/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.crf.Inference;
import edu.berkeley.nlp.crf.InstanceSequence;
import edu.berkeley.nlp.crf.LabeledInstanceSequence;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;
import java.util.ArrayList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Counts<V, E, F, L> {
    private final Encoding<F, L> encoding;
    private final FeatureExtractor<V, F> vertexExtractor;
    private final FeatureExtractor<E, F> edgeExtractor;
    private final Inference<V, E, F, L> inf;

    public Counts(Encoding<F, L> encoding, FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor) {
        this.encoding = encoding;
        this.vertexExtractor = vertexExtractor;
        this.edgeExtractor = edgeExtractor;
        this.inf = new Inference<V, E, F, L>(encoding, vertexExtractor, edgeExtractor);
    }

    public List<Counter<F>> getEmpiricalCounts(List<? extends LabeledInstanceSequence<V, E, L>> sequences) {
        int numLabels = this.encoding.getNumLabels();
        ArrayList<Counter<F>> counts = new ArrayList<Counter<F>>(numLabels);
        int l = 0;
        while (l < numLabels) {
            counts.add(new Counter());
            ++l;
        }
        for (LabeledInstanceSequence<V, E, L> s : sequences) {
            int i = 0;
            while (i < s.getSequenceLength()) {
                Counter<F> vertexFeatures = this.vertexExtractor.extractFeatures(s.getVertexInstance(i));
                int goldLabelIndex = this.encoding.getLabelIndex(s.getGoldLabel(i));
                ((Counter)counts.get(goldLabelIndex)).incrementAll(vertexFeatures);
                if (i > 0) {
                    Counter<F> edgeFeatures = this.edgeExtractor.extractFeatures(s.getEdgeInstance(i, s.getGoldLabel(i - 1)));
                    ((Counter)counts.get(goldLabelIndex)).incrementAll(edgeFeatures);
                }
                ++i;
            }
        }
        return counts;
    }

    public Pair<Double, List<Counter<F>>> getLogNormalizationAndExpectedCounts(List<? extends InstanceSequence<V, E, L>> sequences, double[] w) {
        int numLabels = this.encoding.getNumLabels();
        ArrayList counts = new ArrayList(numLabels);
        int l = 0;
        while (l < numLabels) {
            counts.add(new Counter());
            ++l;
        }
        double totalLogZ = 0.0;
        Logger.startTrack("Computing expected counts", new Object[0]);
        int index = 0;
        for (InstanceSequence<V, E, L> s : sequences) {
            double[][] alpha = this.inf.getAlphas(s, w);
            double[][] beta = this.inf.getBetas(s, w);
            totalLogZ += Math.log(this.inf.getNormalizationConstant(alpha, beta));
            double[][] vertexPosteriors = this.inf.getVertexPosteriors(alpha, beta);
            double[][][] edgePosteriors = this.inf.getEdgePosteriors(s, w, alpha, beta);
            int i = 0;
            while (i < s.getSequenceLength()) {
                Counter<F> vertexFeatures = this.vertexExtractor.extractFeatures(s.getVertexInstance(i));
                int l2 = 0;
                while (l2 < numLabels) {
                    ((Counter)counts.get(l2)).incrementAll(vertexFeatures.scaledClone(vertexPosteriors[i][l2]));
                    ++l2;
                }
                if (i > 0) {
                    int pl = 0;
                    while (pl < numLabels) {
                        Counter<F> edgeFeatures = this.edgeExtractor.extractFeatures(s.getEdgeInstance(i, this.encoding.getLabel(pl)));
                        int cl = 0;
                        while (cl < numLabels) {
                            ((Counter)counts.get(cl)).incrementAll(edgeFeatures.scaledClone(edgePosteriors[i][pl][cl]));
                            ++cl;
                        }
                        ++pl;
                    }
                }
                ++i;
            }
            Logger.logs("Processed %d/%d sentences", ++index, sequences.size());
        }
        Logger.endTrack();
        return Pair.makePair(totalLogZ, counts);
    }
}

