/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.algcomparison.algorithm.oracle.cpdag.RestrictedBoss;
import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.data.BootstrapSampler;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphSaveLoadUtils;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Boss;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.Ida;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.Pc;
import edu.cmu.tetrad.search.PermutationSearch;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.TextTable;
import edu.pitt.dbmi.data.reader.Delimiter;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;

public class Cstar {
    private boolean parallelized = false;
    private int numSubsamples = 30;
    private int topBracket = 5;
    private double selectionAlpha = 0.0;
    private final IndependenceWrapper test;
    private final ScoreWrapper score;
    private final Parameters parameters;
    private CpdagAlgorithm cpdagAlgorithm = CpdagAlgorithm.PC_STABLE;
    private SampleStyle sampleStyle = SampleStyle.SUBSAMPLE;
    private boolean verbose;
    private File newDir = null;

    public Cstar(IndependenceWrapper test, ScoreWrapper score, Parameters parameters) {
        this.test = test;
        this.score = score;
        this.parameters = parameters;
    }

    public static LinkedList<Record> cStar(LinkedList<LinkedList<Record>> allRecords) {
        HashMap<Edge, List> map = new HashMap<Edge, List>();
        for (List list : allRecords) {
            for (Record record : list) {
                Edge edge = Edges.directedEdge(record.getCauseNode(), record.getEffectNode());
                map.computeIfAbsent(edge, k -> new ArrayList());
                ((List)map.get(edge)).add(record);
            }
        }
        LinkedList<Record> cstar = new LinkedList<Record>();
        for (Edge edge : map.keySet()) {
            List recordList = (List)map.get(edge);
            double[] pis = new double[recordList.size()];
            double[] effects = new double[recordList.size()];
            for (int i = 0; i < recordList.size(); ++i) {
                pis[i] = ((Record)recordList.get(i)).getPi();
                effects[i] = ((Record)recordList.get(i)).getMinBeta();
            }
            double medianPis = StatUtils.median(pis);
            double medianEffects = StatUtils.median(effects);
            Record record = new Record(edge.getNode1(), edge.getNode2(), medianPis, medianEffects, ((Record)recordList.get(0)).getNumCauses(), ((Record)recordList.get(0)).getNumEffects());
            cstar.add(record);
        }
        cstar.sort((o1, o2) -> {
            if (o1.getPi() == o2.getPi()) {
                return Double.compare(o2.getMinBeta(), o1.getMinBeta());
            }
            return Double.compare(o2.getPi(), o1.getPi());
        });
        return cstar;
    }

