/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.collect.Lists;
import edu.stanford.nlp.sempre.ContextValue;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.Json;
import edu.stanford.nlp.sempre.KnowledgeGraph;
import edu.stanford.nlp.sempre.NaiveKnowledgeGraph;
import fig.basic.IOUtils;
import fig.basic.LispTree;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.StatFig;
import fig.exec.Execution;
import fig.prob.SampleUtils;
import java.io.File;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

public class Dataset {
    public static Options opts = new Options();
    private LinkedHashMap<String, List<Example>> allExamples = new LinkedHashMap();
    private final HashSet<String> tokenTypes = new HashSet();
    private final StatFig numTokensFig = new StatFig();

    public Set<String> groups() {
        return this.allExamples.keySet();
    }

    public List<Example> examples(String group) {
        return this.allExamples.get(group);
    }

    @JsonProperty(value="groups")
    public List<GroupInfo> getAllGroupInfos() {
        ArrayList all = Lists.newArrayList();
        for (Map.Entry<String, List<Example>> entry : this.allExamples.entrySet()) {
            all.add(new GroupInfo(entry.getKey(), entry.getValue()));
        }
        return all;
    }

    @JsonCreator
    public static Dataset fromGroupInfos(@JsonProperty(value="groups") List<GroupInfo> groups) {
        Dataset d = new Dataset();
        d.readFromGroupInfos(groups);
        return d;
    }

    public void read() {
        this.readFromPathPairs(Dataset.opts.inPaths);
    }

    public void readFromPathPairs(List<Pair<String, String>> pathPairs) {
        for (Pair<String, String> pathPair : pathPairs) {
            if (!((String)pathPair.getSecond()).endsWith(".json")) continue;
            this.readJsonFromPathPairs(pathPairs);
            return;
        }
        this.readLispTreeFromPathPairs(pathPairs);
        this.updateGlobalContext();
    }

    private void updateGlobalContext() {
        if (Dataset.opts.globalGraphPath != null) {
            KnowledgeGraph graph = NaiveKnowledgeGraph.fromFile(Dataset.opts.globalGraphPath);
            for (String group : this.allExamples.keySet()) {
                for (Example ex : this.allExamples.get(group)) {
                    ex.setContext(new ContextValue(graph));
                }
            }
        }
    }

    private void readJsonFromPathPairs(List<Pair<String, String>> pathPairs) {
        ArrayList groups = Lists.newArrayListWithCapacity((int)pathPairs.size());
        for (Pair<String, String> pathPair : pathPairs) {
            String group = (String)pathPair.getFirst();
            String path = (String)pathPair.getSecond();
            List<Example> examples = Json.readValueHard((Reader)IOUtils.openInHard((String)path), new TypeReference<List<Example>>(){});
            GroupInfo gi = new GroupInfo(group, examples);
            gi.path = path;
            groups.add(gi);
        }
        this.readFromGroupInfos(groups);
    }

    private void readFromGroupInfos(List<GroupInfo> groupInfos) {
        LogInfo.begin_track_printAll((String)"Dataset.read", (Object[])new Object[0]);
        for (GroupInfo groupInfo : groupInfos) {
            int maxExamples = Dataset.getMaxExamplesForGroup(groupInfo.group);
            List<Example> examples = this.allExamples.get(groupInfo.group);
            if (examples == null) {
                examples = new ArrayList<Example>();
                this.allExamples.put(groupInfo.group, examples);
            }
            this.readHelper(groupInfo.examples, maxExamples, examples, groupInfo.path);
        }
        if (Dataset.opts.splitDevFromTrain) {
            this.splitDevFromTrain();
        }
        this.collectStats();
        LogInfo.end_track();
    }

    private void splitDevFromTrain() {
        List<Example> origTrainExamples = this.allExamples.get("train");
        if (origTrainExamples != null) {
            int i;
            int split1 = (int)(Dataset.opts.trainFrac * (double)origTrainExamples.size());
            int split2 = (int)((1.0 - Dataset.opts.devFrac) * (double)origTrainExamples.size());
            int[] perm = SampleUtils.samplePermutation((Random)Dataset.opts.splitRandom, (int)origTrainExamples.size());
            ArrayList<Example> trainExamples = new ArrayList<Example>();
            this.allExamples.put("train", trainExamples);
            List<Example> devExamples = this.allExamples.get("dev");
            if (devExamples == null) {
                LinkedHashMap<String, List<Example>> newAllExamples = new LinkedHashMap<String, List<Example>>();
                for (Map.Entry<String, List<Example>> entry : this.allExamples.entrySet()) {
                    newAllExamples.put(entry.getKey(), entry.getValue());
                    if (!entry.getKey().equals("train")) continue;
                    devExamples = new ArrayList<Example>();
                    newAllExamples.put("dev", devExamples);
                }
                this.allExamples = newAllExamples;
            }
            for (i = 0; i < split1; ++i) {
                trainExamples.add(origTrainExamples.get(perm[i]));
            }
            for (i = split2; i < origTrainExamples.size(); ++i) {
                devExamples.add(origTrainExamples.get(perm[i]));
            }
        }
    }

