/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import edu.stanford.nlp.sempre.ContextValue;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.Executor;
import edu.stanford.nlp.sempre.FeatureExtractor;
import edu.stanford.nlp.sempre.FeatureVector;
import edu.stanford.nlp.sempre.Formula;
import edu.stanford.nlp.sempre.Grammar;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.ParserState;
import edu.stanford.nlp.sempre.Rule;
import edu.stanford.nlp.sempre.ValueEvaluator;
import fig.basic.Evaluation;
import fig.basic.Fmt;
import fig.basic.LogInfo;
import fig.basic.MapUtils;
import fig.basic.Option;
import fig.basic.StopWatch;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

public abstract class Parser {
    public static final Options opts = new Options();
    public final Grammar grammar;
    public final FeatureExtractor extractor;
    public final Executor executor;
    public final ValueEvaluator valueEvaluator;
    protected List<Rule> catUnaryRules;
    public PrintWriter chartFillOut = null;

    public boolean verbose(int level) {
        return Parser.opts.verbose >= level;
    }

    public List<Rule> getCatUnaryRules() {
        return this.catUnaryRules;
    }

    public Parser(Spec spec) {
        this.grammar = spec.grammar;
        this.extractor = spec.extractor;
        this.executor = spec.executor;
        this.valueEvaluator = spec.valueEvaluator;
        this.computeCatUnaryRules();
        LogInfo.logs((String)"Parser: %d catUnaryRules (sorted), %d nonCatUnaryRules (in trie)", (Object[])new Object[]{this.catUnaryRules.size(), this.grammar.rules.size() - this.catUnaryRules.size()});
    }

    public synchronized void addRule(Rule rule) {
        if (rule.isCatUnary()) {
            this.catUnaryRules.add(rule);
        }
    }

    protected void computeCatUnaryRules() {
        this.catUnaryRules = new ArrayList<Rule>();
        HashMap<String, List<Rule>> graph = new HashMap<String, List<Rule>>();
        for (Rule rule : this.grammar.rules) {
            if (!rule.isCatUnary()) continue;
            MapUtils.addToList(graph, (Object)rule.lhs, (Object)rule);
        }
        HashMap<String, Boolean> done = new HashMap<String, Boolean>();
        for (String node : graph.keySet()) {
            this.traverse(this.catUnaryRules, node, graph, done);
        }
    }

    protected void traverse(List<Rule> catUnaryRules, String node, Map<String, List<Rule>> graph, Map<String, Boolean> done) {
        Boolean d = done.get(node);
        if (Boolean.TRUE.equals(d)) {
            return;
        }
        if (Boolean.FALSE.equals(d)) {
            throw new RuntimeException("Found cycle of unaries involving " + node);
        }
        done.put(node, false);
        for (Rule rule : MapUtils.getList(graph, (Object)node)) {
            this.traverse(catUnaryRules, rule.rhs.get(0), graph, done);
            catUnaryRules.add(rule);
        }
        done.put(node, true);
    }

    public void onBeginDataGroup(int iter, int numIters, String group) {
    }

    public abstract ParserState newParserState(Params var1, Example var2, boolean var3);

    public Params getSearchParams(Params params) {
        return params;
    }

    public ParserState parse(Params params, Example ex, boolean computeExpectedCounts) {
        if (ex.targetFormula != null && ex.targetValue == null) {
            ex.targetValue = this.executor.execute((Formula)ex.targetFormula, (ContextValue)ex.context).value;
        }
        StopWatch watch = new StopWatch();
        watch.start();
        LogInfo.begin_track_printAll((String)"Parser.parse: parse", (Object[])new Object[0]);
        ParserState state = this.newParserState(params, ex, computeExpectedCounts);
        state.infer();
        LogInfo.end_track();
        watch.stop();
        state.parseTime = watch.getCurrTimeLong();
        state.setEvaluation();
        ex.predDerivations = state.predDerivations;
        Derivation.sortByScore(ex.predDerivations);
        if (Parser.opts.callSetEvaluation) {
            ex.evaluation = new Evaluation();
            this.addToEvaluation(state, ex.evaluation);
        }
        ex.clearTempState();
        for (Derivation deriv : ex.predDerivations) {
            deriv.clearTempState();
        }
        return state;
    }

