/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.discPCFG;

import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.ConditionalTrainer;
import edu.berkeley.nlp.PCFGLA.ConstrainedHierarchicalTwoChartParser;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserConstrainer;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.discPCFG.EncodedDatum;
import edu.berkeley.nlp.discPCFG.Encoding;
import edu.berkeley.nlp.discPCFG.IndexLinearizer;
import edu.berkeley.nlp.discPCFG.Linearizer;
import edu.berkeley.nlp.discPCFG.ObjectiveFunction;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ParsingObjectiveFunction
implements ObjectiveFunction {
    public static final int NO_REGULARIZATION = 0;
    public static final int L1_REGULARIZATION = 1;
    public static final int L2_REGULARIZATION = 2;
    Grammar grammar;
    SimpleLexicon lexicon;
    SpanPredictor spanPredictor;
    Linearizer linearizer;
    int myRegularization;
    double sigma;
    double lastValue;
    double[] lastDerivative;
    double[] lastUnregularizedDerivative;
    double[] x;
    int dimension;
    int nGrammarWeights;
    int nLexiconWeights;
    int nSpanWeights;
    int nProcesses;
    String consBaseName;
    StateSetTreeList[] trainingTrees;
    ExecutorService pool;
    Calculator[] tasks;
    double bestObjectiveSoFar;
    String outFileName;
    double[] spanGoldCounts;

    @Override
    public int dimension() {
        return this.dimension;
    }

    @Override
    public double valueAt(double[] x) {
        this.ensureCache(x);
        return this.lastValue;
    }

    @Override
    public double[] derivativeAt(double[] x) {
        this.ensureCache(x);
        return this.lastDerivative;
    }

    @Override
    public double[] unregularizedDerivativeAt(double[] x) {
        this.ensureCache(x);
        return this.lastUnregularizedDerivative;
    }

    private void ensureCache(double[] proposed_x) {
        if (this.requiresUpdate(proposed_x)) {
            this.linearizer.delinearizeWeights(proposed_x);
            this.grammar = this.linearizer.getGrammar();
            this.lexicon = this.linearizer.getLexicon();
            this.spanPredictor = this.linearizer.getSpanPredictor();
            if (this.x == null) {
                this.x = (double[])proposed_x.clone();
            } else {
                int xi = 0;
                while (xi < this.x.length) {
                    this.x[xi] = proposed_x[xi];
                    ++xi;
                }
            }
            System.out.print("Task: ");
            Future[] submits = new Future[this.nProcesses];
            if (this.nProcesses > 1) {
                boolean done;
                int i = 0;
                while (i < this.nProcesses) {
                    Future submit;
                    submits[i] = submit = this.pool.submit(this.tasks[i]);
                    ++i;
                }
                do {
                    done = true;
                    Future[] futureArray = submits;
                    int n = submits.length;
                    int n2 = 0;
                    while (n2 < n) {
                        Future task = futureArray[n2];
                        done &= task.isDone();
                        ++n2;
                    }
                } while (!done);
            }
            double objective = 0.0;
            int nUnparasble = 0;
            int nIncorrectLL = 0;
            double[] derivatives = new double[this.dimension];
            int i = 0;
            while (i < this.nProcesses) {
                Counts counts = null;
                if (this.nProcesses == 1) {
                    counts = this.tasks[0].call();
                } else {
                    try {
                        counts = (Counts)submits[i].get();
                    }
                    catch (ExecutionException e) {
                        e.printStackTrace();
                        System.out.println(e.getMessage());
                        System.out.println(e.getLocalizedMessage());
                    }
                    catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
                objective += counts.myObjective;
                int j = 0;
                while (j < this.dimension) {
                    int n = j;
                    derivatives[n] = derivatives[n] + counts.myDerivatives[j];
                    ++j;
                }
                nUnparasble += counts.unparsableTrees;
                nIncorrectLL += counts.incorrectLLTrees;
                ++i;
            }
            if (this.spanPredictor != null) {
                int offset = this.dimension - this.spanGoldCounts.length;
                double total = 0.0;
                int rule = 0;
                while (rule < this.spanGoldCounts.length) {
                    total += derivatives[offset + rule];
                    int n = offset + rule;
                    derivatives[n] = derivatives[n] + this.spanGoldCounts[rule];
                    if (SloppyMath.isVeryDangerous(derivatives[offset + rule])) {
                        System.out.print(String.valueOf(derivatives[offset + rule]) + " ");
                    }
                    ++rule;
                }
                System.out.println(total);
            }
            System.out.print(" done. ");
            if (nUnparasble > 0) {
                System.out.println(String.valueOf(nUnparasble) + " trees were not parsable.");
            }
            if (nIncorrectLL > 0) {
                System.out.println(String.valueOf(nIncorrectLL) + " trees had a higher gold LL than all LL.");
            }
            System.out.print("\nThe objective was " + objective);
            this.lastUnregularizedDerivative = (double[])derivatives.clone();
            switch (this.myRegularization) {
                case 2: {
                    objective = this.l2_regularize(objective, derivatives);
                    System.out.print(" and is " + objective + " after L2 regularization");
                    break;
                }
                case 1: {
                    objective = this.l1_regularize(objective, derivatives);
                    System.out.print(" and is " + objective + " after L1 regularization");
                }
            }
            System.out.print(".\n");
            objective *= -1.0;
            int index = 0;
            while (index < derivatives.length) {
                int n = index;
                derivatives[n] = derivatives[n] * -1.0;
                int n3 = index++;
                this.lastUnregularizedDerivative[n3] = this.lastUnregularizedDerivative[n3] * -1.0;
            }
            this.lastValue = objective;
            this.lastDerivative = derivatives;
            if (objective < this.bestObjectiveSoFar && !ConditionalTrainer.Options.dontSaveGrammarsAfterEachIteration) {
                this.bestObjectiveSoFar = objective;
                ParserData pData = new ParserData(this.lexicon, this.grammar, this.spanPredictor, Numberer.getNumberers(), this.grammar.numSubStates, 1, 0, Binarization.RIGHT);
                double val = objective;
                if (val != 0.0) {
                    while (Math.abs(val) < 10000.0) {
                        val *= 10.0;
                    }
                }
                int value = (int)val;
                System.out.println("Saving grammar to " + this.outFileName + "-" + value + ".");
                if (!pData.Save(String.valueOf(this.outFileName) + "-" + value)) {
                    System.out.println("Saving failed!");
                }
            }
        }
    }

    private boolean requiresUpdate(double[] proposed_x) {
        if (this.x == null) {
            return true;
        }
        int i = 0;
        while (i < this.x.length) {
            if (proposed_x[i] == Double.NaN) {
                System.out.println("Optimizer proposed " + this.x[i]);
                proposed_x[i] = Double.NEGATIVE_INFINITY;
            }
            if (this.x[i] != proposed_x[i]) {
                return true;
            }
            ++i;
        }
        return false;
    }

    public double l2_regularize(double objective, double[] derivatives) {
        if (SloppyMath.isVeryDangerous(objective)) {
            return objective;
        }
        double sigma2 = this.sigma * this.sigma;
        double penalty = 0.0;
        int index = 0;
        while (index < this.x.length) {
            penalty += this.x[index] * this.x[index];
            ++index;
        }
        objective -= penalty / (2.0 * sigma2);
        index = 0;
        while (index < this.x.length) {
            int n = index;
            derivatives[n] = derivatives[n] - this.x[index] / sigma2;
            if (SloppyMath.isVeryDangerous(derivatives[index])) {
                System.out.println("Setting regularized derivative to zero because it is Inf.");
                derivatives[index] = 0.0;
            }
            ++index;
        }
        return objective;
    }

    public double l1_regularize(double objective, double[] derivatives) {
        double mySigma;
        if (SloppyMath.isVeryDangerous(objective)) {
            return objective;
        }
        double sigma2 = this.sigma * this.sigma;
        double sigma2span = 1.0;
        double sigma2lex = sigma2;
        int ind = 0;
        int penaltyGr = 0;
        int penaltyLex = 0;
        int penaltySpan = 0;
        int i = 0;
        while (i < this.nGrammarWeights) {
            penaltyGr = (int)((double)penaltyGr + Math.abs(this.x[ind++]));
            ++i;
        }
        penaltyGr = (int)((double)penaltyGr / (2.0 * sigma2));
        i = 0;
        while (i < this.nLexiconWeights) {
            penaltyLex = (int)((double)penaltyLex + Math.abs(this.x[ind++]));
            ++i;
        }
        penaltyLex = (int)((double)penaltyLex / (2.0 * sigma2lex));
        i = 0;
        while (i < this.nSpanWeights) {
            penaltySpan = (int)((double)penaltySpan + Math.abs(this.x[ind++]));
            ++i;
        }
        penaltySpan = (int)((double)penaltySpan / (2.0 * sigma2span));
        objective -= (double)(penaltyGr + penaltyLex + penaltySpan);
        int index = 0;
        int i2 = 0;
        while (i2 < this.nGrammarWeights) {
            mySigma = sigma2;
            if (this.x[index] < 0.0) {
                int n = index;
                derivatives[n] = derivatives[n] - -1.0 / mySigma;
            } else if (this.x[index] > 0.0) {
                int n = index;
                derivatives[n] = derivatives[n] - 1.0 / mySigma;
            } else if (derivatives[index] < -1.0 / mySigma) {
                int n = index;
                derivatives[n] = derivatives[n] - 1.0 / mySigma;
            } else if (derivatives[index] > 1.0 / mySigma) {
                int n = index;
                derivatives[n] = derivatives[n] - -1.0 / mySigma;
            } else {
                derivatives[index] = 0.0;
                this.lastUnregularizedDerivative[index] = 0.0;
            }
            if (SloppyMath.isVeryDangerous(derivatives[index]) || Math.abs(derivatives[index]) > 1.0E10) {
                System.out.println("Setting regularized derivative to zero because it is " + derivatives[index]);
                derivatives[index] = 0.0;
                this.lastUnregularizedDerivative[index] = 0.0;
            }
            ++index;
            ++i2;
        }
        i2 = 0;
        while (i2 < this.nLexiconWeights) {
            mySigma = sigma2lex;
            if (this.x[index] < 0.0) {
                int n = index;
                derivatives[n] = derivatives[n] - -1.0 / mySigma;
            } else if (this.x[index] > 0.0) {
                int n = index;
                derivatives[n] = derivatives[n] - 1.0 / mySigma;
            } else if (derivatives[index] < -1.0 / mySigma) {
                int n = index;
                derivatives[n] = derivatives[n] - 1.0 / mySigma;
            } else if (derivatives[index] > 1.0 / mySigma) {
                int n = index;
                derivatives[n] = derivatives[n] - -1.0 / mySigma;
            } else {
                derivatives[index] = 0.0;
                this.lastUnregularizedDerivative[index] = 0.0;
            }
            if (SloppyMath.isVeryDangerous(derivatives[index]) || Math.abs(derivatives[index]) > 1.0E10) {
                System.out.println("Setting regularized derivative to zero because it is " + derivatives[index]);
                derivatives[index] = 0.0;
                this.lastUnregularizedDerivative[index] = 0.0;
            }
            ++index;
            ++i2;
        }
        i2 = 0;
        while (i2 < this.nSpanWeights) {
            mySigma = sigma2span;
            if (this.x[index] < 0.0) {
                int n = index;
                derivatives[n] = derivatives[n] - -1.0 / mySigma;
            } else if (this.x[index] > 0.0) {
                int n = index;
                derivatives[n] = derivatives[n] - 1.0 / mySigma;
            } else if (derivatives[index] < -1.0 / mySigma) {
                int n = index;
                derivatives[n] = derivatives[n] - 1.0 / mySigma;
            } else if (derivatives[index] > 1.0 / mySigma) {
                int n = index;
                derivatives[n] = derivatives[n] - -1.0 / mySigma;
            } else {
                derivatives[index] = 0.0;
                this.lastUnregularizedDerivative[index] = 0.0;
            }
            if (SloppyMath.isVeryDangerous(derivatives[index]) || Math.abs(derivatives[index]) > 1.0E10) {
                System.out.println("Setting regularized derivative to zero because it is " + derivatives[index]);
                derivatives[index] = 0.0;
                this.lastUnregularizedDerivative[index] = 0.0;
            }
            ++index;
            ++i2;
        }
        return objective;
    }

    public ParsingObjectiveFunction() {
    }

    public ParsingObjectiveFunction(Linearizer linearizer, StateSetTreeList trainTrees, double sigma, int regularization, String consName, int nProc, String outName, boolean doNotProjectConstraints, boolean combinedLexicon) {
        this.sigma = sigma;
        this.myRegularization = regularization;
        this.grammar = linearizer.getGrammar();
        this.lexicon = linearizer.getLexicon();
        this.spanPredictor = linearizer.getSpanPredictor();
        this.linearizer = linearizer;
        this.outFileName = outName;
        this.dimension = linearizer.dimension();
        this.nGrammarWeights = linearizer.getNGrammarWeights();
        this.nLexiconWeights = linearizer.getNLexiconWeights();
        this.nSpanWeights = linearizer.getNSpanWeights();
        if (this.spanPredictor != null) {
            this.spanGoldCounts = this.spanPredictor.countGoldSpanFeatures(trainTrees);
        }
        int nTreesPerBlock = trainTrees.size() / nProc;
        this.consBaseName = consName;
        boolean[][][][][] tmp = ParserConstrainer.loadData(String.valueOf(consName) + "-0.data");
        if (tmp != null) {
            nTreesPerBlock = tmp.length;
        }
        this.nProcesses = nProc;
        this.trainingTrees = new StateSetTreeList[this.nProcesses];
        int i = 0;
        while (i < this.nProcesses) {
            this.trainingTrees[i] = new StateSetTreeList();
            ++i;
        }
        int block = -1;
        int inBlock = 0;
        int i2 = 0;
        while (i2 < trainTrees.size()) {
            if (i2 % nTreesPerBlock == 0) {
                ++block;
                inBlock = 0;
            }
            this.trainingTrees[block % this.nProcesses].add(trainTrees.get(i2));
            ++inBlock;
            ++i2;
        }
        i2 = 0;
        while (i2 < this.nProcesses) {
            System.out.println("Process " + i2 + " has " + this.trainingTrees[i2].size() + " trees.");
            ++i2;
        }
        trainTrees = null;
        this.pool = Executors.newFixedThreadPool(this.nProcesses);
        this.tasks = new Calculator[this.nProcesses];
        i2 = 0;
        while (i2 < this.nProcesses) {
            this.tasks[i2] = this.newCalculator(doNotProjectConstraints, i2);
            ++i2;
        }
        this.bestObjectiveSoFar = Double.POSITIVE_INFINITY;
    }

    @Override
    public void shutdown() {
        this.pool.shutdown();
    }

    protected Calculator newCalculator(boolean doNotProjectConstraints, int i) {
        return new Calculator(this.trainingTrees[i], this.consBaseName, i, this.grammar, this.lexicon, this.spanPredictor, this.dimension, doNotProjectConstraints);
    }

    public double[] getCurrentWeights() {
        return this.linearizer.getLinearizedWeights();
    }

    @Override
    public <F, L> double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
        return null;
    }

    public void setSigma(double newSigma) {
        this.sigma = newSigma;
        this.x = null;
        this.bestObjectiveSoFar = Double.POSITIVE_INFINITY;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class Calculator
    implements Callable {
        ArrayParser gParser;
        ConstrainedTwoChartsParser eParser;
        StateSetTreeList myTrees;
        String consName;
        int myID;
        int nCounts;
        Counts myCounts;
        boolean[][][][][] myConstraints;
        int unparsableTrees;
        int incorrectLLTrees;
        boolean doNotProjectConstraints;
        double[] myDerivatives;

        Calculator(StateSetTreeList myT, String consN, int i, Grammar gr, Lexicon lex, SpanPredictor sp, int dimension, boolean notProject) {
            this.nCounts = dimension;
            this.consName = consN;
            this.myTrees = myT;
            this.doNotProjectConstraints = notProject;
            this.myID = i;
            this.gParser = new ArrayParser(gr, lex);
            this.eParser = this.newEParser(gr, lex, sp);
        }

        protected ConstrainedTwoChartsParser newEParser(Grammar gr, Lexicon lex, SpanPredictor sp) {
            if (!ConditionalTrainer.Options.hierarchicalChart) {
                return new ConstrainedTwoChartsParser(gr, lex, sp);
            }
            return new ConstrainedHierarchicalTwoChartParser(gr, lex, sp, gr.finalLevel);
        }

        protected void loadConstraints() {
            this.myConstraints = new boolean[this.myTrees.size()][][][][];
            boolean[][][][][] curBlock = null;
            int block = 0;
            int i = 0;
            if (this.consName == null) {
                return;
            }
            int tree = 0;
            while (tree < this.myTrees.size()) {
                if (curBlock == null || i >= curBlock.length) {
                    int blockNumber = block * ParsingObjectiveFunction.this.nProcesses + this.myID;
                    curBlock = this.loadData(String.valueOf(this.consName) + "-" + blockNumber + ".data");
                    ++block;
                    i = 0;
                    System.out.print(".");
                }
                if (!this.doNotProjectConstraints) {
                    this.eParser.projectConstraints(curBlock[i], false);
                }
                this.myConstraints[tree] = curBlock[i];
                ++i;
                if (this.myConstraints[tree].length != this.myTrees.get(tree).getYield().size()) {
                    System.out.println("My ID: " + this.myID + ", block: " + block + ", sentence: " + i);
                    System.out.println("Sentence length and constraints length do not match!");
                    this.myConstraints[tree] = null;
                }
                ++tree;
            }
        }

        public Counts call() {
            double myObjective = 0.0;
            this.myDerivatives = new double[ParsingObjectiveFunction.this.dimension];
            this.unparsableTrees = 0;
            this.incorrectLLTrees = 0;
            if (this.myConstraints == null) {
                this.loadConstraints();
            }
            int i = -1;
            int block = 0;
            double totalBias = 0.0;
            for (Tree<StateSet> stateSetTree : this.myTrees) {
                double goldLL;
                List<StateSet> yield = stateSetTree.getYield();
                boolean noSmoothing = false;
                boolean debugOutput = false;
                boolean[][][][] cons = null;
                if (this.consName != null && (cons = this.myConstraints[++i]).length != yield.size()) {
                    System.out.println("My ID: " + this.myID + ", block: " + block + ", sentence: " + i);
                    System.out.println("Sentence length (" + yield.size() + ") and constraints length (" + cons.length + ") do not match!");
                    System.exit(-1);
                }
                double allLL = this.eParser.doConstrainedInsideOutsideScores(yield, cons, noSmoothing, null, null, false);
                double d = goldLL = ConditionalTrainer.Options.hierarchicalChart ? this.eParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, this.eParser.spanScores) : this.gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput, this.eParser.spanScores);
                if (i % 500 == 0) {
                    System.out.print(".");
                }
                if (!this.sanityCheckLLs(goldLL, allLL, stateSetTree)) {
                    myObjective += -1000.0;
                    continue;
                }
                this.eParser.incrementExpectedCounts(ParsingObjectiveFunction.this.linearizer, this.myDerivatives, yield);
                if (ConditionalTrainer.Options.hierarchicalChart) {
                    this.eParser.incrementExpectedGoldCounts(ParsingObjectiveFunction.this.linearizer, this.myDerivatives, stateSetTree);
                } else {
                    this.gParser.incrementExpectedGoldCounts(ParsingObjectiveFunction.this.linearizer, this.myDerivatives, stateSetTree);
                }
                myObjective += goldLL - allLL;
            }
            this.myCounts = new Counts(myObjective, this.myDerivatives, this.unparsableTrees, this.incorrectLLTrees);
            System.out.println("\nAverage bias: " + (totalBias /= (double)this.myTrees.size()) + "\n");
            System.out.print(" " + this.myID + " ");
            return this.myCounts;
        }

        public boolean[][][][][] loadData(String fileName) {
            boolean[][][][][] data = null;
            try {
                FileInputStream fis = new FileInputStream(fileName);
                GZIPInputStream gzis = new GZIPInputStream(fis);
                ObjectInputStream in = new ObjectInputStream(gzis);
                data = (boolean[][][][][])in.readObject();
                in.close();
                gzis.close();
                fis.close();
            }
            catch (IOException e) {
                System.out.println("IOException\n" + e);
                return null;
            }
            catch (ClassNotFoundException e) {
                System.out.println("Class not found!");
                return null;
            }
            return data;
        }

        protected boolean sanityCheckLLs(double goldLL, double allLL, Tree<StateSet> stateSetTree) {
            if (SloppyMath.isVeryDangerous(allLL) || SloppyMath.isVeryDangerous(goldLL)) {
                ++this.unparsableTrees;
                return false;
            }
            if (goldLL - allLL > 1.0E-4) {
                System.out.println("Something is wrong! The gold LL is " + goldLL + " and the all LL is " + allLL);
                System.out.println(stateSetTree);
                ++this.incorrectLLTrees;
                return false;
            }
            return true;
        }
    }

    class Counts {
        double myObjective;
        double[] myDerivatives;
        int unparsableTrees;
        int incorrectLLTrees;

        public Counts(double myObjective, double[] myDerivatives, int unpars, int incorr) {
            this.myObjective = myObjective;
            this.myDerivatives = myDerivatives;
            this.unparsableTrees = unpars;
            this.incorrectLLTrees = incorr;
        }
    }
}

