/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import com.google.common.base.Joiner;
import com.google.common.collect.Maps;
import edu.stanford.nlp.sempre.Builder;
import edu.stanford.nlp.sempre.Dataset;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.ExampleUtils;
import edu.stanford.nlp.sempre.Formula;
import edu.stanford.nlp.sempre.LearnerParallelProcessor;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.Parser;
import edu.stanford.nlp.sempre.ParserState;
import edu.stanford.nlp.sempre.Rule;
import edu.stanford.nlp.sempre.SemanticFn;
import edu.stanford.nlp.sempre.SempreUtils;
import fig.basic.Evaluation;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.MapUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.Parallelizer;
import fig.basic.StopWatchSet;
import fig.basic.Utils;
import fig.exec.Execution;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Learner {
    public static Options opts = new Options();
    private Parser parser;
    private final Params params;
    private final Dataset dataset;
    private final PrintWriter eventsOut;
    private final List<SemanticFn> semFuncsToUpdate;

    public Learner(Parser parser, Params params, Dataset dataset) {
        this.parser = parser;
        this.params = params;
        this.dataset = dataset;
        this.eventsOut = IOUtils.openOutAppendEasy((String)Execution.getFile((String)"learner.events"));
        if (Learner.opts.initialization != null && this.params.isEmpty()) {
            this.params.init(Learner.opts.initialization);
        }
        this.semFuncsToUpdate = new ArrayList<SemanticFn>();
        for (Rule rule : parser.grammar.getRules()) {
            SemanticFn currSemFn = rule.getSem();
            boolean toAdd = true;
            for (SemanticFn semFuncToUpdate : this.semFuncsToUpdate) {
                if (!semFuncToUpdate.getClass().equals(currSemFn.getClass())) continue;
                toAdd = false;
                break;
            }
            if (!toAdd) continue;
            this.semFuncsToUpdate.add(currSemFn);
        }
    }

    public void learn() {
        this.learn(Learner.opts.maxTrainIters, Maps.newHashMap());
    }

    public void learn(int numIters, Map<String, List<Evaluation>> evaluations) {
        LogInfo.begin_track((String)"Learner.learn()", (Object[])new Object[0]);
        if (!this.params.isEmpty()) {
            this.sortOnFeedback();
        }
        for (int iter = 0; iter <= numIters; ++iter) {
            LogInfo.begin_track((String)"Iteration %s/%s", (Object[])new Object[]{iter, numIters});
            Execution.putOutput((String)"iter", (Object)iter);
            HashMap meanEvaluations = Maps.newHashMap();
            for (String group : this.dataset.groups()) {
                meanEvaluations.put(group, new Evaluation());
            }
            for (String group : this.dataset.groups()) {
                boolean updateWeights;
                boolean lastIter = iter == numIters;
                boolean bl = updateWeights = Learner.opts.updateWeights && group.equals("train") && !lastIter;
                if (Learner.opts.skipUnnecessaryGroups && (group.equals("train") && lastIter || !group.equals("train") && !lastIter)) continue;
                this.parser.onBeginDataGroup(iter, numIters, group);
                Evaluation eval = this.processExamples(iter, group, this.dataset.examples(group), updateWeights);
                MapUtils.addToList(evaluations, (Object)group, (Object)eval);
                ((Evaluation)meanEvaluations.get(group)).add(eval);
                StopWatchSet.logStats();
                this.writeParams(iter);
            }
            LogInfo.end_track();
        }
        LogInfo.end_track();
    }

    private void writeParams(int iter) {
        String path = Execution.getFile((String)("params." + iter));
        if (path != null) {
            this.params.write(path);
            Utils.systemHard((String)("ln -sf params." + iter + " " + Execution.getFile((String)"params")));
        }
    }

    public void onlineLearnExample(Example ex) {
        LogInfo.begin_track((String)"onlineLearnExample: %s derivations", (Object[])new Object[]{ex.predDerivations.size()});
        HashMap<String, Double> counts = new HashMap<String, Double>();
        for (Derivation deriv : ex.predDerivations) {
            deriv.compatibility = this.parser.valueEvaluator.getCompatibility(ex.targetValue, deriv.value);
        }
        ParserState.computeExpectedCounts(ex.predDerivations, counts);
        this.params.update(counts);
        LogInfo.end_track();
    }

    public void onlineLearnExampleByFormula(Example ex, List<Formula> formulas) {
        HashMap<String, Double> counts = new HashMap<String, Double>();
        for (Derivation deriv : ex.predDerivations) {
            deriv.compatibility = formulas.contains(deriv.formula) ? 1.0 : 0.0;
        }
        ParserState.computeExpectedCounts(ex.predDerivations, counts);
        this.params.update(counts);
    }

    private Evaluation processExamples(int iter, String group, List<Example> examples, boolean computeExpectedCounts) {
        Evaluation evaluation = new Evaluation();
        if (examples.size() == 0) {
            return evaluation;
        }
        String prefix = "iter=" + iter + "." + group;
        Execution.putOutput((String)"group", (Object)group);
        LogInfo.begin_track_printAll((String)"Processing %s: %s examples", (Object[])new Object[]{prefix, examples.size()});
        LogInfo.begin_track((String)"Examples", (Object[])new Object[0]);
        if (Learner.opts.numParallelThreads > 1) {
            Parallelizer paral = new Parallelizer(Learner.opts.numParallelThreads);
            LearnerParallelProcessor processor = new LearnerParallelProcessor(this.parser, this.params, prefix, computeExpectedCounts, evaluation);
            LogInfo.begin_threads();
            paral.process(examples, (Parallelizer.Processor)processor);
            LogInfo.end_threads();
        } else {
            HashMap<String, Double> counts = new HashMap<String, Double>();
            int batchSize = 0;
            for (int e = 0; e < examples.size(); ++e) {
                Example ex = examples.get(e);
                LogInfo.begin_track_printAll((String)"%s: example %s/%s: %s", (Object[])new Object[]{prefix, e, examples.size(), ex.id});
                ex.log();
                Execution.putOutput((String)"example", (Object)e);
                ParserState state = this.parseExample(this.params, ex, computeExpectedCounts);
                if (computeExpectedCounts) {
                    if (Learner.opts.checkGradient) {
                        LogInfo.begin_track((String)"Checking gradient", (Object[])new Object[0]);
                        this.checkGradient(ex, state);
                        LogInfo.end_track();
                    }
                    SempreUtils.addToDoubleMap(counts, state.expectedCounts);
                    if (++batchSize >= Learner.opts.batchSize) {
                        this.updateWeights(counts);
                        batchSize = 0;
                    }
                }
                LogInfo.logs((String)"Current: %s", (Object[])new Object[]{ex.evaluation.summary()});
                evaluation.add(ex.evaluation);
                LogInfo.logs((String)"Cumulative(%s): %s", (Object[])new Object[]{prefix, evaluation.summary()});
                this.printLearnerEventsIter(ex, iter, group);
                LogInfo.end_track();
                if (Learner.opts.addFeedback && computeExpectedCounts) {
                    this.addFeedback(ex);
                }
                if (Learner.opts.outputPredDerivations && Builder.opts.parser.equals("FloatingParser")) {
                    ExampleUtils.writeParaphraseSDF(iter, group, ex, Learner.opts.outputPredDerivations);
                }
                ex.predDerivations.clear();
            }
            if (computeExpectedCounts && batchSize > 0) {
                this.updateWeights(counts);
            }
        }
        this.params.finalizeWeights();
        if (Learner.opts.sortOnFeedback && computeExpectedCounts) {
            this.sortOnFeedback();
        }
        LogInfo.end_track();
        this.logEvaluationStats(evaluation, prefix);
        evaluation.putOutput(prefix.replace('.', '-'));
        this.printLearnerEventsSummary(evaluation, iter, group);
        ExampleUtils.writeEvaluationSDF(iter, group, evaluation, examples.size());
        LogInfo.end_track();
        return evaluation;
    }

    private void checkGradient(Example ex, ParserState state) {
        double eps = 0.01;
        for (String feature : state.expectedCounts.keySet()) {
            LogInfo.begin_track((String)"feature=%s", (Object[])new Object[]{feature});
            double computedGradient = state.expectedCounts.get(feature);
            Params perturbedParams = this.params.copyParams();
            perturbedParams.getWeights().put(feature, perturbedParams.getWeight(feature) + eps);
            ParserState perturbedState = this.parseExample(perturbedParams, ex, true);
            double checkedGradient = (perturbedState.objectiveValue - state.objectiveValue) / eps;
            LogInfo.logs((String)"Learner.checkGradient(): weight=%s, pertWeight=%s, obj=%s, pertObj=%s, feature=%s, computed=%s, checked=%s, diff=%s", (Object[])new Object[]{this.params.getWeight(feature), perturbedParams.getWeight(feature), state.objectiveValue, perturbedState.objectiveValue, feature, computedGradient, checkedGradient, Math.abs(checkedGradient - computedGradient)});
            LogInfo.end_track();
        }
    }

    private void sortOnFeedback() {
        for (SemanticFn semFn : this.semFuncsToUpdate) {
            semFn.sortOnFeedback(this.parser.getSearchParams(this.params));
        }
    }

    private void addFeedback(Example ex) {
        for (SemanticFn semFn : this.semFuncsToUpdate) {
            semFn.addFeedback(ex);
        }
    }

    private ParserState parseExample(Params params, Example ex, boolean computeExpectedCounts) {
        StopWatchSet.begin((String)"Parser.parse");
        ParserState res = this.parser.parse(params, ex, computeExpectedCounts);
        StopWatchSet.end();
        return res;
    }

    private void updateWeights(Map<String, Double> counts) {
        StopWatchSet.begin((String)"Learner.updateWeights");
        LogInfo.begin_track((String)"Updating learner weights", (Object[])new Object[0]);
        double sum = 0.0;
        for (double v : counts.values()) {
            sum += v * v;
        }
        if (Learner.opts.verbose >= 2) {
            SempreUtils.logMap(counts, "gradient");
        }
        LogInfo.logs((String)"L2 norm: %s", (Object[])new Object[]{Math.sqrt(sum)});
        this.params.update(counts);
        if (Learner.opts.verbose >= 2) {
            this.params.log();
        }
        counts.clear();
        LogInfo.end_track();
        StopWatchSet.end();
    }

    private void logEvaluationStats(Evaluation evaluation, String prefix) {
        LogInfo.logs((String)"Stats for %s: %s", (Object[])new Object[]{prefix, evaluation.summary()});
        evaluation.logStats(prefix);
        evaluation.putOutput(prefix);
    }

    private void printLearnerEventsIter(Example ex, int iter, String group) {
        if (this.eventsOut == null) {
            return;
        }
        ArrayList<String> fields = new ArrayList<String>();
        fields.add("iter=" + iter);
        fields.add("group=" + group);
        fields.add("utterance=" + ex.utterance);
        fields.add("targetValue=" + ex.targetValue);
        if (ex.predDerivations.size() > 0) {
            Derivation deriv = ex.predDerivations.get(0);
            fields.add("predValue=" + deriv.value);
            fields.add("predFormula=" + deriv.formula);
        }
        fields.add(ex.evaluation.summary("\t"));
        this.eventsOut.println(Joiner.on((char)'\t').join(fields));
        this.eventsOut.flush();
        if (Learner.opts.dumpFeaturesAndCompatibility) {
            for (Derivation deriv : ex.predDerivations) {
                fields = new ArrayList();
                fields.add("iter=" + iter);
                fields.add("group=" + group);
                fields.add("utterance=" + ex.utterance);
                HashMap<String, Double> features = new HashMap<String, Double>();
                deriv.incrementAllFeatureVector(1.0, features);
                for (String f : features.keySet()) {
                    double v = (Double)features.get(f);
                    fields.add(f + "=" + v);
                }
                fields.add("comp=" + deriv.compatibility);
                this.eventsOut.println(Joiner.on((char)'\t').join(fields));
            }
        }
    }

    private void printLearnerEventsSummary(Evaluation evaluation, int iter, String group) {
        if (this.eventsOut == null) {
            return;
        }
        ArrayList<String> fields = new ArrayList<String>();
        fields.add("iter=" + iter);
        fields.add("group=" + group);
        fields.add(evaluation.summary("\t"));
        this.eventsOut.println(Joiner.on((char)'\t').join(fields));
        this.eventsOut.flush();
    }

    public static class Options {
        @Option(gloss="Number of iterations to train")
        public int maxTrainIters = 0;
        @Option(gloss="When using mini-batch updates for SGD, this is the batch size")
        public int batchSize = 1;
        @Option(gloss="Write predDerivations to examples file (huge)")
        public boolean outputPredDerivations = false;
        @Option(gloss="Dump all features and compatibility scores")
        public boolean dumpFeaturesAndCompatibility = false;
        @Option(gloss="Whether to add feedback")
        public boolean addFeedback = false;
        @Option(gloss="Whether to sort on feedback")
        public boolean sortOnFeedback = true;
        @Option(gloss="Verbosity")
        public int verbose = 0;
        @Option(gloss="Initialize with these parameters")
        public List<Pair<String, Double>> initialization;
        @Option(gloss="Whether to update weights")
        public boolean updateWeights = true;
        @Option(gloss="Whether to check gradient")
        public boolean checkGradient = false;
        @Option(gloss="Whether to skip the 'train' group in the last iteration and non-'train' groups in other iterations")
        public boolean skipUnnecessaryGroups = false;
        @Option(gloss="Number of threads to parallelize")
        public int numParallelThreads = 1;
    }
}

