/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import edu.stanford.nlp.sempre.BeamParser;
import edu.stanford.nlp.sempre.ChartParserState;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.DerivationStream;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.Formula;
import edu.stanford.nlp.sempre.Json;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.Parser;
import edu.stanford.nlp.sempre.ParserState;
import edu.stanford.nlp.sempre.Rule;
import edu.stanford.nlp.sempre.SemanticFn;
import edu.stanford.nlp.sempre.Trie;
import fig.basic.IntRef;
import fig.basic.LogInfo;
import fig.basic.SetUtils;
import fig.basic.StopWatchSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

class BeamParserState
extends ChartParserState {
    public final Mode mode;
    private final BeamParser parser;
    private final BeamParserState coarseState;

    public BeamParserState(BeamParser parser, Params params, Example ex, boolean computeExpectedCounts, Mode mode, BeamParserState coarseState) {
        super(parser, params, ex, computeExpectedCounts);
        this.parser = parser;
        this.mode = mode;
        this.coarseState = coarseState;
    }

    @Override
    public void infer() {
        if (this.numTokens == 0) {
            return;
        }
        if (this.parser.verbose(2)) {
            LogInfo.begin_track((String)"ParserState.infer", (Object[])new Object[0]);
        }
        for (Derivation deriv : this.gatherTokenAndPhraseDerivations()) {
            this.featurizeAndScoreDerivation(deriv);
            this.addToChart(deriv);
        }
        for (int len = 1; len <= this.numTokens; ++len) {
            int i = 0;
            while (i + len <= this.numTokens) {
                this.build(i, i + len);
                ++i;
            }
        }
        if (this.parser.verbose(2)) {
            LogInfo.end_track();
        }
        if (this.parser.chartFillOut != null && Parser.opts.visualizeChartFilling && this.mode != Mode.bool) {
            this.parser.chartFillOut.println(Json.writeValueAsStringHard(new ChartParserState.ChartFillingData(this.ex.id, this.chartFillingList, this.ex.utterance, this.ex.numTokens())));
            this.parser.chartFillOut.flush();
        }
        this.setPredDerivations();
        if (this.mode == Mode.full) {
            this.ensureExecuted();
            if (this.computeExpectedCounts) {
                this.expectedCounts = new HashMap();
                ParserState.computeExpectedCounts(this.predDerivations, this.expectedCounts);
            }
        }
    }

    protected void build(int start, int end) {
        this.applyNonCatUnaryRules(start, end, start, this.parser.trie, new ArrayList<Derivation>(), new IntRef(0));
        HashSet<String> cellsPruned = new HashSet<String>();
        this.applyCatUnaryRules(start, end, cellsPruned);
        for (Map.Entry entry : this.chart[start][end].entrySet()) {
            this.pruneCell(cellsPruned, (String)entry.getKey(), start, end, (List)entry.getValue());
        }
    }

    private static String cellString(String cat, int start, int end) {
        return cat + ":" + start + ":" + end;
    }

    private int applyRule(int start, int end, Rule rule, List<Derivation> children) {
        if (Parser.opts.verbose >= 5) {
            LogInfo.logs((String)"applyRule %s %s %s %s", (Object[])new Object[]{start, end, rule, children});
        }
        try {
            if (this.mode == Mode.full) {
                StopWatchSet.begin((String)rule.getSemRepn());
                DerivationStream results = rule.sem.call(this.ex, new SemanticFn.CallInfo(rule.lhs, start, end, rule, (List<Derivation>)ImmutableList.copyOf(children)));
                StopWatchSet.end();
                while (results.hasNext()) {
                    Derivation newDeriv = (Derivation)results.next();
                    this.featurizeAndScoreDerivation(newDeriv);
                    this.addToChart(newDeriv);
                }
                return results.estimatedSize();
            }
            if (this.mode == Mode.bool) {
                Derivation deriv = new Derivation.Builder().cat(rule.lhs).start(start).end(end).rule(rule).children((List<Derivation>)ImmutableList.copyOf(children)).formula(Formula.nullFormula).createDerivation();
                this.addToChart(deriv);
                return 1;
            }
            throw new RuntimeException("Invalid mode");
        }
        catch (Exception e) {
            LogInfo.errors((String)"Composition failed: rule = %s, children = %s", (Object[])new Object[]{rule, children});
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    protected void pruneCell(Set<String> cellsPruned, String cat, int start, int end, List<Derivation> derivations) {
        String cell = BeamParserState.cellString(cat, start, end);
        if (cellsPruned.contains(cell)) {
            return;
        }
        cellsPruned.add(cell);
        this.pruneCell(cell, derivations);
    }

    private void applyCatUnaryRules(int start, int end, Set<String> cellsPruned) {
        for (Rule rule : this.parser.catUnaryRules) {
            if (!this.coarseAllows(rule.lhs, start, end)) continue;
            String rhsCat = rule.rhs.get(0);
            List derivations = (List)this.chart[start][end].get(rhsCat);
            if (Parser.opts.verbose >= 5) {
                LogInfo.logs((String)"applyCatUnaryRules %s %s %s %s", (Object[])new Object[]{start, end, rule, derivations});
            }
            if (derivations == null) continue;
            this.pruneCell(cellsPruned, rhsCat, start, end, derivations);
            for (Derivation deriv : derivations) {
                this.applyRule(start, end, rule, Collections.singletonList(deriv));
            }
        }
    }

    private void applyNonCatUnaryRules(int start, int end, int i, Trie node, ArrayList<Derivation> children, IntRef numNew) {
        if (node == null) {
            return;
        }
        if (!this.coarseAllows(node, start, end)) {
            return;
        }
        if (Parser.opts.verbose >= 5) {
            LogInfo.logs((String)"applyNonCatUnaryRules(start=%d, end=%d, i=%d, children=[%s], %s rules)", (Object[])new Object[]{start, end, i, Joiner.on((String)", ").join(children), node.rules.size()});
        }
        if (i == end) {
            for (Rule rule : node.rules) {
                if (!this.coarseAllows(rule.lhs, start, end)) continue;
                numNew.value += this.applyRule(start, end, rule, children);
            }
            return;
        }
        this.applyNonCatUnaryRules(start, end, i + 1, node.next(this.ex.token(i)), children, numNew);
        for (int j = i + 1; j <= end; ++j) {
            block2: for (Map.Entry entry : this.chart[i][j].entrySet()) {
                Trie nextNode = node.next((String)entry.getKey());
                for (Derivation arg : (List)entry.getValue()) {
                    children.add(arg);
                    this.applyNonCatUnaryRules(start, end, j, nextNode, children, numNew);
                    children.remove(children.size() - 1);
                    if (this.mode != Mode.full) continue block2;
                    if (numNew.value < BeamParser.opts.maxNewTreesPerSpan) continue;
                    return;
                }
            }
        }
    }

    public void keepTopDownReachable() {
        if (this.numTokens == 0) {
            return;
        }
        HashSet<String> reachable = new HashSet<String>();
        this.collectReachable(reachable, "$ROOT", 0, this.numTokens);
        for (int start = 0; start < this.numTokens; ++start) {
            for (int end = start + 1; end <= this.numTokens; ++end) {
                LinkedList<String> toRemoveCats = new LinkedList<String>();
                for (String cat : this.chart[start][end].keySet()) {
                    String key = this.catStartEndKey(cat, start, end);
                    if (reachable.contains(key)) continue;
                    toRemoveCats.add(cat);
                }
                Collections.sort(toRemoveCats);
                for (String cat : toRemoveCats) {
                    if (this.parser.verbose(4)) {
                        LogInfo.logs((String)"Pruning chart %s(%s,%s)", (Object[])new Object[]{cat, start, end});
                    }
                    this.chart[start][end].remove(cat);
                }
            }
        }
    }

    private void collectReachable(Set<String> reachable, String cat, int start, int end) {
        String key = this.catStartEndKey(cat, start, end);
        if (reachable.contains(key)) {
            return;
        }
        if (!this.chart[start][end].containsKey(cat)) {
            return;
        }
        reachable.add(key);
        for (Derivation deriv : (List)this.chart[start][end].get(cat)) {
            for (Derivation subderiv : deriv.children) {
                this.collectReachable(reachable, subderiv.cat, subderiv.start, subderiv.end);
            }
        }
    }

    private String catStartEndKey(String cat, int start, int end) {
        return cat + ":" + start + ":" + end;
    }

    protected boolean coarseAllows(Trie node, int start, int end) {
        if (this.coarseState == null) {
            return true;
        }
        return SetUtils.intersects(node.cats, this.coarseState.chart[start][end].keySet());
    }

    protected boolean coarseAllows(String cat, int start, int end) {
        if (this.coarseState == null) {
            return true;
        }
        return this.coarseState.chart[start][end].containsKey(cat);
    }

    public static enum Mode {
        bool,
        full;

    }
}