    private void readHelper(List<Example> incoming, int maxExamples, List<Example> examples, String path) {
        if (examples.size() >= maxExamples) {
            return;
        }
        int i = 0;
        for (Example ex : incoming) {
            if (examples.size() >= maxExamples) break;
            if (ex.id == null) {
                String id = (path != null ? path : "<nopath>") + ":" + i;
                ex = new Example.Builder().withExample(ex).setId(id).createExample();
            }
            ++i;
            ex.preprocess();
            if (ex.numTokens() > Dataset.opts.maxTokens) continue;
            LogInfo.logs((String)"Example %s (%d): %s => %s", (Object[])new Object[]{ex.id, examples.size(), ex.getTokens(), ex.targetValue});
            examples.add(ex);
            this.numTokensFig.add((double)ex.numTokens());
            for (String token : ex.getTokens()) {
                this.tokenTypes.add(token);
            }
        }
    }

    private void readLispTreeFromPathPairs(List<Pair<String, String>> pathPairs) {
        LogInfo.begin_track_printAll((String)"Dataset.read", (Object[])new Object[0]);
        for (Pair<String, String> pathPair : pathPairs) {
            String group = (String)pathPair.getFirst();
            String path = (String)pathPair.getSecond();
            int maxExamples = Dataset.getMaxExamplesForGroup(group);
            List<Example> examples = this.allExamples.get(group);
            if (examples == null) {
                examples = new ArrayList<Example>();
                this.allExamples.put(group, examples);
            }
            this.readLispTreeHelper(path, maxExamples, examples);
        }
        if (Dataset.opts.splitDevFromTrain) {
            this.splitDevFromTrain();
        }
        LogInfo.end_track();
    }

    private void readLispTreeHelper(String path, int maxExamples, List<Example> examples) {
        if (examples.size() >= maxExamples) {
            return;
        }
        LogInfo.begin_track((String)"Reading %s", (Object[])new Object[]{path});
        Iterator trees = LispTree.proto.parseFromFile(path);
        int n = 0;
        while (examples.size() < maxExamples && trees.hasNext()) {
            LispTree tree = (LispTree)trees.next();
            if (tree.children.size() < 2 || !"example".equals(((LispTree)tree.child((int)0)).value)) {
                if ("metadata".equals(((LispTree)tree.child((int)0)).value)) continue;
                throw new RuntimeException("Invalid example: " + tree);
            }
            Example ex = Example.fromLispTree(tree, path + ":" + n);
            ++n;
            ex.preprocess();
            if (ex.numTokens() > Dataset.opts.maxTokens) continue;
            LogInfo.logs((String)"Example %s (%d): %s => %s", (Object[])new Object[]{ex.id, examples.size(), ex.getTokens(), ex.targetValue});
            examples.add(ex);
            this.numTokensFig.add((double)ex.numTokens());
            for (String token : ex.getTokens()) {
                this.tokenTypes.add(token);
            }
        }
        LogInfo.end_track();
    }

    private void collectStats() {
        LogInfo.begin_track_printAll((String)"Dataset stats", (Object[])new Object[0]);
        Execution.putLogRec((String)"numTokenTypes", (Object)this.tokenTypes.size());
        Execution.putLogRec((String)"numTokensPerExample", (Object)this.numTokensFig);
        for (Map.Entry<String, List<Example>> e : this.allExamples.entrySet()) {
            Execution.putLogRec((String)("numExamples." + e.getKey()), (Object)e.getValue().size());
        }
        LogInfo.end_track();
    }

    public static int getMaxExamplesForGroup(String group) {
        int maxExamples = Integer.MAX_VALUE;
        for (Pair<String, Integer> maxPair : Dataset.opts.maxExamples) {
            if (!((String)maxPair.getFirst()).equals(group)) continue;
            maxExamples = (Integer)maxPair.getSecond();
        }
        return maxExamples;
    }

    public static void appendExampleToFile(String path, Example ex) {
        List<Example> examples = new File(path).exists() ? Json.readValueHard((Reader)IOUtils.openInHard((String)path), new TypeReference<List<Example>>(){}) : new ArrayList<Example>();
        examples.add(ex);
        Json.prettyWriteValueHard(new File(path), examples);
    }

    static class GroupInfo {
        @JsonProperty
        final String group;
        @JsonProperty
        final List<Example> examples;
        String path;

        @JsonCreator
        public GroupInfo(@JsonProperty(value="group") String group, @JsonProperty(value="examples") List<Example> examples) {
            this.group = group;
            this.examples = examples;
        }
    }

    public static class Options {
        @Option(gloss="Paths to read input files (format: <group>:<file>)")
        public ArrayList<Pair<String, String>> inPaths = new ArrayList();
        @Option(gloss="Maximum number of examples to read")
        public ArrayList<Pair<String, Integer>> maxExamples = new ArrayList();
        @Option(gloss="Fraction of trainExamples (from the beginning) to keep for training")
        public double trainFrac = 1.0;
        @Option(gloss="Fraction of trainExamples (from the end) to keep for development")
        public double devFrac = 0.0;
        @Option(gloss="Used to randomly divide training examples")
        public Random splitRandom = new Random(1L);
        @Option(gloss="whether to split dev from train")
        public boolean splitDevFromTrain = true;
        @Option(gloss="Only keep examples which have at most this number of tokens")
        public int maxTokens = Integer.MAX_VALUE;
        @Option(gloss="Path to a knowledge graph that will be uploaded as global context")
        public String globalGraphPath;
    }
}

