/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import edu.stanford.nlp.sempre.AbstractReinforcementParserState;
import edu.stanford.nlp.sempre.ChartParserState;
import edu.stanford.nlp.sempre.DerivInfo;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.DerivationStream;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.FeatureVector;
import edu.stanford.nlp.sempre.JoinFn;
import edu.stanford.nlp.sempre.Json;
import edu.stanford.nlp.sempre.ListParserAgenda;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.Parser;
import edu.stanford.nlp.sempre.ParserAgenda;
import edu.stanford.nlp.sempre.ParserState;
import edu.stanford.nlp.sempre.PrioritizedDerivationStream;
import edu.stanford.nlp.sempre.QueueParserAgenda;
import edu.stanford.nlp.sempre.ReinforcementParser;
import edu.stanford.nlp.sempre.ReinforcementUtils;
import edu.stanford.nlp.sempre.SemanticFn;
import edu.stanford.nlp.sempre.SempreUtils;
import edu.stanford.nlp.sempre.SingleDerivationStream;
import fig.basic.Fmt;
import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.prob.SampleUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

final class ReinforcementParserState
extends AbstractReinforcementParserState {
    private static final double LOG_SMALL_PROB = Math.log(ReinforcementParser.opts.lowProb);
    private final ParserAgenda<PrioritizedDerivationStream> agenda;
    private int completeDerivationsPushed = 0;
    private int firstCorrectItem = -1;
    private String samplingStrategy;
    private Sampler sampler;
    List<Derivation> correctDerivations = new ArrayList<Derivation>();
    private Map<String, Double> stateSequenceExpectedCounts = new HashMap<String, Double>();
    Random randGen = new Random(1L);
    private Map<Long, Pair<ArrayList<Derivation>, Integer>> backpointerList;
    private int numItemsSampled = 0;

    private ReinforcementParserState(ReinforcementParser parser, Params params, Example ex, boolean computeExpectedCounts, String samplingStrategy) {
        super(parser, params, ex, computeExpectedCounts);
        this.samplingStrategy = samplingStrategy;
        this.backpointerList = new HashMap<Long, Pair<ArrayList<Derivation>, Integer>>();
        this.agenda = samplingStrategy.equals("max") ? new QueueParserAgenda() : new ListParserAgenda();
    }

    private void clearState() {
        this.agenda.clear();
        this.clearChart();
        this.completeDerivationsPushed = 0;
        this.firstCorrectItem = -1;
        this.correctDerivations.clear();
        this.stateSequenceExpectedCounts.clear();
        this.backpointerList.clear();
        this.numItemsSampled = 0;
    }

    @Override
    protected void addToAgenda(DerivationStream derivationStream) {
        this.addToAgenda(derivationStream, 0.0);
    }

    private void addToAgenda(DerivationStream derivationStream, double probSum) {
        if (!derivationStream.hasNext()) {
            return;
        }
        if (!ReinforcementParser.opts.alwaysUnroll || derivationStream.estimatedSize() <= 1) {
            Derivation deriv = derivationStream.peek();
            this.featurizeAndScoreDerivation(deriv);
            this.addToAgendaWithScore(derivationStream, deriv.score, probSum);
            if (this.completeDerivationsPushed % 100 == 0) {
                this.agenda.sort();
            }
        } else {
            while (derivationStream.hasNext()) {
                Derivation deriv = (Derivation)derivationStream.next();
                this.featurizeAndScoreDerivation(deriv);
                SingleDerivationStream newStream = SingleDerivationStream.constant(deriv);
                this.addToAgendaWithScore(newStream, deriv.score, probSum);
                if (this.completeDerivationsPushed % 100 != 0) continue;
                this.agenda.sort();
            }
        }
    }

    @Override
    protected void featurizeAndScoreDerivation(Derivation deriv) {
        if (deriv.isFeaturizedAndScored()) {
            return;
        }
        this.parser.extractor.extractLocal(this.ex, deriv);
        FeatureVector searchFV = deriv.addPrefixLocalFeatureVector(this.parser.searchPrefix);
        deriv.score = searchFV.dotProduct(this.params);
        if (deriv.children != null) {
            for (Derivation child : deriv.children) {
                deriv.score += child.score;
            }
        }
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"featurizeAndScore(score=%s) %s %s: %s [rule: %s]", (Object[])new Object[]{Fmt.D((double)deriv.score), deriv.cat, this.ex.spanString(deriv.start, deriv.end), deriv, deriv.rule});
        }
        ++this.numOfFeaturizedDerivs;
    }

    private void addToAgendaWithScore(DerivationStream derivationStream, double derivScore, double probSum) {
        if (derivScore == Double.NEGATIVE_INFINITY) {
            return;
        }
        Derivation deriv = derivationStream.peek();
        double priority = derivScore - (double)this.completeDerivationsPushed++ * 1.0E-19;
        this.agenda.add(new PrioritizedDerivationStream(derivationStream, priority, probSum), priority);
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"ReinforcementParser: adding to agenda: size=%s, priority=%s, deriv=%s(%s,%s), formula=%s,|pushed|=%s", (Object[])new Object[]{this.agenda.size(), priority, deriv.cat, deriv.start, deriv.end, deriv.formula, this.completeDerivationsPushed});
        }
    }

    public boolean continueParsing() {
        if (this.agenda.size() == 0) {
            LogInfo.log((Object)"Agenda is empty");
            return false;
        }
        return this.chart[0][this.numTokens].get("$ROOT") == null || ((List)this.chart[0][this.numTokens].get("$ROOT")).size() < this.getBeamSize();
    }

    @Override
    public void infer() {
        if (this.numTokens == 0) {
            return;
        }
        ReinforcementParserState oracleState = null;
        this.expectedCounts = new HashMap();
        if (this.computeExpectedCounts && !ReinforcementParser.opts.simulateNonRlObjective) {
            LogInfo.begin_track((String)"Finding oracle derivation", (Object[])new Object[0]);
            oracleState = new StateBuilder().parser(this.parser).params(this.params).example(this.ex).samplingStrategy("agenda").computeExpectedCounts(false).createState();
            oracleState.infer();
            LogInfo.end_track();
            if (oracleState.correctDerivations.isEmpty()) {
                LogInfo.logs((String)"No oracle derivation found", (Object[])new Object[0]);
                return;
            }
        }
        this.createSampler(oracleState);
        LogInfo.begin_track((String)"Coarse parsing", (Object[])new Object[0]);
        this.coarseParserState = null;
        if (ReinforcementParser.opts.efficientCoarsePrune) {
            this.coarseParserState = this.coarseParser.getCoarsePrunedChart(this.ex);
        }
        LogInfo.end_track();
        LogInfo.begin_track((String)"ReinforcementParserState.inferBySampling", (Object[])new Object[0]);
        this.sampleHistoryAndInfer();
        LogInfo.end_track();
        this.setPredDerivations();
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"Expected reward = %s", (Object[])new Object[]{this.objectiveValue});
        }
        this.visualizeChart();
    }

    private void sampleHistoryAndInfer() {
        for (Derivation deriv : this.gatherTokenAndPhraseDerivations()) {
            this.addToAgenda(SingleDerivationStream.constant(deriv));
        }
        for (DerivationStream derivStream : this.gatherRhsTerminalsDerivations()) {
            this.addToAgenda(derivStream);
        }
        this.ensureExecuted();
        while (this.continueParsing()) {
            this.unrollHighProbStreams();
            Pair<PrioritizedDerivationStream, Double> pdsAndProbability = this.sampler.sample();
            DerivationStream sampledDerivations = ((PrioritizedDerivationStream)pdsAndProbability.getFirst()).derivStream;
            Derivation sampledDerivation = (Derivation)sampledDerivations.next();
            this.updateBackpointers(sampledDerivations, sampledDerivation);
            ++this.numItemsSampled;
            assert (sampledDerivation.isFeaturizedAndScored()) : "top derivation is not featurized and scored: " + sampledDerivation;
            assert (Math.abs(sampledDerivation.score - ((PrioritizedDerivationStream)pdsAndProbability.getFirst()).priority) < 1.0E-4) : sampledDerivation.score + " != " + ((PrioritizedDerivationStream)pdsAndProbability.getFirst()).priority;
            if (this.parser.verbose(2)) {
                LogInfo.begin_track((String)"Item %d (|agenda|=%d), priority %s: |item|=%s -> %s %s %s [%s], prob=%s", (Object[])new Object[]{this.numItemsSampled, this.agenda.size() + 1, Fmt.D((double)((PrioritizedDerivationStream)pdsAndProbability.getFirst()).priority), sampledDerivations.estimatedSize(), sampledDerivation.cat, this.ex.spanString(sampledDerivation.start, sampledDerivation.end), sampledDerivation, sampledDerivation.rule, pdsAndProbability.getSecond()});
            }
            this.handleRootDerivation(this.ex, this.numItemsSampled, sampledDerivation);
            if (this.computeExpectedCounts) {
                HashMap<String, Double> counts = new HashMap<String, Double>();
                if ((Double)pdsAndProbability.getSecond() > -1.0E-4) {
                    sampledDerivation.incrementLocalFeatureVector(1.0 - ((PrioritizedDerivationStream)pdsAndProbability.getFirst()).probSum, counts);
                } else {
                    sampledDerivation.incrementLocalFeatureVector(-((PrioritizedDerivationStream)pdsAndProbability.getFirst()).probSum, counts);
                }
                if (this.parser.verbose(3)) {
                    SempreUtils.logMap(counts, "agenda item gradient");
                }
                ReinforcementUtils.addToDoubleMap(this.stateSequenceExpectedCounts, counts, this.parser.searchPrefix);
            }
            if (this.addToBoundedChart(sampledDerivation)) {
                if (this.parser.verbose(5)) {
                    LogInfo.logs((String)"ReinforcementParserState.infer: adding to chart %s(%s,%s) formula=%s", (Object[])new Object[]{sampledDerivation.cat, sampledDerivation.start, sampledDerivation.end, sampledDerivation.formula});
                }
                this.combineWithChartDerivations(sampledDerivation);
            }
            this.addToAgenda(sampledDerivations);
            if (!this.parser.verbose(2)) continue;
            LogInfo.end_track();
        }
        this.finalizeSearchExpectedCounts();
        this.rerankRootDerivations();
        if (this.computeExpectedCounts) {
            this.computeGradient();
        }
    }

    private void unrollHighProbStreams() {
        int i;
        if (this.samplingStrategy.equals("max")) {
            return;
        }
        this.sampler.unroll();
        if (this.parser.verbose(3)) {
            LogInfo.begin_track((String)"Unrolling high probability streams", (Object[])new Object[0]);
        }
        double lb = Double.NEGATIVE_INFINITY;
        int numOfHiddenStreams = 0;
        for (PrioritizedDerivationStream pds : this.agenda) {
            lb = NumUtils.logAdd((double)lb, (double)pds.getScore());
            if (pds.derivStream.estimatedSize() <= 1) continue;
            ++numOfHiddenStreams;
        }
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"unrollHighProbStreams(): |agenda|=%s, lb=%s, |hiddenstreams|=%s", (Object[])new Object[]{this.agenda.size(), lb, numOfHiddenStreams});
        }
        ArrayList<Pair> derivsToAdd = new ArrayList<Pair>();
        ArrayList<Integer> indicesToRemove = new ArrayList<Integer>();
        for (i = 0; i < this.agenda.size(); ++i) {
            PrioritizedDerivationStream pds = this.agenda.get(i);
            boolean modified = false;
            while (pds.derivStream.hasNext() && pds.derivStream.estimatedSize() > 1 && this.illegalStream(pds.derivStream, lb, pds.derivStream.estimatedSize(), numOfHiddenStreams)) {
                modified = true;
                Derivation nextDeriv = (Derivation)pds.derivStream.next();
                this.updateBackpointers(pds.derivStream, nextDeriv);
                SingleDerivationStream derivStream = SingleDerivationStream.constant(nextDeriv);
                if (this.parser.verbose(3) && derivStream.hasNext()) {
                    Derivation deriv = derivStream.peek();
                    LogInfo.logs((String)"unrollIllegalStreams(): add deriv=%s(%s,%s) [%s] score=%s, |stream|=%s", (Object[])new Object[]{deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.score, pds.derivStream.estimatedSize()});
                }
                derivsToAdd.add(Pair.newPair((Object)derivStream, (Object)pds.probSum));
                if (pds.derivStream.hasNext()) {
                    this.featurizeAndScoreDerivation(pds.derivStream.peek());
                    lb = NumUtils.logAdd((double)lb, (double)pds.getScore());
                }
                if (pds.derivStream.estimatedSize() > 1) continue;
                --numOfHiddenStreams;
            }
            if (!modified) continue;
            indicesToRemove.add(i);
            derivsToAdd.add(Pair.newPair((Object)pds.derivStream, (Object)pds.probSum));
        }
        for (i = indicesToRemove.size() - 1; i >= 0; --i) {
            this.agenda.remove(this.agenda.get((Integer)indicesToRemove.get(i)), (Integer)indicesToRemove.get(i));
        }
        for (Pair pair : derivsToAdd) {
            this.addToAgenda((DerivationStream)pair.getFirst(), (Double)pair.getSecond());
        }
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"unrollHighProbStreams(): |agenda|=%s", (Object[])new Object[]{this.agenda.size()});
        }
        if (this.parser.verbose(3)) {
            LogInfo.end_track();
        }
    }

    private boolean illegalStream(DerivationStream derivStream, double logSum, int estimatedSize, int numOfHiddenStreams) {
        Derivation deriv = derivStream.peek();
        double firstItemLogProb = deriv.score - logSum;
        double upperBound = Math.log(estimatedSize) + Math.log(numOfHiddenStreams);
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"IllegalStream(): score=%s, logsum=%s, |stream|=%s, |hiddenstreams|=%s, deriv=%s(%s,%s) %s, sum=%s", (Object[])new Object[]{deriv.score, logSum, estimatedSize, numOfHiddenStreams, deriv.cat, deriv.start, deriv.end, deriv.formula, firstItemLogProb + upperBound});
        }
        return firstItemLogProb + upperBound > LOG_SMALL_PROB;
    }

    private boolean isHighProbStream(DerivationStream derivStream, double maxScore, int estimatedSize) {
        Derivation deriv = derivStream.peek();
        double gapFromMax = deriv.score - maxScore;
        double threshold = LOG_SMALL_PROB - Math.log(estimatedSize);
        if (this.parser.verbose(3)) {
            LogInfo.logs((String)"isHighProbStream(): gapFromMax=%s, threshold=%s, deriv=%s(%s,%s) %s |stream|=%s", (Object[])new Object[]{gapFromMax, threshold, deriv.cat, deriv.start, deriv.end, deriv.formula, derivStream.estimatedSize()});
        }
        return gapFromMax > threshold;
    }

    private void rerankRootDerivations() {
        this.setPredDerivations();
        for (Derivation rootDeriv : this.predDerivations) {
            double oldScore = rootDeriv.score;
            rootDeriv.computeScore(this.params);
            if (!this.parser.verbose(3)) continue;
            LogInfo.logs((String)"ReinforcementParser.rerankRootDerivations: deriv=%s, old=%s, new=%s", (Object[])new Object[]{rootDeriv, oldScore, rootDeriv.score});
        }
        Derivation.sortByScore(this.predDerivations);
    }

    private void updateBackpointers(DerivationStream stream, Derivation sampledDeriv) {
        List<Derivation> list;
        Pair pair = this.backpointerList.get(sampledDeriv.creationIndex);
        if (!stream.hasNext()) {
            return;
        }
        if (pair == null) {
            list = new ArrayList();
            ((ArrayList)list).add(sampledDeriv);
            pair = Pair.newPair(list, (Object)0);
            this.backpointerList.put(sampledDeriv.creationIndex, (Pair<ArrayList<Derivation>, Integer>)pair);
        }
        list = (List)pair.getFirst();
        Derivation nextDeriv = stream.peek();
        list.add(nextDeriv);
        this.backpointerList.put(nextDeriv.creationIndex, (Pair<ArrayList<Derivation>, Integer>)Pair.newPair((Object)((ArrayList)pair.getFirst()), (Object)(list.size() - 1)));
    }

    private double computeExpectedReward(List<Derivation> predDerivations, double[] probs) {
        double rewardExpectation = 0.0;
        for (int i = 0; i < predDerivations.size(); ++i) {
            rewardExpectation += probs[i] * ReinforcementParserState.compatibilityToReward(predDerivations.get((int)i).compatibility);
        }
        return rewardExpectation;
    }

    private void computeGradient() {
        if (this.predDerivations.isEmpty()) {
            return;
        }
        double[] qDist = this.sampler.getDerivDistribution(this.predDerivations);
        double[] piDist = ReinforcementUtils.expNormalize(this.predDerivations);
        LogInfo.begin_track((String)"Computing gradient", (Object[])new Object[0]);
        double rewardExpectation = this.computeExpectedReward(this.predDerivations, qDist);
        HashMap<String, Double> featureExpectation = new HashMap<String, Double>();
        HashMap<String, Double> rewardInfusedFeatureExpectation = new HashMap<String, Double>();
        for (int i = 0; i < this.predDerivations.size(); ++i) {
            Derivation deriv = (Derivation)this.predDerivations.get(i);
            deriv.incrementAllFeatureVector(piDist[i], featureExpectation);
            deriv.incrementAllFeatureVector(qDist[i] * ReinforcementParserState.compatibilityToReward(deriv.compatibility), rewardInfusedFeatureExpectation);
        }
        Map<String, Double> sampleCounts = new HashMap<String, Double>();
        if (ReinforcementParser.opts.simulateNonRlObjective) {
            ParserState.computeExpectedCounts(this.predDerivations, sampleCounts);
        } else {
            sampleCounts = ReinforcementUtils.multiplyDoubleMap(this.stateSequenceExpectedCounts, rewardExpectation);
            SempreUtils.addToDoubleMap(sampleCounts, rewardInfusedFeatureExpectation);
            ReinforcementUtils.subtractFromDoubleMap(sampleCounts, ReinforcementUtils.multiplyDoubleMap(featureExpectation, rewardExpectation));
        }
        SempreUtils.addToDoubleMap(this.expectedCounts, sampleCounts);
        double sum = 0.0;
        for (String key : sampleCounts.keySet()) {
            double value = sampleCounts.get(key);
            if (this.parser.verbose(3)) {
                LogInfo.logs((String)"feature=%s, value=%s", (Object[])new Object[]{key, value});
            }
            sum += value * value;
        }
        LogInfo.logs((String)"L2 norm: %s", (Object[])new Object[]{Math.sqrt(sum)});
        LogInfo.end_track();
    }

    private void createSampler(ReinforcementParserState oracleState) {
        if ("proposal".equals(this.samplingStrategy)) {
            if (oracleState == null) {
                throw new RuntimeException("missing oracle state");
            }
            this.sampler = new MultiplicativeProposalSampler(oracleState);
        } else if ("max".equals(this.samplingStrategy)) {
            this.sampler = new MaxSampler();
        } else if ("agenda".equals(this.samplingStrategy) || this.samplingStrategy == null) {
            this.sampler = new AgendaSampler();
        }
    }

    private void visualizeChart() {
        if (this.parser.chartFillOut != null && Parser.opts.visualizeChartFilling) {
            this.parser.chartFillOut.println(Json.writeValueAsStringHard(new ChartParserState.ChartFillingData(this.ex.id, this.chartFillingList, this.ex.utterance, this.ex.numTokens())));
            this.parser.chartFillOut.flush();
        }
    }

    private void finalizeSearchExpectedCounts() {
        if (ReinforcementParser.opts.simulateNonRlObjective) {
            return;
        }
        if (!this.computeExpectedCounts) {
            return;
        }
        HashMap<String, Double> counts = new HashMap<String, Double>();
        for (PrioritizedDerivationStream pds : this.agenda) {
            pds.derivStream.peek().incrementLocalFeatureVector(pds.probSum, counts);
        }
        if (this.parser.verbose(3)) {
            SempreUtils.logMap(counts, "subtracted");
        }
        ReinforcementUtils.subtractFromDoubleMap(this.stateSequenceExpectedCounts, counts, this.parser.searchPrefix);
        if (this.parser.verbose(3)) {
            SempreUtils.logMap(this.stateSequenceExpectedCounts, "final search gradient");
        }
    }

    private void handleRootDerivation(Example ex, int numItemsSampled, Derivation sampledDerivation) {
        if (!sampledDerivation.isRoot(ex.numTokens())) {
            return;
        }
        sampledDerivation.ensureExecuted(this.parser.executor, ex.context);
        if (ex.targetValue != null) {
            sampledDerivation.compatibility = this.parser.valueEvaluator.getCompatibility(ex.targetValue, sampledDerivation.value);
        }
        if (Parser.opts.partialReward ? sampledDerivation.compatibility > 0.0 : sampledDerivation.compatibility == 1.0) {
            if (this.parser.verbose(2)) {
                LogInfo.logs((String)"Top-level %s: reward = %s", (Object[])new Object[]{numItemsSampled, sampledDerivation.compatibility});
            }
            this.correctDerivations.add(sampledDerivation);
            if (this.correctDerivations.get((int)0).compatibility < sampledDerivation.compatibility) {
                Collections.swap(this.correctDerivations, 0, this.correctDerivations.size() - 1);
            }
            if (this.firstCorrectItem == -1) {
                this.firstCorrectItem = numItemsSampled;
            }
        }
    }

    @Override
    public void setEvaluation() {
        LogInfo.begin_track_printAll((String)"ReinforcementParserParserState.setEvaluation", (Object[])new Object[0]);
        super.setEvaluation();
        if (this.coarseParserState != null) {
            this.evaluation.add("coarseParseTime", (double)this.coarseParserState.getCoarseParseTime());
        }
        if (this.firstCorrectItem != -1) {
            this.evaluation.add("firstCorrectItem", (double)this.firstCorrectItem);
        }
        LogInfo.end_track();
    }

    public static class CorrectDerivationComparator
    implements Comparator<Derivation> {
        @Override
        public int compare(Derivation deriv1, Derivation deriv2) {
            if (deriv1.compatibility > deriv2.compatibility) {
                return -1;
            }
            if (deriv1.compatibility < deriv2.compatibility) {
                return 1;
            }
            boolean deriv1Join = this.containsJoin(deriv1);
            boolean deriv2Join = this.containsJoin(deriv2);
            if (deriv1Join && !deriv2Join) {
                return -1;
            }
            if (!deriv1Join && deriv2Join) {
                return 1;
            }
            if (deriv1.score > deriv2.score) {
                return -1;
            }
            if (deriv1.score < deriv2.score) {
                return 1;
            }
            if (deriv1.creationIndex < deriv2.creationIndex) {
                return -1;
            }
            if (deriv1.creationIndex > deriv2.creationIndex) {
                return 1;
            }
            return 0;
        }

        private boolean containsJoin(Derivation d) {
            SemanticFn semanticFn = d.rule.getSem();
            if (semanticFn != null && semanticFn instanceof JoinFn) {
                return true;
            }
            for (Derivation child : d.children) {
                if (!this.containsJoin(child)) continue;
                return true;
            }
            return false;
        }
    }

    class MultiplicativeProposalSampler
    extends Sampler {
        private double bonus;
        private OracleInfo oracleInfo;

        public MultiplicativeProposalSampler(ReinforcementParserState oracleState) {
            this.oracleInfo = new OracleInfo(oracleState);
            this.bonus = ReinforcementParser.opts.multiplicativeBonus;
            LogInfo.logs((String)"Bonus=%s", (Object[])new Object[]{this.bonus});
        }

        @Override
        public Pair<PrioritizedDerivationStream, Double> sample() {
            double[] modelProbs = ReinforcementUtils.expNormalize(ReinforcementParserState.this.agenda);
            double[] samplerProbs = this.getUnnormalizedAgendaDistribution();
            if (!NumUtils.expNormalize((double[])samplerProbs)) {
                throw new RuntimeException("Normalization failed" + Arrays.toString(samplerProbs));
            }
            int sampledIndex = ReinforcementUtils.sampleIndex(ReinforcementParserState.this.randGen, samplerProbs);
            PrioritizedDerivationStream pds = (PrioritizedDerivationStream)ReinforcementParserState.this.agenda.get(sampledIndex);
            double prob = samplerProbs[sampledIndex];
            if (ReinforcementParserState.this.parser.verbose(3)) {
                Derivation deriv = pds.derivStream.peek();
                if (this.oracleInfo.oracleDerivInfos.contains(new DerivInfo(deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.rule))) {
                    LogInfo.logs((String)"MultiplicativeProposalSampler.sample(): Sampled from correct!, prob=%s", (Object[])new Object[]{prob});
                } else {
                    LogInfo.logs((String)"MultiplicativeProposalSampler.sample(): Sampled from incorrect!, prob=%s", (Object[])new Object[]{prob});
                }
            }
            boolean returnProb = true;
            if (ReinforcementParserState.this.computeExpectedCounts && ReinforcementParser.opts.updateGradientForCorrectMovesOnly) {
                Derivation deriv = pds.derivStream.peek();
                DerivInfo derivInfo = new DerivInfo(deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.rule);
                if (this.oracleInfo.oracleDerivInfos.contains(derivInfo)) {
                    this.updateProbSum(modelProbs);
                } else {
                    returnProb = false;
                }
                if (ReinforcementParserState.this.parser.verbose(3)) {
                    LogInfo.logs((String)"Updating gradient=%s", (Object[])new Object[]{returnProb});
                }
            } else if (ReinforcementParserState.this.computeExpectedCounts) {
                this.updateProbSum(modelProbs);
            }
            ReinforcementParserState.this.agenda.remove(pds, sampledIndex);
            return Pair.newPair((Object)pds, (Object)(returnProb ? prob : -1.0));
        }

        @Override
        public void unroll() {
            int i;
            if (ReinforcementParserState.this.parser.verbose(3)) {
                LogInfo.begin_track((String)"MultiplicativeBonusSampler.unroll()", (Object[])new Object[0]);
            }
            ArrayList<Pair> derivsToAdd = new ArrayList<Pair>();
            ArrayList<Integer> indicesToRemove = new ArrayList<Integer>();
            for (i = 0; i < ReinforcementParserState.this.agenda.size(); ++i) {
                PrioritizedDerivationStream pds = (PrioritizedDerivationStream)ReinforcementParserState.this.agenda.get(i);
                boolean modified = false;
                while (pds.derivStream.hasNext() && pds.derivStream.estimatedSize() > 1 && this.oracleInfo.isNecessaryDeriv(pds.derivStream.peek())) {
                    modified = true;
                    Derivation nextDeriv = (Derivation)pds.derivStream.next();
                    SingleDerivationStream newDerivStream = SingleDerivationStream.constant(nextDeriv);
                    if (ReinforcementParserState.this.parser.verbose(3) && newDerivStream.hasNext()) {
                        Derivation deriv = newDerivStream.peek();
                        LogInfo.logs((String)"MultiplicativeSampler.unroll(): add necessary deriv=%s(%s,%s) [%s] score=%s, |stream|=%s, creationIndex=%s", (Object[])new Object[]{deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.score, pds.derivStream.estimatedSize(), deriv.creationIndex});
                    }
                    derivsToAdd.add(Pair.newPair((Object)newDerivStream, (Object)pds.probSum));
                    if (!pds.derivStream.hasNext()) continue;
                    ReinforcementParserState.this.featurizeAndScoreDerivation(pds.derivStream.peek());
                }
                if (!modified) continue;
                indicesToRemove.add(i);
                derivsToAdd.add(Pair.newPair((Object)pds.derivStream, (Object)pds.probSum));
            }
            for (i = indicesToRemove.size() - 1; i >= 0; --i) {
                ReinforcementParserState.this.agenda.remove((PrioritizedDerivationStream)ReinforcementParserState.this.agenda.get((Integer)indicesToRemove.get(i)), (Integer)indicesToRemove.get(i));
            }
            for (Pair pair : derivsToAdd) {
                ReinforcementParserState.this.addToAgenda((DerivationStream)pair.getFirst(), (Double)pair.getSecond());
            }
            if (ReinforcementParserState.this.parser.verbose(3)) {
                LogInfo.end_track();
            }
        }

        private double[] getUnnormalizedAgendaDistribution() {
            double[] probs = new double[ReinforcementParserState.this.agenda.size()];
            for (int i = 0; i < ReinforcementParserState.this.agenda.size(); ++i) {
                Derivation d = ((PrioritizedDerivationStream)((ReinforcementParserState)ReinforcementParserState.this).agenda.get((int)i)).derivStream.peek();
                probs[i] = d.score;
                if (!this.oracleInfo.oracleDerivInfos.contains(new DerivInfo(d.cat, d.start, d.end, d.formula, d.rule))) continue;
                int n = i;
                probs[n] = probs[n] + this.bonus;
            }
            return probs;
        }

        @Override
        public double[] getDerivDistribution(List<Derivation> rootDerivs) {
            double[] res = new double[rootDerivs.size()];
            for (int i = 0; i < rootDerivs.size(); ++i) {
                Derivation rootDeriv = rootDerivs.get(i);
                res[i] = rootDeriv.score + this.bonus * rootDeriv.compatibility;
            }
            NumUtils.expNormalize((double[])res);
            return res;
        }
    }

    class OracleInfo {
        private List<DerivInfo> necessaryDerivInfos;
        private List<DerivInfo> oracleDerivInfos;
        private Map<Long, Pair<ArrayList<Derivation>, Integer>> backPointers;
        private NecessaryDeriv[] necessaryDerivsCache;
        long firstCorrectDerivNumber = -1L;

        public OracleInfo(ReinforcementParserState oracleState) {
            if (oracleState == null) {
                throw new RuntimeException("oracle state is null");
            }
            this.necessaryDerivInfos = new ArrayList<DerivInfo>();
            this.oracleDerivInfos = new ArrayList<DerivInfo>();
            if (!oracleState.correctDerivations.isEmpty()) {
                Collections.sort(oracleState.correctDerivations, new CorrectDerivationComparator());
                this.backPointers = oracleState.backpointerList;
                Derivation oracleDeriv = oracleState.correctDerivations.get(0);
                this.firstCorrectDerivNumber = oracleDeriv.creationIndex;
                LogInfo.logs((String)"OracleSampler: deriv=%s, comp=%s", (Object[])new Object[]{oracleDeriv, oracleDeriv.compatibility});
                this.populateCorrectDerivations(oracleDeriv);
                if (ReinforcementParserState.this.parser.verbose(2)) {
                    LogInfo.begin_track((String)"OracleSampler: necessary infos:", (Object[])new Object[0]);
                    for (DerivInfo necessaryInfo : this.necessaryDerivInfos) {
                        LogInfo.log((Object)necessaryInfo);
                    }
                    LogInfo.end_track();
                    LogInfo.begin_track((String)"OracleSampler: oracle infos:", (Object[])new Object[0]);
                    for (DerivInfo oracleInfo : this.oracleDerivInfos) {
                        LogInfo.log((Object)oracleInfo);
                    }
                    LogInfo.end_track();
                }
            }
        }

        private void populateCorrectDerivations(Derivation oracleDeriv) {
            DerivInfo derivInfo;
            Pair<ArrayList<Derivation>, Integer> listAndIndex;
            if (ReinforcementParserState.this.parser.verbose(4)) {
                LogInfo.logs((String)"populateCorrectDerivations(): oracle deriv: %s", (Object[])new Object[]{oracleDeriv});
            }
            if ((listAndIndex = this.backPointers.get(oracleDeriv.creationIndex)) != null) {
                for (int i = (Integer)listAndIndex.getSecond() - 1; i >= 0; --i) {
                    DerivInfo derivInfo2;
                    Derivation deriv = (Derivation)((ArrayList)listAndIndex.getFirst()).get(i);
                    if (ReinforcementParserState.this.parser.verbose(4)) {
                        LogInfo.logs((String)"populateCorrectDerivations(): necessary deriv: %s", (Object[])new Object[]{deriv});
                    }
                    if (this.necessaryDerivInfos.contains(derivInfo2 = new DerivInfo(deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.rule))) continue;
                    this.necessaryDerivInfos.add(derivInfo2);
                }
            }
            if (!this.oracleDerivInfos.contains(derivInfo = new DerivInfo(oracleDeriv.cat, oracleDeriv.start, oracleDeriv.end, oracleDeriv.formula, oracleDeriv.rule))) {
                this.necessaryDerivInfos.add(derivInfo);
                this.oracleDerivInfos.add(derivInfo);
            }
            for (Derivation child : oracleDeriv.children) {
                this.populateCorrectDerivations(child);
            }
        }

        protected boolean isNecessaryDeriv(Derivation deriv) {
            int index;
            if (this.necessaryDerivInfos.isEmpty()) {
                return false;
            }
            if (this.necessaryDerivsCache == null) {
                this.necessaryDerivsCache = new NecessaryDeriv[200000];
                Arrays.fill((Object[])this.necessaryDerivsCache, (Object)NecessaryDeriv.UNKNOWN);
            }
            if ((index = (int)(deriv.creationIndex - this.firstCorrectDerivNumber)) < 0) {
                throw new RuntimeException("Negative index - correct index larger than deriv number");
            }
            if (index >= 200000) {
                LogInfo.warnings((String)"isNecessaryDeriv(): index larger than 200000: %s", (Object[])new Object[]{index});
                return this.necessaryDerivInfos.contains(new DerivInfo(deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.rule));
            }
            if (this.necessaryDerivsCache[index] == NecessaryDeriv.UNNECESSARY_DERIV) {
                return false;
            }
            if (this.necessaryDerivsCache[index] == NecessaryDeriv.NECESSARY_DERIV) {
                return true;
            }
            boolean res = this.necessaryDerivInfos.contains(new DerivInfo(deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.rule));
            this.necessaryDerivsCache[index] = res ? NecessaryDeriv.NECESSARY_DERIV : NecessaryDeriv.UNNECESSARY_DERIV;
            return res;
        }
    }

    class MaxSampler
    extends Sampler {
        MaxSampler() {
        }

        @Override
        public Pair<PrioritizedDerivationStream, Double> sample() {
            PrioritizedDerivationStream pds = (PrioritizedDerivationStream)ReinforcementParserState.this.agenda.pop();
            return Pair.newPair((Object)pds, (Object)1.0);
        }

        @Override
        public double[] getDerivDistribution(List<Derivation> rootDerivs) {
            double[] res = new double[rootDerivs.size()];
            Arrays.fill(res, 0.0);
            res[0] = 1.0;
            return res;
        }

        @Override
        public void unroll() {
        }
    }

    class AgendaSampler
    extends Sampler {
        AgendaSampler() {
        }

        @Override
        public Pair<PrioritizedDerivationStream, Double> sample() {
            double[] modelProbs = ReinforcementUtils.expNormalize(ReinforcementParserState.this.agenda);
            if (ReinforcementParserState.this.computeExpectedCounts) {
                this.updateProbSum(modelProbs);
            }
            int sampledIndex = SampleUtils.sampleMultinomial((Random)ReinforcementParserState.this.randGen, (double[])modelProbs);
            PrioritizedDerivationStream pds = (PrioritizedDerivationStream)ReinforcementParserState.this.agenda.get(sampledIndex);
            double prob = modelProbs[sampledIndex];
            ReinforcementParserState.this.agenda.remove(pds, sampledIndex);
            return Pair.newPair((Object)pds, (Object)prob);
        }

        @Override
        public double[] getDerivDistribution(List<Derivation> rootDerivs) {
            return ReinforcementUtils.expNormalize(rootDerivs);
        }

        @Override
        public void unroll() {
        }
    }

    abstract class Sampler {
        Sampler() {
        }

        public abstract Pair<PrioritizedDerivationStream, Double> sample();

        public abstract double[] getDerivDistribution(List<Derivation> var1);

        public abstract void unroll();

        public void updateProbSum(double[] modelProbs) {
            for (int i = 0; i < ReinforcementParserState.this.agenda.size(); ++i) {
                PrioritizedDerivationStream pds = (PrioritizedDerivationStream)ReinforcementParserState.this.agenda.get(i);
                pds.addProb(modelProbs[i]);
                if (!ReinforcementParserState.this.parser.verbose(3)) continue;
                LogInfo.logs((String)"updateProbSum(): deriv=%s, probSum=%s", (Object[])new Object[]{pds.derivStream.peek(), pds.probSum});
            }
        }
    }

    public static class StateBuilder {
        private ReinforcementParser parser;
        private Params params;
        private Example example;
        private ParserState coarseState;
        private String samplingStrategy = null;
        private boolean computeExpectedCounts;

        public StateBuilder parser(ReinforcementParser parser) {
            this.parser = parser;
            return this;
        }

        public StateBuilder params(Params params) {
            this.params = params;
            return this;
        }

        public StateBuilder example(Example example) {
            this.example = example;
            return this;
        }

        public StateBuilder samplingStrategy(String samplingStrategy) {
            this.samplingStrategy = samplingStrategy;
            return this;
        }

        public StateBuilder computeExpectedCounts(boolean computeExpectedCounts) {
            this.computeExpectedCounts = computeExpectedCounts;
            return this;
        }

        public ReinforcementParserState createState() {
            return new ReinforcementParserState(this.parser, this.params, this.example, this.computeExpectedCounts, this.samplingStrategy);
        }
    }

    public static enum NecessaryDeriv {
        NECESSARY_DERIV,
        UNNECESSARY_DERIV,
        UNKNOWN;

    }
}

