/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import com.google.common.base.Joiner;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.Grammar;
import edu.stanford.nlp.sempre.Parser;
import edu.stanford.nlp.sempre.Rule;
import fig.basic.LogInfo;
import fig.basic.MapUtils;
import fig.basic.Pair;
import fig.basic.StopWatch;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class CoarseParser {
    public final Grammar grammar;
    private Map<Pair<String, String>, Set<String>> rhsToLhsMap;
    ArrayList<Rule> catUnaryRules;
    Map<String, List<Rule>> terminalsToRulesList = new HashMap<String, List<Rule>>();

    public CoarseParser(Grammar grammar) {
        this.grammar = grammar;
        this.catUnaryRules = new ArrayList();
        this.rhsToLhsMap = new HashMap<Pair<String, String>, Set<String>>();
        HashMap<String, List<Rule>> graph = new HashMap<String, List<Rule>>();
        for (Rule rule : grammar.rules) {
            if (rule.rhs.size() > 2) {
                throw new RuntimeException("We assume that the grammar is binarized, rule: " + rule);
            }
            if (rule.isCatUnary()) {
                MapUtils.addToList(graph, (Object)rule.lhs, (Object)rule);
                continue;
            }
            if (rule.rhs.size() == 2) {
                MapUtils.addToSet(this.rhsToLhsMap, (Object)Pair.newPair((Object)rule.rhs.get(0), (Object)rule.rhs.get(1)), (Object)rule.lhs);
                continue;
            }
            assert (rule.isRhsTerminals());
            MapUtils.addToList(this.terminalsToRulesList, (Object)Joiner.on((char)' ').join(rule.rhs), (Object)rule);
        }
        HashMap<String, Boolean> done = new HashMap<String, Boolean>();
        for (String node : graph.keySet()) {
            this.traverse(this.catUnaryRules, node, graph, done);
        }
        LogInfo.logs((String)"Coarse parser: %d catUnaryRules (sorted), %d nonCatUnaryRules", (Object[])new Object[]{this.catUnaryRules.size(), grammar.rules.size() - this.catUnaryRules.size()});
    }

    private void traverse(List<Rule> catUnaryRules, String node, Map<String, List<Rule>> graph, Map<String, Boolean> done) {
        Boolean d = done.get(node);
        if (Boolean.TRUE.equals(d)) {
            return;
        }
        if (Boolean.FALSE.equals(d)) {
            throw new RuntimeException("Found cycle of unaries involving " + node);
        }
        done.put(node, false);
        for (Rule rule : MapUtils.getList(graph, (Object)node)) {
            this.traverse(catUnaryRules, rule.rhs.get(0), graph, done);
            catUnaryRules.add(rule);
        }
        done.put(node, true);
    }

    public CoarseParserState getCoarsePrunedChart(Example ex) {
        CoarseParserState res = new CoarseParserState(ex, this);
        res.infer();
        return res;
    }

    class CategorySpan {
        public final String cat;
        public final int start;
        public final int end;

        public CategorySpan(String cat, int start, int end) {
            this.cat = cat;
            this.start = start;
            this.end = end;
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.cat == null ? 0 : this.cat.hashCode());
            result = 31 * result + this.end;
            result = 31 * result + this.start;
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            CategorySpan other = (CategorySpan)obj;
            if (this.cat == null ? other.cat != null : !this.cat.equals(other.cat)) {
                return false;
            }
            if (this.end != other.end) {
                return false;
            }
            return this.start == other.start;
        }
    }

    class CoarseParserState {
        private Map<String, List<CategorySpan>>[][] chart;
        public final Example example;
        public final CoarseParser parser;
        private int numTokens;
        private long time;
        private String[][] phrases;

        public CoarseParserState(Example example, CoarseParser parser) {
            this.example = example;
            this.parser = parser;
            this.numTokens = example.numTokens();
            this.chart = (HashMap[][])Array.newInstance(HashMap.class, this.numTokens, this.numTokens + 1);
            this.phrases = new String[this.numTokens][this.numTokens + 1];
            for (int start = 0; start < this.numTokens; ++start) {
                StringBuilder sb = new StringBuilder();
                for (int end = start + 1; end <= this.numTokens; ++end) {
                    if (end - start > 1) {
                        sb.append(' ');
                    }
                    sb.append(example.languageInfo.tokens.get(end - 1));
                    this.phrases[start][end] = sb.toString();
                    this.chart[start][end] = new HashMap<String, List<CategorySpan>>();
                }
            }
        }

        public long getCoarseParseTime() {
            return this.time;
        }

        public void infer() {
            StopWatch watch = new StopWatch();
            watch.start();
            this.parseTokensAndPhrases();
            for (int len = 1; len <= this.numTokens; ++len) {
                int i = 0;
                while (i + len <= this.numTokens) {
                    this.build(i, i + len);
                    ++i;
                }
            }
            this.keepTopDownReachable();
            watch.stop();
            this.time = watch.getCurrTimeLong();
        }

        public boolean coarseAllows(String cat, int start, int end) {
            return this.chart[start][end].containsKey(cat);
        }

        private void build(int start, int end) {
            this.handleBinaryRules(start, end);
            this.handleUnaryRules(start, end);
        }

        private void parseTokensAndPhrases() {
            int i;
            for (i = 0; i < this.numTokens; ++i) {
                this.addToChart("$TOKEN", i, i + 1);
                this.addToChart("$LEMMA_TOKEN", i, i + 1);
            }
            for (i = 0; i < this.numTokens; ++i) {
                for (int j = i + 1; j <= this.numTokens; ++j) {
                    this.addToChart("$PHRASE", i, j);
                    this.addToChart("$LEMMA_PHRASE", i, j);
                }
            }
        }

        private void addToChart(String cat, int start, int end) {
            if (Parser.opts.verbose >= 5) {
                LogInfo.logs((String)"Adding to chart %s(%s,%s)", (Object[])new Object[]{cat, start, end});
            }
            MapUtils.putIfAbsent(this.chart[start][end], (Object)cat, new ArrayList());
        }

        private void addToChart(String parentCat, String childCat, int start, int end) {
            if (Parser.opts.verbose >= 5) {
                LogInfo.logs((String)"Adding to chart %s(%s,%s)-->%s(%s,%s)", (Object[])new Object[]{parentCat, start, end, childCat, start, end});
            }
            MapUtils.addToList(this.chart[start][end], (Object)parentCat, (Object)new CategorySpan(childCat, start, end));
        }

        private void addToChart(String parentCat, String leftCat, String rightCat, int start, int i, int end) {
            if (Parser.opts.verbose >= 5) {
                LogInfo.logs((String)"Adding to chart %s(%s,%s)-->%s(%s,%s) %s(%s,%s)", (Object[])new Object[]{parentCat, start, end, leftCat, start, i, rightCat, i, end});
            }
            MapUtils.addToList(this.chart[start][end], (Object)parentCat, (Object)new CategorySpan(leftCat, start, i));
            MapUtils.addToList(this.chart[start][end], (Object)parentCat, (Object)new CategorySpan(rightCat, i, end));
        }

        private void handleBinaryRules(int start, int end) {
            for (int i = start + 1; i < end; ++i) {
                ArrayList<String> left = new ArrayList<String>(this.chart[start][i].keySet());
                ArrayList<String> right = new ArrayList<String>(this.chart[i][end].keySet());
                if (i - start == 1) {
                    left.add(this.phrases[start][i]);
                }
                if (end - i == 1) {
                    right.add(this.phrases[i][end]);
                }
                for (String l : left) {
                    for (String r : right) {
                        Set parentCats = (Set)CoarseParser.this.rhsToLhsMap.get(Pair.newPair((Object)l, (Object)r));
                        if (parentCats == null) continue;
                        for (String parentCat : parentCats) {
                            this.addToChart(parentCat, l, r, start, i, end);
                        }
                    }
                }
            }
        }

        private void handleUnaryRules(int start, int end) {
            for (Rule rule : (List)MapUtils.get(CoarseParser.this.terminalsToRulesList, (Object)this.phrases[start][end], Collections.emptyList())) {
                this.addToChart(rule.lhs, start, end);
            }
            for (Rule rule : this.parser.catUnaryRules) {
                String rhsCat = rule.rhs.get(0);
                if (!this.chart[start][end].containsKey(rhsCat)) continue;
                this.addToChart(rule.lhs, rhsCat, start, end);
            }
        }

        public void keepTopDownReachable() {
            if (this.numTokens == 0) {
                return;
            }
            HashSet<CategorySpan> reachable = new HashSet<CategorySpan>();
            this.collectReachable(reachable, new CategorySpan("$ROOT", 0, this.numTokens));
            for (int start = 0; start < this.numTokens; ++start) {
                for (int end = start + 1; end <= this.numTokens; ++end) {
                    LinkedList<String> toRemoveCats = new LinkedList<String>();
                    for (String cat : this.chart[start][end].keySet()) {
                        if (reachable.contains(new CategorySpan(cat, start, end))) continue;
                        toRemoveCats.add(cat);
                    }
                    Collections.sort(toRemoveCats);
                    for (String cat : toRemoveCats) {
                        if (Parser.opts.verbose >= 5) {
                            LogInfo.logs((String)"Pruning chart %s(%s,%s)", (Object[])new Object[]{cat, start, end});
                        }
                        this.chart[start][end].remove(cat);
                    }
                }
            }
        }

        private void collectReachable(Set<CategorySpan> reachable, CategorySpan catSpan) {
            if (reachable.contains(catSpan)) {
                return;
            }
            if (!this.chart[catSpan.start][catSpan.end].containsKey(catSpan.cat)) {
                return;
            }
            reachable.add(catSpan);
            for (CategorySpan childCatSpan : this.chart[catSpan.start][catSpan.end].get(catSpan.cat)) {
                this.collectReachable(reachable, childCatSpan);
            }
        }
    }
}

