/*
 * Decompiled with CFR 0.152.
 */
package edu.jhu.thrax.hadoop.features.mapred;

import edu.jhu.thrax.hadoop.comparators.FieldComparator;
import edu.jhu.thrax.hadoop.comparators.PrimitiveArrayMarginalComparator;
import edu.jhu.thrax.hadoop.datatypes.Annotation;
import edu.jhu.thrax.hadoop.datatypes.FeaturePair;
import edu.jhu.thrax.hadoop.datatypes.PrimitiveUtils;
import edu.jhu.thrax.hadoop.datatypes.RuleWritable;
import edu.jhu.thrax.hadoop.features.mapred.MapReduceFeature;
import edu.jhu.thrax.util.Vocabulary;
import java.io.IOException;
import java.util.Arrays;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Partitioner;
import org.apache.hadoop.mapreduce.Reducer;

public class LhsGivenTargetPhraseFeature
extends MapReduceFeature {
    public static final String NAME = "lhs_given_e";
    public static final String LABEL = "p(LHS|e)";
    private static final FloatWritable ZERO = new FloatWritable(0.0f);

    @Override
    public String getName() {
        return NAME;
    }

    @Override
    public String getLabel() {
        return LABEL;
    }

    @Override
    public Class<? extends WritableComparator> sortComparatorClass() {
        return Comparator.class;
    }

    public Class<? extends Partitioner<RuleWritable, Writable>> partitionerClass() {
        return RuleWritable.TargetPartitioner.class;
    }

    public Class<? extends Mapper<RuleWritable, Annotation, RuleWritable, IntWritable>> mapperClass() {
        return Map.class;
    }

    public Class<? extends Reducer<RuleWritable, IntWritable, RuleWritable, FeaturePair>> reducerClass() {
        return Reduce.class;
    }

    @Override
    public void unaryGlueRuleScore(int nt, java.util.Map<Integer, Writable> map) {
        map.put(Vocabulary.id(LABEL), (Writable)ZERO);
    }

    @Override
    public void binaryGlueRuleScore(int nt, java.util.Map<Integer, Writable> map) {
        map.put(Vocabulary.id(LABEL), (Writable)ZERO);
    }

    public static class Comparator
    extends WritableComparator {
        private static final WritableComparator PARRAY_COMP = new PrimitiveArrayMarginalComparator();
        private static final FieldComparator SOURCE_COMP = new FieldComparator(0, PARRAY_COMP);
        private static final FieldComparator TARGET_COMP = new FieldComparator(1, PARRAY_COMP);

        public Comparator() {
            super(RuleWritable.class);
        }

        public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
            try {
                int lhs2;
                int h1 = WritableUtils.decodeVIntSize((byte)b1[s1 + 1]) + 1;
                int h2 = WritableUtils.decodeVIntSize((byte)b2[s2 + 1]) + 1;
                int cmp = TARGET_COMP.compare(b1, s1 + h1, l1 - h1, b2, s2 + h2, l2 - h2);
                if (cmp != 0) {
                    return cmp;
                }
                int lhs1 = Math.abs(WritableComparator.readVInt((byte[])b1, (int)(s1 + 1)));
                cmp = PrimitiveUtils.compare(lhs1, lhs2 = Math.abs(WritableComparator.readVInt((byte[])b2, (int)(s2 + 1))));
                if (cmp != 0) {
                    return cmp;
                }
                cmp = SOURCE_COMP.compare(b1, s1 + h1, l1 - h1, b2, s2 + h2, l2 - h2);
                if (cmp != 0) {
                    return cmp;
                }
                return PrimitiveUtils.compare(b1[s1], b2[s2]);
            }
            catch (IOException e) {
                throw new IllegalArgumentException(e);
            }
        }
    }

    private static class Reduce
    extends Reducer<RuleWritable, IntWritable, RuleWritable, FeaturePair> {
        private int marginal;
        private FloatWritable prob;

        private Reduce() {
        }

        protected void setup(Reducer.Context context) throws IOException, InterruptedException {
            Configuration conf = context.getConfiguration();
            String vocabulary_path = conf.getRaw("thrax.work-dir") + "vocabulary/part-*";
            Vocabulary.initialize(conf, vocabulary_path);
        }

        protected void reduce(RuleWritable key, Iterable<IntWritable> values, Reducer.Context context) throws IOException, InterruptedException {
            if (key.lhs == 0) {
                this.marginal = 0;
                for (IntWritable x : values) {
                    this.marginal += x.get();
                }
                return;
            }
            if (Arrays.equals(key.source, PrimitiveArrayMarginalComparator.MARGINAL)) {
                int count = 0;
                for (IntWritable x : values) {
                    count += x.get();
                }
                this.prob = new FloatWritable((float)(-Math.log((float)count / (float)this.marginal)));
                return;
            }
            context.write((Object)key, (Object)new FeaturePair(Vocabulary.id(LhsGivenTargetPhraseFeature.LABEL), (Writable)this.prob));
        }
    }

    private static class Map
    extends Mapper<RuleWritable, Annotation, RuleWritable, IntWritable> {
        private Map() {
        }

        protected void map(RuleWritable key, Annotation value, Mapper.Context context) throws IOException, InterruptedException {
            RuleWritable target_marginal = new RuleWritable(key);
            RuleWritable lhs_target_marginal = new RuleWritable(key);
            target_marginal.source = PrimitiveArrayMarginalComparator.MARGINAL;
            target_marginal.lhs = 0;
            target_marginal.monotone = false;
            lhs_target_marginal.source = PrimitiveArrayMarginalComparator.MARGINAL;
            target_marginal.monotone = false;
            IntWritable count = new IntWritable(value.count());
            context.write((Object)key, (Object)count);
            context.write((Object)lhs_target_marginal, (Object)count);
            context.write((Object)target_marginal, (Object)count);
        }
    }
}