    public void addToEvaluation(ParserState state, Evaluation evaluation) {
        Derivation deriv4;
        int i;
        int numTop;
        Example ex = state.ex;
        List<Derivation> predDerivations = state.predDerivations;
        boolean printAllPredictions = Parser.opts.printAllPredictions;
        int numCandidates = predDerivations.size();
        LogInfo.begin_track_printAll((String)"Parser.setEvaluation: %d candidates", (Object[])new Object[]{numCandidates});
        int correctIndex = -1;
        int correctIndexAfterParse = -1;
        double maxCompatibility = 0.0;
        double[] compatibilities = null;
        if (ex.targetValue != null) {
            int i2;
            compatibilities = new double[numCandidates];
            for (i2 = 0; i2 < numCandidates; ++i2) {
                Derivation deriv2 = predDerivations.get(i2);
                compatibilities[i2] = deriv2.compatibility;
                if (compatibilities[i2] == 1.0 && correctIndex == -1) {
                    correctIndex = i2;
                }
                maxCompatibility = Math.max(compatibilities[i2], maxCompatibility);
            }
            for (i2 = 0; i2 < numCandidates; ++i2) {
                if (compatibilities[i2] != 1.0) continue;
                correctIndexAfterParse = i2;
                break;
            }
        }
        double[] probs = Derivation.getProbs(predDerivations, 1.0);
        for (int i3 = 0; i3 < numCandidates; ++i3) {
            Derivation deriv3 = predDerivations.get(i3);
            deriv3.prob = probs[i3];
        }
        double topMass = 0.0;
        if (ex.targetValue != null) {
            for (numTop = 0; numTop < numCandidates && Math.abs(predDerivations.get((int)numTop).score - predDerivations.get((int)0).score) < 1.0E-10; ++numTop) {
                topMass += probs[numTop];
            }
        }
        double correct = 0.0;
        double partCorrect = 0.0;
        if (ex.targetValue != null) {
            for (int i4 = 0; i4 < numTop; ++i4) {
                if (compatibilities[i4] == 1.0) {
                    correct += probs[i4] / topMass;
                }
                if (!(compatibilities[i4] > 0.0)) continue;
                partCorrect += compatibilities[i4] * probs[i4] / topMass;
            }
        }
        if (correctIndex != -1 && correct != 1.0) {
            Derivation trueDeriv = predDerivations.get(correctIndex);
            Derivation predDeriv = predDerivations.get(0);
            HashMap<String, Double> featureDiff = new HashMap<String, Double>();
            trueDeriv.incrementAllFeatureVector(1.0, featureDiff);
            predDeriv.incrementAllFeatureVector(-1.0, featureDiff);
            String heading = String.format("TopTrue (%d) - Pred (%d) = Diff", correctIndex, 0);
            FeatureVector.logFeatureWeights(heading, featureDiff, state.params);
            LinkedHashMap<String, Integer> choiceDiff = new LinkedHashMap<String, Integer>();
            trueDeriv.incrementAllChoices(1, choiceDiff);
            predDeriv.incrementAllChoices(-1, choiceDiff);
            FeatureVector.logChoices(heading, choiceDiff);
        }
        int numPrintedSoFar = 0;
        for (i = 0; i < predDerivations.size(); ++i) {
            boolean print;
            deriv4 = predDerivations.get(i);
            if (compatibilities == null || compatibilities[i] != 1.0) continue;
            boolean bl = print = printAllPredictions || numPrintedSoFar < Parser.opts.maxPrintedTrue;
            if (!print) continue;
            LogInfo.logs((String)"True@%04d: %s [score=%s, prob=%s%s]", (Object[])new Object[]{i, deriv4.toString(), Fmt.D((double)deriv4.score), Fmt.D((double)probs[i]), compatibilities != null ? ", comp=" + Fmt.D((double)compatibilities[i]) : ""});
            ++numPrintedSoFar;
            if (!Parser.opts.dumpAllFeatures) continue;
            FeatureVector.logFeatureWeights("Features", deriv4.getAllFeatureVector(), state.params);
        }
        numPrintedSoFar = 0;
        for (i = 0; i < predDerivations.size(); ++i) {
            boolean print;
            deriv4 = predDerivations.get(i);
            if (compatibilities == null || !(compatibilities[i] > 0.0) || !(compatibilities[i] < 1.0)) continue;
            boolean bl = print = printAllPredictions || numPrintedSoFar < Parser.opts.maxPrintedTrue;
            if (!print) continue;
            LogInfo.logs((String)"Part@%04d: %s [score=%s, prob=%s%s]", (Object[])new Object[]{i, deriv4.toString(), Fmt.D((double)deriv4.score), Fmt.D((double)probs[i]), compatibilities != null ? ", comp=" + Fmt.D((double)compatibilities[i]) : ""});
            ++numPrintedSoFar;
            if (!Parser.opts.dumpAllFeatures) continue;
            FeatureVector.logFeatureWeights("Features", deriv4.getAllFeatureVector(), state.params);
        }
        for (i = 0; i < predDerivations.size(); ++i) {
            boolean print;
            deriv4 = predDerivations.get(i);
            boolean bl = print = printAllPredictions || (probs[i] >= probs[0] / 2.0 || i < 10) && i < Parser.opts.maxPrintedPredictions;
            if (!print) continue;
            LogInfo.logs((String)"Pred@%04d: %s [score=%s, prob=%s%s]", (Object[])new Object[]{i, deriv4.toString(), Fmt.D((double)deriv4.score), Fmt.D((double)probs[i]), compatibilities != null ? ", comp=" + Fmt.D((double)compatibilities[i]) : ""});
            if (!Parser.opts.dumpAllFeatures) continue;
            FeatureVector.logFeatureWeights("Features", deriv4.getAllFeatureVector(), state.params);
        }
        evaluation.add("correct", correct);
        evaluation.add("oracle", correctIndex != -1);
        evaluation.add("partCorrect", partCorrect);
        evaluation.add("partOracle", maxCompatibility);
        if (correctIndexAfterParse != -1) {
            evaluation.add("correctIndexAfterParse", (double)correctIndexAfterParse);
        }
        if (correctIndex != -1) {
            evaluation.add("correctMaxBeamPosition", (double)predDerivations.get((int)correctIndex).maxBeamPosition);
            evaluation.add("correctMaxUnsortedBeamPosition", (double)predDerivations.get((int)correctIndex).maxUnsortedBeamPosition);
        }
        evaluation.add("parsed", numCandidates > 0);
        evaluation.add("numCandidates", (double)numCandidates);
        if (numCandidates > 0) {
            evaluation.add("parsedNumCandidates", (double)numCandidates);
        }
        evaluation.add(state.evaluation);
        for (Derivation deriv4 : predDerivations) {
            if (deriv4.executorStats == null) continue;
            evaluation.add(deriv4.executorStats);
        }
        LogInfo.end_track();
    }

