/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.LabeledInstance;
import edu.berkeley.nlp.classify.ProbabilisticClassifier;
import edu.berkeley.nlp.classify.ProbabilisticClassifierFactory;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.CounterMap;
import java.util.ArrayList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NaiveBayesClassifier<I, F, L>
implements ProbabilisticClassifier<I, L> {
    private CounterMap<L, F> featureProbs;
    private Counter<F> backoffProbs;
    private Counter<L> labelProbs;
    private FeatureExtractor<I, F> featureExtractor;
    private double alpha = 0.1;

    @Override
    public Counter<L> getProbabilities(I instance) {
        double logPosteriorUnnorm;
        Counter<L> posteriors = new Counter<L>();
        ArrayList<Double> logPosteriorsUnnormed = new ArrayList<Double>();
        for (L label : this.labelProbs.keySet()) {
            double logPrior;
            logPosteriorUnnorm = logPrior = Math.log(this.labelProbs.getCount(label));
            Counter<F> featCounts = this.featureExtractor.extractFeatures(instance);
            for (F feat : featCounts.keySet()) {
                double count = featCounts.getCount(feat);
                logPosteriorUnnorm += count * Math.log(this.getFeatureProb(feat, label));
            }
            logPosteriorsUnnormed.add(logPosteriorUnnorm);
            posteriors.setCount(label, logPosteriorUnnorm);
        }
        double logPosteriorNorm = SloppyMath.logAdd(logPosteriorsUnnormed);
        for (L label : this.labelProbs.keySet()) {
            logPosteriorUnnorm = posteriors.getCount(label);
            double logPosterior = logPosteriorUnnorm - logPosteriorNorm;
            double posterior = Math.exp(logPosterior);
            posteriors.setCount(label, posterior);
        }
        return posteriors;
    }

    private double getFeatureProb(F feat, L label) {
        double mleProb = this.featureProbs.getCount(label, feat);
        double backoffProb = this.backoffProbs.getCount(feat);
        return (1.0 - this.alpha) * mleProb + this.alpha * backoffProb;
    }

    @Override
    public L getLabel(I instance) {
        return this.getProbabilities(instance).argMax();
    }

    public NaiveBayesClassifier(CounterMap<L, F> featureProbs, Counter<F> backoffProbs, Counter<L> labelProbs, FeatureExtractor<I, F> featureExtractor) {
        this.featureProbs = featureProbs;
        this.backoffProbs = backoffProbs;
        this.labelProbs = labelProbs;
        this.featureExtractor = featureExtractor;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class Factory<I, F, L>
    implements ProbabilisticClassifierFactory<I, L> {
        private FeatureExtractor<I, F> featureExtractor;

        public Factory(FeatureExtractor<I, F> featureExtractor) {
            this.featureExtractor = featureExtractor;
        }

        @Override
        public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> trainingData) {
            CounterMap<L, F> featureProbs = new CounterMap<L, F>();
            Counter<F> backoffProbs = new Counter<F>();
            Counter<L> labelProbs = new Counter<L>();
            for (LabeledInstance<I, L> instance : trainingData) {
                L label = instance.getLabel();
                labelProbs.incrementCount(label, 1.0);
                I inst = instance.getInput();
                Counter<F> featCounts = this.featureExtractor.extractFeatures(inst);
                for (F feat : featCounts.keySet()) {
                    double count = featCounts.getCount(feat);
                    backoffProbs.incrementCount(feat, count);
                    featureProbs.incrementCount(label, feat, count);
                }
            }
            featureProbs.normalize();
            labelProbs.normalize();
            backoffProbs.normalize();
            return new NaiveBayesClassifier(featureProbs, backoffProbs, labelProbs, this.featureExtractor);
        }
    }
}