    private static double pcer(double pi, double q, double p) {
        return 1.0 / (2.0 * pi - 1.0) * (q * q / (p * p));
    }

    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    public LinkedList<LinkedList<Record>> getRecords(DataSet dataSet, List<Node> possibleCauses, List<Node> possibleEffects, int topBracket, String path) {
        boolean made;
        if (topBracket < 1) {
            throw new IllegalArgumentException("Top bracket must be at least 1.");
        }
        if (topBracket > possibleCauses.size()) {
            throw new IllegalArgumentException("Top bracket (q) is too large; it is " + topBracket + " but the number of possible causes is " + possibleCauses.size());
        }
        this.topBracket = topBracket;
        if (path == null || path.isEmpty()) {
            path = "cstar-out";
            TetradLogger.getInstance().forceLogMessage("Using path = 'cstar-out'.");
        }
        File origDir = null;
        if (new File(path).exists()) {
            origDir = new File(path);
        }
        int i = 1;
        while (new File(path + "." + i).exists()) {
            ++i;
        }
        path = path + "." + i;
        File newDir = new File(path);
        if (origDir == null) {
            origDir = newDir;
        }
        if (!(made = newDir.mkdirs())) {
            throw new IllegalStateException("Could not make a new directory; perhaps file permissions need to be adjusted.");
        }
        TetradLogger.getInstance().forceLogMessage("Creating directories for " + newDir.getAbsolutePath());
        newDir = new File(path);
        TetradLogger.getInstance().forceLogMessage("Using files in directory " + origDir.getAbsolutePath());
        this.newDir = newDir;
        possibleEffects = GraphUtils.replaceNodes(possibleEffects, dataSet.getVariables());
        possibleCauses = GraphUtils.replaceNodes(possibleCauses, dataSet.getVariables());
        LinkedList<LinkedList<Record>> allRecords = new LinkedList<LinkedList<Record>>();
        TetradLogger.getInstance().forceLogMessage("Results directory = " + newDir.getAbsolutePath());
        if (new File(origDir, "possible.causes.txt").exists() && new File(newDir, "possible.causes.txt").exists()) {
            TetradLogger.getInstance().forceLogMessage("Loading data, possible causes, and possible effects from " + origDir.getAbsolutePath());
            possibleCauses = this.readVars(dataSet, origDir, "possible.causes.txt");
            possibleEffects = this.readVars(dataSet, origDir, "possible.effects.txt");
        }
        this.writeVars(possibleCauses, newDir, "possible.causes.txt");
        this.writeVars(possibleEffects, newDir, "possible.effects.txt");
        if (new File(origDir, "data.txt").exists()) {
            try {
                dataSet = SimpleDataLoader.loadContinuousData(new File(origDir, "data.txt"), "//", '\"', "*", true, Delimiter.TAB, false);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Could not load data from " + new File(origDir, "data.txt").getAbsolutePath());
            }
        }
        this.writeData(dataSet, newDir);
        ArrayList minimalEffects = new ArrayList();
        for (int e = 0; e < possibleEffects.size(); ++e) {
            minimalEffects.add(new ConcurrentHashMap());
            for (int s = 0; s < this.numSubsamples; ++s) {
                ConcurrentHashMap<Node, Double> map = new ConcurrentHashMap<Node, Double>();
                for (Node node : possibleCauses) {
                    map.put(node, 0.0);
                }
                ((Map)minimalEffects.get(e)).put(s, map);
            }
        }
        ArrayList<Callable<double[][]>> tasks = new ArrayList<Callable<double[][]>>();
        for (int subsample = 0; subsample < this.numSubsamples; ++subsample) {
            class Task
            implements Callable<double[][]> {
                private final List<Node> possibleCauses;
                private final List<Node> possibleEffects;
                private final int subsample;
                private final DataSet _dataSet;
                private final File origDir;
                private final File newDir;

                Task(int subsample, List<Node> possibleCauses, List<Node> possibleEffects, DataSet dataSet, File origDir, File newDir) {
                    this.subsample = subsample;
                    this.possibleCauses = possibleCauses;
                    this.possibleEffects = possibleEffects;
                    this._dataSet = dataSet;
                    this.origDir = origDir;
                    this.newDir = newDir;
                }

                @Override
                public double[][] call() {
                    TetradLogger.getInstance().forceLogMessage("\nRunning subsample " + (this.subsample + 1) + " of " + Cstar.this.numSubsamples + ".");
                    try {
                        double[][] effects;
                        Graph cpdag;
                        BootstrapSampler sampler = new BootstrapSampler();
                        if (new File(this.origDir, "cpdag." + (this.subsample + 1) + ".txt").exists() && new File(this.origDir, "effects." + (this.subsample + 1) + ".txt").exists()) {
                            TetradLogger.getInstance().forceLogMessage("Loading CPDAG and effects from " + this.origDir.getAbsolutePath() + " for index " + (this.subsample + 1));
                            cpdag = GraphSaveLoadUtils.loadGraphTxt(new File(this.origDir, "cpdag." + (this.subsample + 1) + ".txt"));
                            effects = Cstar.this.loadMatrix(new File(this.origDir, "effects." + (this.subsample + 1) + ".txt"));
                        } else {
                            DataSet sample;
                            TetradLogger.getInstance().forceLogMessage("Sampling data for index " + (this.subsample + 1));
                            if (Cstar.this.sampleStyle == SampleStyle.BOOTSTRAP) {
                                sampler.setWithoutReplacements(false);
                                sample = sampler.sample(this._dataSet, this._dataSet.getNumRows() / 2);
                            } else if (Cstar.this.sampleStyle == SampleStyle.SUBSAMPLE) {
                                sampler.setWithoutReplacements(true);
                                sample = sampler.sample(this._dataSet, this._dataSet.getNumRows() / 2);
                            } else {
                                throw new IllegalArgumentException("That type of sample is not configured: " + (Object)((Object)Cstar.this.sampleStyle));
                            }
                            if (Cstar.this.cpdagAlgorithm == CpdagAlgorithm.PC_STABLE) {
                                TetradLogger.getInstance().forceLogMessage("Running PC-Stable for index " + (this.subsample + 1));
                                cpdag = Cstar.this.getPatternPcStable(sample);
                            } else if (Cstar.this.cpdagAlgorithm == CpdagAlgorithm.FGES) {
                                TetradLogger.getInstance().forceLogMessage("Running FGES for index " + (this.subsample + 1));
                                cpdag = Cstar.this.getPatternFges(sample);
                            } else if (Cstar.this.cpdagAlgorithm == CpdagAlgorithm.BOSS) {
                                TetradLogger.getInstance().forceLogMessage("Running BOSS for index " + (this.subsample + 1));
                                cpdag = Cstar.this.getPatternBoss(sample);
                            } else if (Cstar.this.cpdagAlgorithm == CpdagAlgorithm.RESTRICTED_BOSS) {
                                TetradLogger.getInstance().forceLogMessage("Running Restricted BOSS for index " + (this.subsample + 1));
                                cpdag = Cstar.this.getPatternRestrictedBoss(sample, this._dataSet);
                            } else {
                                throw new IllegalArgumentException("That type of of cpdag algorithm is not configured: " + (Object)((Object)Cstar.this.cpdagAlgorithm));
                            }
                            Ida ida = new Ida(sample, cpdag, this.possibleCauses);
                            effects = new double[this.possibleCauses.size()][this.possibleEffects.size()];
                            TetradLogger.getInstance().forceLogMessage("Running IDA for index " + (this.subsample + 1));
                            for (int e = 0; e < this.possibleEffects.size(); ++e) {
                                Map<Node, Double> minEffects = ida.calculateMinimumEffectsOnY(this.possibleEffects.get(e));
                                for (int c = 0; c < this.possibleCauses.size(); ++c) {
                                    Double _e = minEffects.get(this.possibleCauses.get(c));
                                    effects[c][e] = _e != null ? _e : 0.0;
                                }
                            }
                        }
                        TetradLogger.getInstance().forceLogMessage("Saving CPDAG and effects for index " + (this.subsample + 1));
                        Cstar.this.saveMatrix(effects, new File(this.newDir, "effects." + (this.subsample + 1) + ".txt"));
                        try {
                            GraphSaveLoadUtils.saveGraph(cpdag, new File(this.newDir, "cpdag." + (this.subsample + 1) + ".txt"), false);
                        }
                        catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                        return effects;
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
            tasks.add(new Task(subsample, possibleCauses, possibleEffects, dataSet, origDir, newDir));
        }
        List<double[][]> allEffects = this.runCallablesDoubleArray(tasks, this.parallelized);
        ArrayList doubles = new ArrayList();
        for (int subsample = 0; subsample < this.numSubsamples; ++subsample) {
            double[][] effects = allEffects.get(subsample);
            if (effects.length != possibleCauses.size() || effects[0].length != possibleEffects.size()) {
                throw new IllegalStateException("Length of subsample " + (subsample + 1) + "does not match the number of possible causes.");
            }
            ArrayList<Double> _doubles = new ArrayList<Double>();
            for (int c = 0; c < possibleCauses.size(); ++c) {
                for (int e = 0; e < possibleEffects.size(); ++e) {
                    _doubles.add(effects[c][e]);
                }
            }
            _doubles.sort((o1, o2) -> Double.compare(o2, o1));
            doubles.add(_doubles);
        }
        try {
            if (this.verbose) {
                TetradLogger.getInstance().forceLogMessage("Examining top bracket = " + this.topBracket + ".");
            }
            ArrayList<Tuple> tuples = new ArrayList<Tuple>();
            for (int e = 0; e < possibleEffects.size(); ++e) {
                for (int c = 0; c < possibleCauses.size(); ++c) {
                    int count = 0;
                    for (int subsample = 0; subsample < this.numSubsamples; ++subsample) {
                        double cutoff = (Double)((List)doubles.get(subsample)).get(this.topBracket * possibleEffects.size() - 1);
                        if (!(allEffects.get(subsample)[c][e] >= cutoff)) continue;
                        ++count;
                    }
                    double pi = (double)count / (double)this.numSubsamples;
                    if (pi <= 0.0) continue;
                    Node cause = possibleCauses.get(c);
                    Node effect = possibleEffects.get(e);
                    tuples.add(new Tuple(cause, effect, pi, this.avgMinEffect(possibleCauses, possibleEffects, allEffects, cause, effect)));
                }
            }
            tuples.sort((o1, o2) -> {
                if (o1.getPi() == o2.getPi()) {
                    return Double.compare(o2.getMinBeta(), o1.getMinBeta());
                }
                return Double.compare(o2.getPi(), o1.getPi());
            });
            LinkedList<Record> records = new LinkedList<Record>();
            for (Tuple tuple : tuples) {
                double avg = tuple.getMinBeta();
                Node causeNode = tuple.getCauseNode();
                Node effectNode = tuple.getEffectNode();
                if (!(tuple.getMinBeta() > this.selectionAlpha)) continue;
                records.add(new Record(causeNode, effectNode, tuple.getPi(), avg, possibleCauses.size(), possibleEffects.size()));
            }
            allRecords.add(records);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        allRecords.sort(Comparator.comparingDouble(List::size));
        return allRecords;
    }

    public Graph makeGraph(List<Record> records) {
        ArrayList<Node> outNodes = new ArrayList<Node>();
        EdgeListGraph graph = new EdgeListGraph(outNodes);
        for (Record record : records) {
            if (!(record.getPi() > 0.5)) continue;
            graph.addNode(record.getCauseNode());
            graph.addNode(record.getEffectNode());
            graph.addDirectedEdge(record.getCauseNode(), record.getEffectNode());
        }
        return graph;
    }

    public void setCpdagAlgorithm(CpdagAlgorithm cpdagAlgorithm) {
        this.cpdagAlgorithm = cpdagAlgorithm;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setSelectionAlpha(double selectionAlpha) {
        this.selectionAlpha = selectionAlpha;
    }

    public void setSampleStyle(SampleStyle sampleStyle) {
        this.sampleStyle = sampleStyle;
    }

    public void setNumSubsamples(int numSubsamples) {
        this.numSubsamples = numSubsamples;
    }

    public File getDir() {
        return this.newDir;
    }

    private void writeData(DataSet dataSet, File dir) {
        try {
            PrintStream out = new PrintStream(new FileOutputStream(new File(dir, "data.txt")));
            out.println(dataSet.toString());
        }
        catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private List<Node> readVars(DataSet dataSet, File dir, String s) {
        try {
            String line;
            ArrayList<Node> vars = new ArrayList<Node>();
            File file = new File(dir, s);
            BufferedReader in = new BufferedReader(new FileReader(file));
            while ((line = in.readLine()) != null) {
                if (line.trim().isEmpty()) continue;
                vars.add(dataSet.getVariable(line.trim()));
            }
            return vars;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private void writeVars(List<Node> vars, File dir, String s) {
        try {
            File file = new File(dir, s);
            PrintStream out = new PrintStream(new FileOutputStream(file));
            for (Node node : vars) {
                out.println(node.getName());
            }
            out.flush();
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }

    private double avgMinEffect(List<Node> possibleCauses, List<Node> possibleEffects, List<double[][]> allEffects, Node causeNode, Node effectNode) {
        ArrayList<Double> f = new ArrayList<Double>();
        if (allEffects == null) {
            throw new NullPointerException("effects null");
        }
        for (int k = 0; k < this.numSubsamples; ++k) {
            int c = possibleCauses.indexOf(causeNode);
            int e = possibleEffects.indexOf(effectNode);
            f.add(allEffects.get(k)[c][e]);
        }
        double[] _f = new double[f.size()];
        for (int b = 0; b < f.size(); ++b) {
            _f[b] = (Double)f.get(b);
        }
        return StatUtils.mean(_f);
    }

    public String makeTable(LinkedList<Record> records) {
        String header = "# Potential Causes = " + records.get(0).getNumCauses() + "\n# Potential Effects = " + records.get(0).getNumEffects() + "\nTop Bracket (\u2018q\u2019) = " + this.topBracket + "\n\n";
        int numColumns = 6;
        TextTable table = new TextTable(records.size() + 1, numColumns);
        DecimalFormat nf = new DecimalFormat("0.0000");
        int column = 0;
        table.setToken(0, column++, "Index");
        table.setToken(0, column++, "Cause");
        table.setToken(0, column++, "Effect");
        table.setToken(0, column++, "PI");
        table.setToken(0, column++, "Effect");
        table.setToken(0, column, "PCER");
        if (records.isEmpty()) {
            return "\nThere are no records above chance.\n";
        }
        int p = records.getLast().getNumCauses();
        for (int i = 0; i < records.size(); ++i) {
            Node cause = records.get(i).getCauseNode();
            Node effect = records.get(i).getEffectNode();
            column = 0;
            table.setToken(i + 1, column++, String.valueOf(i + 1));
            table.setToken(i + 1, column++, cause.getName());
            table.setToken(i + 1, column++, effect.getName());
            table.setToken(i + 1, column++, nf.format(records.get(i).getPi()));
            table.setToken(i + 1, column++, nf.format(records.get(i).getMinBeta()));
            double pcer = Cstar.pcer(records.get(i).getPi(), i + 1, p);
            table.setToken(i + 1, column, records.get(i).getPi() <= 0.5 ? "*" : nf.format(pcer));
        }
        return header + table;
    }

    private Graph getPatternPcStable(DataSet sample) {
        IndependenceTest test = this.test.getTest(sample, this.parameters);
        test.setVerbose(false);
        Pc pc = new Pc(test);
        pc.setStable(true);
        pc.setVerbose(false);
        return pc.search();
    }

    private Graph getPatternFges(DataSet sample) {
        Score score = this.score.getScore(sample, this.parameters);
        Fges fges = new Fges(score);
        fges.setVerbose(false);
        return fges.search();
    }

    private Graph getPatternBoss(DataSet sample) {
        Score score = this.score.getScore(sample, this.parameters);
        PermutationSearch boss = new PermutationSearch(new Boss(score));
        return boss.search();
    }

    private Graph getPatternRestrictedBoss(DataSet sample, DataSet data) {
        RestrictedBoss restrictedBoss = new RestrictedBoss(this.score);
        this.parameters.set("trimmingStyle", (Object)1);
        Graph g = restrictedBoss.search(sample, this.parameters);
        g = GraphUtils.replaceNodes(g, data.getVariables());
        return g;
    }

    private void saveMatrix(double[][] effects, File file) {
        try {
            ArrayList<Node> vars = new ArrayList<Node>();
            for (int i = 0; i < effects[0].length; ++i) {
                vars.add(new ContinuousVariable("V" + (i + 1)));
            }
            BoxDataSet data = new BoxDataSet(new DoubleDataBox(effects), vars);
            if (file != null) {
                PrintStream out = new PrintStream(new FileOutputStream(file));
                out.println(data);
            }
        }
        catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private double[][] loadMatrix(File file) {
        try {
            DataSet dataSet = SimpleDataLoader.loadContinuousData(file, "//", '\"', "*", true, Delimiter.TAB, false);
            return dataSet.getDoubleData().toArray();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private List<double[][]> runCallablesDoubleArray(List<Callable<double[][]>> tasks, boolean parallelized) {
        if (tasks.isEmpty()) {
            return new ArrayList<double[][]>();
        }
        ArrayList<double[][]> results = new ArrayList<double[][]>();
        if (!parallelized) {
            for (Callable<double[][]> task : tasks) {
                try {
                    results.add(task.call());
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        } else {
            ForkJoinPool executorService = ForkJoinPool.commonPool();
            try {
                List<Future<double[][]>> futures = executorService.invokeAll(tasks);
                for (Future<double[][]> future : futures) {
                    results.add(future.get());
                }
            }
            catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return results;
    }

    public static enum CpdagAlgorithm {
        PC_STABLE,
        FGES,
        BOSS,
        RESTRICTED_BOSS;

    }

    public static enum SampleStyle {
        BOOTSTRAP,
        SUBSAMPLE;

    }

    public static class Record
    implements TetradSerializable {
        private static final long serialVersionUID = 23L;
        private final Node causeNode;
        private final Node target;
        private final double pi;
        private final double effect;
        private final int numCauses;
        private final int numEffects;

        Record(Node predictor, Node target, double pi, double minEffect, int numCauses, int numEffects) {
            this.causeNode = predictor;
            this.target = target;
            this.pi = pi;
            this.effect = minEffect;
            this.numCauses = numCauses;
            this.numEffects = numEffects;
        }

        public Node getCauseNode() {
            return this.causeNode;
        }

        public Node getEffectNode() {
            return this.target;
        }

        public double getPi() {
            return this.pi;
        }

        double getMinBeta() {
            return this.effect;
        }

        public int getNumCauses() {
            return this.numCauses;
        }

        public int getNumEffects() {
            return this.numEffects;
        }
    }

    private static class Tuple {
        private final Node cause;
        private final Node effect;
        private final double pi;
        private final double minBeta;

        private Tuple(Node cause, Node effect, double pi, double minBeta) {
            this.cause = cause;
            this.effect = effect;
            this.pi = pi;
            this.minBeta = minBeta;
        }

        public Node getCauseNode() {
            return this.cause;
        }

        public Node getEffectNode() {
            return this.effect;
        }

        public double getPi() {
            return this.pi;
        }

        public double getMinBeta() {
            return this.minBeta;
        }
    }
}