    public static class Spec {
        public final Grammar grammar;
        public final FeatureExtractor extractor;
        public final Executor executor;
        public final ValueEvaluator valueEvaluator;

        public Spec(Grammar grammar, FeatureExtractor extractor, Executor executor, ValueEvaluator valueEvaluator) {
            this.grammar = grammar;
            this.extractor = extractor;
            this.executor = executor;
            this.valueEvaluator = valueEvaluator;
        }
    }

    public static class Options {
        @Option(gloss="For debugging, whether to print out all the predicted derivations")
        public boolean printAllPredictions;
        @Option(gloss="Maximal number of predictions to print")
        public int maxPrintedPredictions = Integer.MAX_VALUE;
        @Option(gloss="Maximal number of correct predictions to print")
        public int maxPrintedTrue = Integer.MAX_VALUE;
        @Option(gloss="Use a coarse pass to prune the chart before full parsing")
        public boolean coarsePrune = false;
        @Option(gloss="How much output to print")
        public int verbose = 0;
        @Option(gloss="Execute only top formula to be cheap (hack at test time for fast demo)")
        public boolean executeTopFormulaOnly = false;
        @Option(gloss="Whether to output chart filling visualization (huge file!)")
        public boolean visualizeChartFilling = false;
        @Option(gloss="Keep this number of derivations per cell (exact use depends on the parser)")
        public int beamSize = 200;
        @Option(gloss="Whether to update based on partial reward (for learning)")
        public boolean partialReward = true;
        @Option(gloss="Whether to unroll derivation streams (applies to lazy parsers)")
        public boolean unrollStream = false;
        @Option(gloss="Inject random noise into the score (to mix things up a bit)")
        public double derivationScoreNoise = 0.0;
        @Option(gloss="Source of random noise")
        public Random derivationScoreRandom = new Random(1L);
        @Option(gloss="Prune away error denotations")
        public boolean pruneErrorValues = false;
        @Option(gloss="Dump all features (for debugging)")
        public boolean dumpAllFeatures = false;
        @Option(gloss="Call SetEvaluation during parsing")
        public boolean callSetEvaluation = true;
    }
}

