/*
 * Decompiled with CFR 0.152.
 */
package edu.jhu.thrax.hadoop.paraphrasing;

import edu.jhu.thrax.hadoop.datatypes.Annotation;
import edu.jhu.thrax.hadoop.datatypes.FeatureMap;
import edu.jhu.thrax.hadoop.datatypes.RuleWritable;
import edu.jhu.thrax.hadoop.features.annotation.AnnotationFeature;
import edu.jhu.thrax.hadoop.features.annotation.AnnotationFeatureFactory;
import edu.jhu.thrax.hadoop.features.pivot.PivotedAnnotationFeature;
import edu.jhu.thrax.hadoop.features.pivot.PivotedFeature;
import edu.jhu.thrax.hadoop.features.pivot.PivotedFeatureFactory;
import edu.jhu.thrax.util.BackwardsCompatibility;
import edu.jhu.thrax.util.Vocabulary;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.mapreduce.Reducer;

public class PivotingReducer
extends Reducer<RuleWritable, FeatureMap, RuleWritable, FeatureMap> {
    private int[] currentSrc;
    private int currentLhs;
    private int[] nts;
    private int lhs;
    private List<ParaphrasePattern> targets;
    private List<PivotedFeature> pivotedFeatures;
    private List<AnnotationFeature> annotationFeatures;
    private Map<Integer, PruningRule> translationPruningRules;
    private Map<Integer, PruningRule> pivotedPruningRules;

    protected void setup(Reducer.Context context) throws IOException, InterruptedException {
        Configuration conf = context.getConfiguration();
        String vocabulary_path = conf.getRaw("thrax.work-dir") + "vocabulary/part-*";
        Vocabulary.initialize(conf, vocabulary_path);
        String features = BackwardsCompatibility.equivalent(conf.get("thrax.features", ""));
        this.pivotedFeatures = PivotedFeatureFactory.getAll(features);
        this.annotationFeatures = AnnotationFeatureFactory.getAll(features);
        if (!this.annotationFeatures.isEmpty()) {
            this.pivotedFeatures.add(new PivotedAnnotationFeature());
        }
        HashSet<String> prerequisite_features = new HashSet<String>();
        for (PivotedFeature pf : this.pivotedFeatures) {
            prerequisite_features.addAll(pf.getPrerequisites());
        }
        this.annotationFeatures = new ArrayList<AnnotationFeature>();
        for (String f_name : prerequisite_features) {
            AnnotationFeature af = AnnotationFeatureFactory.get(f_name);
            if (af == null) continue;
            this.annotationFeatures.add(af);
        }
        this.currentLhs = 0;
        this.currentSrc = null;
        this.lhs = 0;
        this.nts = null;
        this.targets = new ArrayList<ParaphrasePattern>();
        String pruning_rules = BackwardsCompatibility.equivalent(conf.get("thrax.pruning", ""));
        this.translationPruningRules = this.getTranslationPruningRules(pruning_rules);
        this.pivotedPruningRules = this.getPivotedPruningRules(pruning_rules);
        for (AnnotationFeature af : this.annotationFeatures) {
            af.init(context);
        }
    }

    protected void reduce(RuleWritable key, Iterable<FeatureMap> values, Reducer.Context context) throws IOException, InterruptedException {
        if (this.currentLhs == 0 || key.lhs != this.currentLhs || !Arrays.equals(key.source, this.currentSrc)) {
            if (this.currentLhs != 0) {
                this.pivotAll(context);
            }
            this.currentLhs = key.lhs;
            this.currentSrc = key.source;
            if (this.currentLhs == 0 || this.currentSrc.length == 0) {
                return;
            }
            this.lhs = this.currentLhs;
            this.nts = PivotingReducer.extractNonterminals(this.currentSrc);
            this.targets.clear();
        }
        boolean seen_first = false;
        for (FeatureMap features : values) {
            if (seen_first) {
                throw new RuntimeException("Multiple feature maps for one rule:" + key.toString());
            }
            seen_first = true;
            Annotation annotation = (Annotation)features.get("Annotation");
            for (AnnotationFeature f : this.annotationFeatures) {
                features.put(f.getLabel(), f.score(key, annotation));
            }
            if (!PivotingReducer.prune(features, this.translationPruningRules)) {
                this.targets.add(new ParaphrasePattern(key.target, this.nts, this.lhs, key.monotone, features));
                continue;
            }
            context.getCounter((Enum)PivotingCounters.EF_PRUNED).increment(1L);
        }
    }

    protected void cleanup(Reducer.Context context) throws IOException, InterruptedException {
        if (this.currentLhs != 0) {
            this.pivotAll(context);
        }
    }

    protected void pivotAll(Reducer.Context context) throws IOException, InterruptedException {
        context.getCounter((Enum)PivotingCounters.F_READ).increment(1L);
        context.getCounter((Enum)PivotingCounters.EF_READ).increment((long)this.targets.size());
        for (int i = 0; i < this.targets.size(); ++i) {
            for (int j = i; j < this.targets.size(); ++j) {
                this.pivotOne(this.targets.get(i), this.targets.get(j), context);
                if (i == j) continue;
                this.pivotOne(this.targets.get(j), this.targets.get(i), context);
            }
        }
    }

    protected void pivotOne(ParaphrasePattern src, ParaphrasePattern tgt, Reducer.Context context) throws IOException, InterruptedException {
        RuleWritable pivoted_rule = new RuleWritable();
        FeatureMap pivoted_features = new FeatureMap();
        pivoted_rule.lhs = src.lhs;
        pivoted_rule.source = src.rhs;
        pivoted_rule.target = tgt.rhs;
        pivoted_rule.monotone = src.monotone == tgt.monotone;
        try {
            for (PivotedFeature f : this.pivotedFeatures) {
                pivoted_features.put(f.getLabel(), f.pivot(src.features, tgt.features));
            }
        }
        catch (Exception e) {
            StringBuilder src_f = new StringBuilder();
            for (int w : src.features.keySet()) {
                src_f.append(Vocabulary.word(w) + "=" + src.features.get(w) + " ");
            }
            StringBuilder tgt_f = new StringBuilder();
            for (int w : tgt.features.keySet()) {
                tgt_f.append(Vocabulary.word(w) + "=" + tgt.features.get(w) + " ");
            }
            e.printStackTrace();
            throw new RuntimeException(Vocabulary.getWords(src.rhs) + " \n " + Vocabulary.getWords(tgt.rhs) + " \n " + src_f.toString() + " \n " + tgt_f.toString() + " \n");
        }
        if (!PivotingReducer.prune(pivoted_features, this.pivotedPruningRules)) {
            context.write((Object)pivoted_rule, (Object)pivoted_features);
            context.getCounter((Enum)PivotingCounters.EE_WRITTEN).increment(1L);
        } else {
            context.getCounter((Enum)PivotingCounters.EE_PRUNED).increment(1L);
        }
    }

    protected Map<Integer, PruningRule> getPivotedPruningRules(String conf_string) {
        String[] rule_strings;
        HashMap<Integer, PruningRule> rules = new HashMap<Integer, PruningRule>();
        for (String rule_string : rule_strings = conf_string.split("\\s*,\\s*")) {
            boolean smaller;
            String[] f;
            if (rule_string.contains("<")) {
                f = rule_string.split("<");
                smaller = true;
            } else {
                if (!rule_string.contains(">")) continue;
                f = rule_string.split(">");
                smaller = false;
            }
            int label = Vocabulary.id(PivotedFeatureFactory.get(f[0]).getLabel());
            rules.put(label, new PruningRule(smaller, Float.parseFloat(f[1])));
        }
        return rules;
    }

    protected Map<Integer, PruningRule> getTranslationPruningRules(String conf_string) {
        String[] rule_strings;
        HashMap<Integer, PruningRule> rules = new HashMap<Integer, PruningRule>();
        for (String rule_string : rule_strings = conf_string.split("\\s*,\\s*")) {
            Set<String> upper_bound_labels;
            boolean smaller;
            String[] f;
            if (rule_string.contains("<")) {
                f = rule_string.split("<");
                smaller = true;
            } else {
                if (!rule_string.contains(">")) continue;
                f = rule_string.split(">");
                smaller = false;
            }
            Float threshold = Float.valueOf(Float.parseFloat(f[1]));
            Set<String> lower_bound_labels = PivotedFeatureFactory.get(f[0]).getLowerBoundLabels();
            if (lower_bound_labels != null) {
                for (String label : lower_bound_labels) {
                    rules.put(Vocabulary.id(label), new PruningRule(smaller, threshold.floatValue()));
                }
            }
            if ((upper_bound_labels = PivotedFeatureFactory.get(f[0]).getUpperBoundLabels()) == null) continue;
            for (String label : upper_bound_labels) {
                rules.put(Vocabulary.id(label), new PruningRule(!smaller, threshold.floatValue()));
            }
        }
        return rules;
    }

    protected static boolean prune(FeatureMap features, Map<Integer, PruningRule> rules) {
        for (Map.Entry<Integer, PruningRule> e : rules.entrySet()) {
            if (!features.containsKey(e.getKey()) || !e.getValue().applies((FloatWritable)features.get(e.getKey()))) continue;
            return true;
        }
        return false;
    }

    protected static int[] extractNonterminals(int[] source) {
        int[] nArray;
        int first_nt = 0;
        for (int token : source) {
            if (!Vocabulary.nt(token)) continue;
            if (first_nt == 0) {
                first_nt = token;
                continue;
            }
            return new int[]{first_nt, token};
        }
        if (first_nt == 0) {
            nArray = new int[]{};
        } else {
            int[] nArray2 = new int[1];
            nArray = nArray2;
            nArray2[0] = first_nt;
        }
        return nArray;
    }

    class PruningRule {
        private boolean smaller;
        private float threshold;

        PruningRule(boolean smaller, float threshold) {
            this.smaller = smaller;
            this.threshold = threshold;
        }

        protected boolean applies(FloatWritable value) {
            return this.smaller ? value.get() < this.threshold : value.get() > this.threshold;
        }
    }

    class ParaphrasePattern {
        int arity;
        int lhs;
        int[] rhs;
        boolean monotone;
        FeatureMap features;

        public ParaphrasePattern(int[] target, int[] nts, int lhs, boolean mono, FeatureMap features) {
            this.arity = nts.length;
            this.lhs = lhs;
            this.rhs = target;
            this.monotone = mono;
            this.features = new FeatureMap(features);
        }
    }

    private static enum PivotingCounters {
        F_READ,
        EF_READ,
        EF_PRUNED,
        EE_PRUNED,
        EE_WRITTEN;

    }
}

