/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import edu.stanford.nlp.sempre.ChartParserState;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.Parser;
import edu.stanford.nlp.sempre.Rule;
import fig.basic.Evaluation;
import fig.basic.Fmt;
import fig.basic.LogInfo;
import fig.basic.NumUtils;
import fig.basic.Option;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public abstract class ParserState {
    public static Options opts = new Options();
    public final Parser parser;
    public final Params params;
    public final Example ex;
    public final boolean computeExpectedCounts;
    public final List<Derivation> predDerivations = new ArrayList<Derivation>();
    public final Evaluation evaluation = new Evaluation();
    public Map<String, Double> expectedCounts;
    public double objectiveValue;
    public final int numTokens;
    public long parseTime;
    public int maxCellSize;
    public String maxCellDescription;
    public boolean fallOffBeam;
    public int totalGeneratedDerivs;
    public int numOfFeaturizedDerivs = 0;

    public ParserState(Parser parser, Params params, Example ex, boolean computeExpectedCounts) {
        this.parser = parser;
        this.params = params;
        this.ex = ex;
        this.computeExpectedCounts = computeExpectedCounts;
        this.numTokens = ex.numTokens();
    }

    protected int getBeamSize() {
        return Parser.opts.beamSize;
    }

    public abstract void infer();

    protected void featurizeAndScoreDerivation(Derivation deriv) {
        if (deriv.isFeaturizedAndScored()) {
            LogInfo.warnings((String)"Derivation already featurized: %s", (Object[])new Object[]{deriv});
            return;
        }
        this.parser.extractor.extractLocal(this.ex, deriv);
        deriv.computeScoreLocal(this.params);
        if (ParserState.opts.throwFeaturesAway) {
            deriv.clearFeatures();
        }
        if (this.parser.verbose(5)) {
            LogInfo.logs((String)"featurizeAndScoreDerivation(score=%s) %s %s: %s [rule: %s]", (Object[])new Object[]{Fmt.D((double)deriv.score), deriv.cat, this.ex.spanString(deriv.start, deriv.end), deriv, deriv.rule});
        }
        ++this.numOfFeaturizedDerivs;
    }

    protected void pruneCell(String cellDescription, List<Derivation> derivations) {
        if (derivations == null) {
            return;
        }
        if (derivations.size() > this.maxCellSize) {
            this.maxCellSize = derivations.size();
            this.maxCellDescription = cellDescription;
            if (this.maxCellSize > 5000) {
                LogInfo.logs((String)"ParserState.pruneCell %s: maxCellSize = %s entries (not pruned yet)", (Object[])new Object[]{this.maxCellDescription, this.maxCellSize});
            }
        }
        int i = 0;
        for (Derivation deriv : derivations) {
            deriv.maxUnsortedBeamPosition = i;
            if (deriv.children != null) {
                for (Derivation child : deriv.children) {
                    deriv.maxUnsortedBeamPosition = Math.max(deriv.maxUnsortedBeamPosition, child.maxUnsortedBeamPosition);
                }
            }
            if (deriv.preSortBeamPosition == -1) {
                deriv.preSortBeamPosition = i;
            }
            ++i;
        }
        if (Parser.opts.derivationScoreNoise > 0.0) {
            for (Derivation deriv : derivations) {
                deriv.score += Parser.opts.derivationScoreRandom.nextDouble() * Parser.opts.derivationScoreNoise;
            }
        }
        Derivation.sortByScore(derivations);
        if (Parser.opts.verbose >= 3) {
            LogInfo.begin_track((String)"ParserState.pruneCell(%s): %d derivations", (Object[])new Object[]{cellDescription, derivations.size()});
            for (Derivation deriv : derivations) {
                LogInfo.logs((String)"%s(%s,%s): %s %s, [score=%s] allAnchored: %s", (Object[])new Object[]{deriv.cat, deriv.start, deriv.end, deriv.formula, deriv.canonicalUtterance, deriv.score, deriv.allAnchored()});
            }
            LogInfo.end_track();
        }
        i = 0;
        for (Derivation deriv : derivations) {
            deriv.maxBeamPosition = i;
            if (deriv.children != null) {
                for (Derivation child : deriv.children) {
                    deriv.maxBeamPosition = Math.max(deriv.maxBeamPosition, child.maxBeamPosition);
                }
            }
            deriv.postSortBeamPosition = i++;
        }
        if (ChartParserState.opts.pruneByProbDiff) {
            double highestScore = derivations.get((int)0).score;
            while (highestScore - derivations.get((int)(derivations.size() - 1)).score > Math.log(ParserState.opts.probDiffPruningThresh)) {
                derivations.remove(derivations.size() - 1);
                this.fallOffBeam = true;
            }
        } else {
            int beamSize = this.getBeamSize();
            if (derivations.size() > beamSize && Parser.opts.verbose >= 1) {
                LogInfo.logs((String)"ParserState.pruneCell %s: Pruning %d -> %d derivations", (Object[])new Object[]{cellDescription, derivations.size(), beamSize});
            }
            while (derivations.size() > beamSize) {
                derivations.remove(derivations.size() - 1);
                this.fallOffBeam = true;
            }
        }
    }

    public List<Derivation> gatherTokenAndPhraseDerivations() {
        int i;
        ArrayList<Derivation> derivs = new ArrayList<Derivation>();
        for (i = 0; i < this.numTokens; ++i) {
            derivs.add(new Derivation.Builder().cat("$TOKEN").start(i).end(i + 1).rule(Rule.nullRule).children(Derivation.emptyList).withStringFormulaFrom(this.ex.token(i)).canonicalUtterance(this.ex.token(i)).createDerivation());
            derivs.add(new Derivation.Builder().cat("$LEMMA_TOKEN").start(i).end(i + 1).rule(Rule.nullRule).children(Derivation.emptyList).withStringFormulaFrom(this.ex.lemmaToken(i)).canonicalUtterance(this.ex.token(i)).createDerivation());
        }
        for (i = 0; i < this.numTokens; ++i) {
            for (int j = i + 1; j <= this.numTokens; ++j) {
                derivs.add(new Derivation.Builder().cat("$PHRASE").start(i).end(j).rule(Rule.nullRule).children(Derivation.emptyList).withStringFormulaFrom(this.ex.phrase(i, j)).canonicalUtterance(this.ex.phrase(i, j)).createDerivation());
                derivs.add(new Derivation.Builder().cat("$LEMMA_PHRASE").start(i).end(j).rule(Rule.nullRule).children(Derivation.emptyList).withStringFormulaFrom(this.ex.lemmaPhrase(i, j)).canonicalUtterance(this.ex.phrase(i, j)).createDerivation());
            }
        }
        return derivs;
    }

    public void ensureExecuted() {
        LogInfo.begin_track((String)"Parser.ensureExecuted", (Object[])new Object[0]);
        for (Derivation deriv : this.predDerivations) {
            deriv.ensureExecuted(this.parser.executor, this.ex.context);
            if (this.ex.targetValue != null) {
                deriv.compatibility = this.parser.valueEvaluator.getCompatibility(this.ex.targetValue, deriv.value);
            }
            if (this.computeExpectedCounts || !Parser.opts.executeTopFormulaOnly) continue;
            break;
        }
        LogInfo.end_track();
    }

    protected void setEvaluation() {
        this.evaluation.add("numTokens", (double)this.numTokens);
        this.evaluation.add("parseTime", (double)this.parseTime);
        this.evaluation.add("maxCellSize", (Object)this.maxCellDescription, (double)this.maxCellSize);
        this.evaluation.add("fallOffBeam", this.fallOffBeam);
        this.evaluation.add("totalDerivs", (double)this.totalGeneratedDerivs);
        this.evaluation.add("numOfFeaturizedDerivs", (double)this.numOfFeaturizedDerivs);
    }

    public static double compatibilityToReward(double compatibility) {
        if (Parser.opts.partialReward) {
            return compatibility;
        }
        return compatibility == 1.0 ? 1.0 : 0.0;
    }

    public static void computeExpectedCounts(List<Derivation> derivations, Map<String, Double> counts) {
        Derivation deriv;
        int i;
        int n = derivations.size();
        if (n == 0) {
            return;
        }
        double[] trueScores = new double[n];
        double[] predScores = new double[n];
        int[] goodAndBad = null;
        if (ParserState.opts.customExpectedCounts == CustomExpectedCount.TOP ? (goodAndBad = ParserState.getTopDerivations(derivations)) == null : ParserState.opts.customExpectedCounts == CustomExpectedCount.RANDOM && (goodAndBad = ParserState.getRandomDerivations(derivations)) == null) {
            return;
        }
        block5: for (i = 0; i < n; ++i) {
            deriv = derivations.get(i);
            double logReward = Math.log(ParserState.compatibilityToReward(deriv.compatibility));
            switch (ParserState.opts.customExpectedCounts) {
                case NONE: {
                    trueScores[i] = deriv.score + logReward;
                    predScores[i] = deriv.score;
                    continue block5;
                }
                case UNIFORM: {
                    trueScores[i] = logReward;
                    predScores[i] = 0.0;
                    continue block5;
                }
                case TOP: 
                case RANDOM: {
                    trueScores[i] = i == goodAndBad[0] ? 0.0 : Double.NEGATIVE_INFINITY;
                    predScores[i] = i == goodAndBad[1] ? 0.0 : Double.NEGATIVE_INFINITY;
                    continue block5;
                }
                default: {
                    throw new RuntimeException("Unknown customExpectedCounts: " + (Object)((Object)ParserState.opts.customExpectedCounts));
                }
            }
        }
        if (!NumUtils.expNormalize((double[])trueScores)) {
            return;
        }
        if (!NumUtils.expNormalize((double[])predScores)) {
            return;
        }
        for (i = 0; i < n; ++i) {
            deriv = derivations.get(i);
            double incr = trueScores[i] - predScores[i];
            if (incr == 0.0) continue;
            deriv.incrementAllFeatureVector(incr, counts);
        }
    }

    private static int[] getTopDerivations(List<Derivation> derivations) {
        int chosenGood = -1;
        int chosenBad = -1;
        double chosenGoodScore = Double.NEGATIVE_INFINITY;
        double chosenBadScore = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < derivations.size(); ++i) {
            Derivation deriv = derivations.get(i);
            if (deriv.compatibility == 1.0) {
                if (!(deriv.score > chosenGoodScore)) continue;
                chosenGood = i;
                chosenGoodScore = deriv.score;
                continue;
            }
            if (!(deriv.score > chosenBadScore)) continue;
            chosenBad = i;
            chosenBadScore = deriv.score;
        }
        if (chosenGood == -1 || chosenBad == -1 || chosenGoodScore >= chosenBadScore + ParserState.opts.contrastiveMargin) {
            return null;
        }
        return new int[]{chosenGood, chosenBad};
    }

    private static int[] getRandomDerivations(List<Derivation> derivations) {
        int[] nArray;
        int chosenGood = -1;
        int chosenBad = -1;
        int numGoodSoFar = 0;
        int numBadSoFar = 0;
        for (int i = 0; i < derivations.size(); ++i) {
            Derivation deriv = derivations.get(i);
            if (deriv.compatibility == 1.0) {
                ++numGoodSoFar;
                if (!(Math.random() <= 1.0 / (double)numGoodSoFar)) continue;
                chosenGood = i;
                continue;
            }
            ++numBadSoFar;
            if (!(Math.random() <= 1.0 / (double)numBadSoFar)) continue;
            chosenBad = i;
        }
        if (chosenGood == -1 || chosenBad == -1) {
            nArray = null;
        } else {
            int[] nArray2 = new int[2];
            nArray2[0] = chosenGood;
            nArray = nArray2;
            nArray2[1] = chosenBad;
        }
        return nArray;
    }

    public static enum CustomExpectedCount {
        NONE,
        UNIFORM,
        TOP,
        RANDOM;

    }

    public static class Options {
        @Option(gloss="Use a custom distribution for computing expected counts")
        public CustomExpectedCount customExpectedCounts = CustomExpectedCount.NONE;
        @Option(gloss="For customExpectedCounts = TOP, only update if good < bad + margin")
        public double contrastiveMargin = 1000000.0;
        @Option(gloss="Whether to prune based on probability difference")
        public boolean pruneByProbDiff = false;
        @Option(gloss="Difference in probability for pruning by prob diff")
        public double probDiffPruningThresh = 100.0;
        @Option(gloss="Throw features away after scoring to save memory")
        public boolean throwFeaturesAway = false;
    }
}

