/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.classify.Feature;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.FeatureManager;
import edu.berkeley.nlp.math.CachingDifferentiableFunction;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Pair;
import java.util.Collection;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LinearRegression<I> {
    private FeatureExtractor<I, String> featureExtractor;
    private double[] weights;
    private FeatureManager featureManager;

    private LinearRegression(FeatureExtractor<I, String> featureExtractor, FeatureManager featureManager, double[] weights) {
        this.featureExtractor = featureExtractor;
        this.featureManager = featureManager;
        this.weights = weights;
    }

    public double getResponse(I input) {
        Counter<String> featCounts = this.featureExtractor.extractFeatures(input);
        double score = 0.0;
        for (String f : featCounts.keySet()) {
            double count = featCounts.getCount(f);
            Feature feat = this.featureManager.getFeature(f);
            score += count * this.weights[feat.getIndex()];
        }
        return score;
    }

    public static void main(String[] args) {
        List<String> elem1 = CollectionUtils.makeList("a", "b", "c");
        List<String> elem2 = CollectionUtils.makeList("a", "b");
        Pair<List<String>, Double> d1 = Pair.newPair(elem1, 3.0);
        Pair<List<String>, Double> d2 = Pair.newPair(elem2, 2.0);
        FeatureExtractor<List<String>, String> featExtractor = new FeatureExtractor<List<String>, String>(){

            @Override
            public Counter<String> extractFeatures(List<String> instance) {
                Counter<String> counts = new Counter<String>();
                for (String elem : instance) {
                    counts.incrementCount(elem, 1.0);
                }
                return counts;
            }
        };
        Factory<List<String>> factory = new Factory<List<String>>(featExtractor);
        List datums = CollectionUtils.makeList(d1, d2);
        LinearRegression<List<String>> linearRegressionModel = factory.train(datums);
        double guess = linearRegressionModel.getResponse(elem1);
        System.out.println("guess: " + guess);
    }

    /* synthetic */ LinearRegression(FeatureExtractor featureExtractor, FeatureManager featureManager, double[] dArray, LinearRegression linearRegression) {
        this(featureExtractor, featureManager, dArray);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class Factory<I> {
        double[] weights;
        FeatureManager featureManager;
        FeatureExtractor<I, String> featureExtractor;
        Collection<Pair<I, Double>> trainingData;

        public Factory(FeatureExtractor<I, String> featureExtractor) {
            this.featureExtractor = featureExtractor;
            this.featureManager = new FeatureManager();
        }

        private Counter<Feature> getFeatures(I input) {
            Counter<String> strCounts = this.featureExtractor.extractFeatures(input);
            Counter<Feature> featCounts = new Counter<Feature>();
            for (String f : strCounts.keySet()) {
                double count = strCounts.getCount(f);
                Feature feat = this.featureManager.getFeature(f);
                featCounts.setCount(feat, count);
            }
            return featCounts;
        }

        private double getScore(Counter<Feature> featureCounts) {
            double score = 0.0;
            for (Feature feat : featureCounts.keySet()) {
                double count = featureCounts.getCount(feat);
                score += count * this.weights[feat.getIndex()];
            }
            return score;
        }

        private void extractAllFeatures() {
            for (Pair<I, Double> datum : this.trainingData) {
                Counter<String> counts = this.featureExtractor.extractFeatures(datum.getFirst());
                for (String f : counts.keySet()) {
                    this.featureManager.getFeature(f);
                }
            }
            this.featureManager.lock();
        }

        private String examineWeights() {
            Counter<Feature> counts = new Counter<Feature>();
            int i = 0;
            while (i < this.weights.length) {
                Feature feat = this.featureManager.getFeature(i);
                counts.setCount(feat, this.weights[i]);
                ++i;
            }
            return counts.toString();
        }

        public LinearRegression<I> train(Collection<Pair<I, Double>> trainingData) {
            this.trainingData = trainingData;
            this.extractAllFeatures();
            ObjectiveFunction objFn = new ObjectiveFunction();
            LBFGSMinimizer gradMinimizer = new LBFGSMinimizer();
            double[] initial = new double[objFn.dimension()];
            this.weights = gradMinimizer.minimize(objFn, initial, 1.0E-4);
            return new LinearRegression(this.featureExtractor, this.featureManager, this.weights, null);
        }

        /*
         * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
         */
        private class ObjectiveFunction
        extends CachingDifferentiableFunction {
            private ObjectiveFunction() {
            }

            @Override
            protected Pair<Double, double[]> calculate(double[] x) {
                Factory.this.weights = x;
                double objective = 0.0;
                double[] gradient = new double[this.dimension()];
                for (Pair datum : Factory.this.trainingData) {
                    Object input = datum.getFirst();
                    Counter featCounts = Factory.this.getFeatures(input);
                    double guessResponse = Factory.this.getScore(featCounts);
                    double goldResponse = datum.getSecond();
                    double diff = guessResponse - goldResponse;
                    objective += 0.5 * diff * diff;
                    for (Feature feat : featCounts.keySet()) {
                        double count = featCounts.getCount(feat);
                        int n = feat.getIndex();
                        gradient[n] = gradient[n] + count * diff;
                    }
                }
                return Pair.newPair(objective, gradient);
            }

            @Override
            public int dimension() {
                return Factory.this.featureManager.getNumFeatures();
            }

            public double[] unregularizedDerivativeAt(double[] x) {
                return null;
            }
        }
    }
}

