/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.DirichletBayesIm;
import edu.cmu.tetrad.bayes.DirichletEstimator;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.MisclassificationUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeEqualityMode;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.RandomGraph;
import edu.cmu.tetrad.performance.Comparison;
import edu.cmu.tetrad.search.BDeuScore;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.IndTestChiSquare;
import edu.cmu.tetrad.search.IndTestProbabilistic;
import edu.cmu.tetrad.search.Rfci;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.search.XdslXmlParser;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TextTable;
import edu.pitt.dbmi.algo.bayesian.constraint.inference.BCInference;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import nu.xom.Builder;
import nu.xom.Document;
import nu.xom.ParsingException;
import org.apache.commons.math3.util.FastMath;

public class RBExperiments {
    private final int depth = 5;
    private String directory;
    private static final int MININUM_EXPONENT = -1022;

    private List<Node> getLatents(Graph dag) {
        ArrayList<Node> latents = new ArrayList<Node>();
        for (Node n : dag.getNodes()) {
            if (n.getNodeType() != NodeType.LATENT) continue;
            latents.add(n);
        }
        return latents;
    }

    public Graph makeSimpleDAG(int numLatentConfounders) {
        ArrayList<Node> nodes = new ArrayList<Node>();
        for (int i = 0; i < 5; ++i) {
            nodes.add(new DiscreteVariable(Integer.toString(i + 1)));
        }
        EdgeListGraph dag = new EdgeListGraph(nodes);
        dag.addDirectedEdge((Node)nodes.get(0), (Node)nodes.get(1));
        dag.addDirectedEdge((Node)nodes.get(0), (Node)nodes.get(2));
        dag.addDirectedEdge((Node)nodes.get(1), (Node)nodes.get(3));
        dag.addDirectedEdge((Node)nodes.get(2), (Node)nodes.get(3));
        dag.addDirectedEdge((Node)nodes.get(2), (Node)nodes.get(4));
        return dag;
    }

    private BayesIm initializeIM(BayesIm im) {
        int node = 0;
        im.setProbability(node, 0, 0, 0.8);
        im.setProbability(node, 0, 1, 0.2);
        node = 1;
        im.setProbability(node, 0, 0, 0.9);
        im.setProbability(node, 0, 1, 0.1);
        im.setProbability(node, 1, 0, 0.3);
        im.setProbability(node, 1, 1, 0.7);
        node = 2;
        im.setProbability(node, 0, 0, 0.8);
        im.setProbability(node, 0, 1, 0.2);
        im.setProbability(node, 1, 0, 0.4);
        im.setProbability(node, 1, 1, 0.6);
        node = 3;
        im.setProbability(node, 0, 0, 0.9);
        im.setProbability(node, 0, 1, 0.1);
        im.setProbability(node, 1, 0, 0.7);
        im.setProbability(node, 1, 1, 0.3);
        im.setProbability(node, 2, 0, 0.6);
        im.setProbability(node, 2, 1, 0.4);
        im.setProbability(node, 3, 0, 0.2);
        im.setProbability(node, 3, 1, 0.8);
        node = 4;
        im.setProbability(node, 0, 0, 0.9);
        im.setProbability(node, 0, 1, 0.1);
        im.setProbability(node, 1, 0, 0.6);
        im.setProbability(node, 1, 1, 0.4);
        return im;
    }

    public static void main(String[] args) throws IOException {
        int[] cases;
        NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT);
        double alpha = 0.05;
        double lower = 0.3;
        double upper = 0.7;
        int numCases = 100;
        int numModels = 5;
        int numBootstrapSamples = 10;
        int round = 0;
        String modelName = "Alarm";
        String filePath = "/Users/chw20/Documents/DBMI/bsc-results";
        String dataPath = System.getProperty("user.dir");
        boolean threshold1 = false;
        boolean threshold2 = true;
        block30: for (int i = 0; i < args.length; ++i) {
            switch (args[i]) {
                case "-c": {
                    numCases = Integer.parseInt(args[i + 1]);
                    continue block30;
                }
                case "-lv": {
                    Double.parseDouble(args[i + 1]);
                    continue block30;
                }
                case "-bs": {
                    numBootstrapSamples = Integer.parseInt(args[i + 1]);
                    continue block30;
                }
                case "-alpha": {
                    alpha = Double.parseDouble(args[i + 1]);
                    continue block30;
                }
                case "-m": {
                    numModels = Integer.parseInt(args[i + 1]);
                    continue block30;
                }
                case "-net": {
                    modelName = args[i + 1];
                    continue block30;
                }
                case "-t1": {
                    threshold1 = Boolean.parseBoolean(args[i + 1]);
                    continue block30;
                }
                case "-t2": {
                    threshold2 = Boolean.parseBoolean(args[i + 1]);
                    continue block30;
                }
                case "-low": {
                    lower = Double.parseDouble(args[i + 1]);
                    continue block30;
                }
                case "-up": {
                    upper = Double.parseDouble(args[i + 1]);
                    continue block30;
                }
                case "-out": {
                    filePath = args[i + 1];
                    continue block30;
                }
                case "-data": {
                    dataPath = args[i + 1];
                    continue block30;
                }
                case "-i": {
                    round = Integer.parseInt(args[i + 1]);
                }
            }
        }
        RBExperiments DFC = new RBExperiments();
        DFC.directory = dataPath;
        double[] lv = new double[]{0.0};
        for (int numCase : cases = new int[]{200}) {
            for (double numLatentConfounder : lv) {
                for (int i = 0; i < 10; ++i) {
                    DFC.experiment(modelName, numCase, numModels, numBootstrapSamples, alpha, numLatentConfounder, threshold1, threshold2, lower, upper, filePath, i);
                }
            }
        }
    }

    public void experiment(String modelName, int numCases, int numModels, int numBootstrapSamples, double alpha, double numLatentConfounders, boolean threshold1, boolean threshold2, double lower, double upper, String filePath, int round) {
        PrintStream out;
        Long seed = 878376L;
        RandomUtil.getInstance().setSeed(seed);
        BayesIm im = this.getBayesIM(modelName);
        BayesPm pm = im.getBayesPm();
        Graph dag = pm.getDag();
        int numVars = im.getNumNodes();
        int LV = (int)FastMath.floor(numLatentConfounders * (double)numVars);
        RandomGraph.fixLatents4(LV, dag);
        System.out.println("Variables set to be latent:" + this.getLatents(dag));
        filePath = filePath + "/" + modelName + "-Vars" + dag.getNumNodes() + "-Edges" + dag.getNumEdges() + "-H" + numLatentConfounders + "-Cases" + numCases + "-numModels" + numModels + "-BS" + numBootstrapSamples;
        try {
            File dir = new File(filePath);
            dir.mkdirs();
            File file = new File(dir, "Results-" + modelName + "-Vars" + dag.getNumNodes() + "-Edges" + dag.getNumEdges() + "-H" + numLatentConfounders + "-Cases" + numCases + "-numModels" + numModels + "-BS" + numBootstrapSamples + "-" + round + ".txt");
            if (file.exists() && file.length() != 0L) {
                return;
            }
            out = new PrintStream(new FileOutputStream(file));
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        RandomUtil.getInstance().setSeed((long)round * 1000000L + 71512L);
        DataSet fullData = im.simulateData(numCases, true);
        fullData = this.refineData(fullData);
        DataSet data = DataUtils.restrictToMeasured(fullData);
        Graph PAG_True = SearchGraphUtils.dagToPag(dag);
        PAG_True = GraphUtils.replaceNodes(PAG_True, data.getVariables());
        long start = MillisecondTimes.timeMillis();
        Graph rfciPag = this.runPagCs(data, alpha);
        long RfciTime = MillisecondTimes.timeMillis() - start;
        System.out.println("RFCI done!");
        ArrayList<Graph> bscPags = new ArrayList<Graph>();
        start = MillisecondTimes.timeMillis();
        IndTestProbabilistic testBSC = this.runRB(data, bscPags, numModels, threshold1);
        long BscRfciTime = MillisecondTimes.timeMillis() - start;
        Map<IndependenceFact, Double> H = testBSC.getH();
        System.out.println("RB (RFCI-BSC) done!");
        start = MillisecondTimes.timeMillis();
        DataSet depData = this.createDepDataFiltering(H, data, numBootstrapSamples, threshold2, lower, upper);
        out.println("DepData(row,col):" + depData.getNumRows() + "," + depData.getNumColumns());
        System.out.println("Dep data creation done!");
        Graph depCPDAG = this.runFGS(depData);
        Graph estDepBN = SearchGraphUtils.dagFromCPDAG(depCPDAG);
        System.out.println("estDepBN: " + estDepBN.getEdges());
        out.println("DepGraph(nodes,edges):" + estDepBN.getNumNodes() + "," + estDepBN.getNumEdges());
        System.out.println("Dependency graph done!");
        BayesPm pmHat = new BayesPm(estDepBN, 2, 2);
        DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(pmHat, 0.5);
        DirichletBayesIm imHat = DirichletEstimator.estimate(prior, depData);
        Long BscdTime = MillisecondTimes.timeMillis() - start;
        System.out.println("Dependency BN_Param done");
        start = MillisecondTimes.timeMillis();
        allScores lnProbs = this.getLnProbsAll(bscPags, H, data, imHat, estDepBN);
        Long mutualTime = (MillisecondTimes.timeMillis() - start) / 2L;
        start = MillisecondTimes.timeMillis();
        Map<Graph, Double> normalizedDep = this.normalProbs(lnProbs.LnBSCD);
        Long dTime = MillisecondTimes.timeMillis() - start;
        start = MillisecondTimes.timeMillis();
        Map<Graph, Double> normalizedInd = this.normalProbs(lnProbs.LnBSCI);
        Long iTime = MillisecondTimes.timeMillis() - start;
        normalizedDep = MapUtil.sortByValue(normalizedDep);
        Graph maxBND = normalizedDep.keySet().iterator().next();
        normalizedInd = MapUtil.sortByValue(normalizedInd);
        Graph maxBNI = normalizedInd.keySet().iterator().next();
        out.println("*** RFCI time (sec):" + RfciTime / 1000L);
        this.summarize(rfciPag, PAG_True, out);
        out.println("\n*** RB-I time (sec):" + BscRfciTime);
        this.summarize(maxBNI, PAG_True, out);
        out.println("\n*** RB-D time (sec):" + BscRfciTime);
        this.summarize(maxBND, PAG_True, out);
        out.println("P(maxBNI): \n1.0");
        out.println("P(maxBND): \n1.0");
        out.println("------------------------------------------");
        out.println("PAG_True: \n" + PAG_True);
        out.println("------------------------------------------");
        out.println("Rfci: \n" + rfciPag);
        out.println("------------------------------------------");
        out.println("RB-I: \n" + maxBNI);
        out.println("------------------------------------------");
        out.println("RB-D: \n" + maxBND);
        out.close();
    }

    private DataSet refineData(DataSet fullData) {
        for (int c = 0; c < fullData.getNumColumns(); ++c) {
            for (int r = 0; r < fullData.getNumRows(); ++r) {
                if (fullData.getInt(r, c) >= 0) continue;
                fullData.setInt(r, c, 0);
            }
        }
        return fullData;
    }

    private BayesIm getBayesIM(String type) {
        if ("Alarm".equals(type)) {
            return this.loadBayesIm("Alarm.xdsl", true);
        }
        if ("Hailfinder".equals(type)) {
            return this.loadBayesIm("Hailfinder.xdsl", false);
        }
        if ("Hepar".equals(type)) {
            return this.loadBayesIm("Hepar2.xdsl", true);
        }
        if ("Win95".equals(type)) {
            return this.loadBayesIm("win95pts.xdsl", false);
        }
        if ("Barley".equals(type)) {
            return this.loadBayesIm("barley.xdsl", false);
        }
        throw new IllegalArgumentException("Not a recogized Bayes IM type.");
    }

    private void summarize(Graph graph, Graph trueGraph, PrintStream out) {
        out.flush();
        ArrayList<Comparison.TableColumn> tableColumns = new ArrayList<Comparison.TableColumn>();
        tableColumns.add(Comparison.TableColumn.AhdCor);
        tableColumns.add(Comparison.TableColumn.AhdFp);
        tableColumns.add(Comparison.TableColumn.AhdFn);
        tableColumns.add(Comparison.TableColumn.AhdPrec);
        tableColumns.add(Comparison.TableColumn.AhdRec);
        tableColumns.add(Comparison.TableColumn.AdjCor);
        tableColumns.add(Comparison.TableColumn.AdjFp);
        tableColumns.add(Comparison.TableColumn.AdjFn);
        tableColumns.add(Comparison.TableColumn.AdjPrec);
        tableColumns.add(Comparison.TableColumn.AdjRec);
        tableColumns.add(Comparison.TableColumn.SHD);
        GraphUtils.GraphComparison comparison = SearchGraphUtils.getGraphComparison(trueGraph, graph);
        ArrayList<Node> variables = new ArrayList<Node>();
        for (Comparison.TableColumn column : tableColumns) {
            variables.add(new ContinuousVariable(column.toString()));
        }
        BoxDataSet dataSet = new BoxDataSet(new DoubleDataBox(0, variables.size()), variables);
        dataSet.setNumberFormat(new DecimalFormat("0"));
        int newRow = dataSet.getNumRows();
        if (tableColumns.contains((Object)Comparison.TableColumn.AdjCor)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AdjCor), comparison.getAdjCor());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AdjFn)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AdjFn), comparison.getAdjFn());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AdjFp)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AdjFp), comparison.getAdjFp());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AdjPrec)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AdjPrec), comparison.getAdjPrec());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AdjRec)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AdjRec), comparison.getAdjRec());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AhdCor)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AhdCor), comparison.getAhdCor());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AhdFn)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AhdFn), comparison.getAhdFn());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AhdFp)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AhdFp), comparison.getAhdFp());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AhdPrec)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AhdPrec), comparison.getAhdPrec());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.AhdRec)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.AhdRec), comparison.getAhdRec());
        }
        if (tableColumns.contains((Object)Comparison.TableColumn.SHD)) {
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)Comparison.TableColumn.SHD), comparison.getShd());
        }
        int[] cols = new int[tableColumns.size()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        out.println(MisclassificationUtils.edgeMisclassifications(graph, trueGraph));
        int SHDAdj = comparison.getEdgesAdded().size() + comparison.getEdgesRemoved().size();
        int diffEdgePoint = comparison.getShd() - SHDAdj * 2;
        out.println("# missing/extra edges: " + SHDAdj);
        out.println("# different edge points: " + diffEdgePoint);
        out.println("-------------------------------");
    }

    private double[] printCorrectArrows(Graph outGraph, Graph truePag, PrintStream out) {
        int correctArrows = 0;
        int totalEstimatedArrows = 0;
        int totalTrueArrows = 0;
        double[] stats = new double[5];
        for (Edge edge : outGraph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            Endpoint ex = edge.getEndpoint1();
            Endpoint ey = edge.getEndpoint2();
            Edge edge1 = truePag.getEdge(x, y);
            if (ex == Endpoint.ARROW) {
                if (edge1 != null && edge1.getProximalEndpoint(x) == Endpoint.ARROW) {
                    ++correctArrows;
                }
                ++totalEstimatedArrows;
            }
            if (ey != Endpoint.ARROW) continue;
            if (edge1 != null && edge1.getProximalEndpoint(y) == Endpoint.ARROW) {
                ++correctArrows;
            }
            ++totalEstimatedArrows;
        }
        for (Edge edge : truePag.getEdges()) {
            Endpoint ex = edge.getEndpoint1();
            Endpoint ey = edge.getEndpoint2();
            if (ex == Endpoint.ARROW) {
                ++totalTrueArrows;
            }
            if (ey != Endpoint.ARROW) continue;
            ++totalTrueArrows;
        }
        out.println();
        out.println("# correct arrows: " + correctArrows);
        out.println("# total estimated arrows: " + totalEstimatedArrows);
        out.println("# total true arrows: " + totalTrueArrows);
        out.println();
        DecimalFormat nf = new DecimalFormat("0.00");
        double precision = (double)correctArrows / (double)totalEstimatedArrows;
        out.println("Arrow precision: " + nf.format(precision));
        double recall = (double)correctArrows / (double)totalTrueArrows;
        out.println("Arrow recall: " + nf.format(recall));
        stats[0] = correctArrows;
        stats[1] = totalEstimatedArrows;
        stats[2] = totalTrueArrows;
        stats[3] = precision;
        stats[4] = recall;
        return stats;
    }

    private double[] printCorrectTails(Graph outGraph, Graph truePag, PrintStream out) {
        int correctTails = 0;
        int totalEstimatedTails = 0;
        int totalTrueTails = 0;
        double[] stats = new double[5];
        for (Edge edge : outGraph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            Endpoint ex = edge.getEndpoint1();
            Endpoint ey = edge.getEndpoint2();
            Edge edge1 = truePag.getEdge(x, y);
            if (ex == Endpoint.TAIL) {
                if (edge1 != null && edge1.getProximalEndpoint(x) == Endpoint.TAIL) {
                    ++correctTails;
                }
                ++totalEstimatedTails;
            }
            if (ey != Endpoint.TAIL) continue;
            if (edge1 != null && edge1.getProximalEndpoint(y) == Endpoint.TAIL) {
                ++correctTails;
            }
            ++totalEstimatedTails;
        }
        for (Edge edge : truePag.getEdges()) {
            Endpoint ex = edge.getEndpoint1();
            Endpoint ey = edge.getEndpoint2();
            if (ex == Endpoint.TAIL) {
                ++totalTrueTails;
            }
            if (ey != Endpoint.TAIL) continue;
            ++totalTrueTails;
        }
        out.println();
        out.println("# correct tails: " + correctTails);
        out.println("# total estimated tails: " + totalEstimatedTails);
        out.println("# total true tails: " + totalTrueTails);
        out.println();
        DecimalFormat nf = new DecimalFormat("0.00");
        double precision = (double)correctTails / (double)totalEstimatedTails;
        out.println("Tail precision: " + nf.format(precision));
        double recall = (double)correctTails / (double)totalTrueTails;
        out.println("Tail recall: " + nf.format(recall));
        stats[0] = correctTails;
        stats[1] = totalEstimatedTails;
        stats[2] = totalTrueTails;
        stats[3] = precision;
        stats[4] = recall;
        return stats;
    }

    private TextTable getTextTable(DataSet dataSet, int[] columns, NumberFormat nf) {
        int j;
        int i;
        TextTable table = new TextTable(dataSet.getNumRows() + 2, columns.length + 1);
        table.setToken(0, 0, "Run #");
        for (int j2 = 0; j2 < columns.length; ++j2) {
            table.setToken(0, j2 + 1, dataSet.getVariable(columns[j2]).getName());
        }
        for (i = 0; i < dataSet.getNumRows(); ++i) {
            table.setToken(i + 1, 0, Integer.toString(i + 1));
        }
        for (i = 0; i < dataSet.getNumRows(); ++i) {
            for (j = 0; j < columns.length; ++j) {
                table.setToken(i + 1, j + 1, nf.format(dataSet.getDouble(i, columns[j])));
            }
        }
        DecimalFormat nf2 = new DecimalFormat("0.00");
        for (j = 0; j < columns.length; ++j) {
            double sum = 0.0;
            for (int i2 = 0; i2 < dataSet.getNumRows(); ++i2) {
                sum += dataSet.getDouble(i2, columns[j]);
            }
            double avg = sum / (double)dataSet.getNumRows();
            table.setToken(dataSet.getNumRows() + 2 - 1, j + 1, nf2.format(avg));
        }
        table.setToken(dataSet.getNumRows() + 2 - 1, 0, "Avg");
        return table;
    }

    private DataSet createDepDataFiltering(Map<IndependenceFact, Double> H, DataSet data, int numBootstrapSamples, boolean threshold, double lower, double upper) {
        ArrayList<Node> vars = new ArrayList<Node>();
        HashMap<IndependenceFact, Double> HCopy = new HashMap<IndependenceFact, Double>();
        for (IndependenceFact f : H.keySet()) {
            if (!(H.get(f) > lower) || !(H.get(f) < upper)) continue;
            HCopy.put(f, H.get(f));
            DiscreteVariable var = new DiscreteVariable(f.toString());
            vars.add(var);
        }
        BoxDataSet depData = new BoxDataSet(new DoubleDataBox(numBootstrapSamples, vars.size()), vars);
        System.out.println("\nDep data rows: " + depData.getNumRows() + ", columns: " + depData.getNumColumns());
        System.out.println("HCopy size: " + HCopy.size());
        for (int b = 0; b < numBootstrapSamples; ++b) {
            DataSet bsData = DataUtils.getBootstrapSample(data, data.getNumRows());
            IndTestProbabilistic bsTest = new IndTestProbabilistic(bsData);
            bsTest.setThreshold(threshold);
            for (IndependenceFact f : HCopy.keySet()) {
                boolean ind = bsTest.checkIndependence(f.getX(), f.getY(), f.getZ()).independent();
                int value = ind ? 1 : 0;
                depData.setInt(b, depData.getColumn(depData.getVariable(f.toString())), value);
            }
        }
        return depData;
    }

    private Graph runFGS(DataSet data) {
        BDeuScore sd = new BDeuScore(data);
        sd.setSamplePrior(1.0);
        sd.setStructurePrior(1.0);
        Fges fgs = new Fges(sd);
        fgs.setVerbose(false);
        fgs.setFaithfulnessAssumed(true);
        Graph fgsCPDAG = fgs.search();
        fgsCPDAG = GraphUtils.replaceNodes(fgsCPDAG, data.getVariables());
        return fgsCPDAG;
    }

    private allScores getLnProbsAll(List<Graph> pags, Map<IndependenceFact, Double> H, DataSet data, BayesIm im, Graph dep) {
        HashMap<Graph, Double> pagLnBSCD = new HashMap<Graph, Double>();
        HashMap<Graph, Double> pagLnBSCI = new HashMap<Graph, Double>();
        for (Graph pagOrig : pags) {
            if (pagLnBSCD.containsKey(pagOrig)) continue;
            double lnInd = this.getLnProb(pagOrig, H);
            double lnDep = this.getLnProbUsingDepFiltering(pagOrig, H, im, dep);
            pagLnBSCD.put(pagOrig, lnDep);
            pagLnBSCI.put(pagOrig, lnInd);
        }
        System.out.println("pags size: " + pags.size());
        System.out.println("unique pags size: " + pagLnBSCD.size());
        return new allScores(pagLnBSCD, pagLnBSCI);
    }

    private IndTestProbabilistic runRB(DataSet data, List<Graph> pags, int numModels, boolean threshold) {
        IndTestProbabilistic BSCtest = new IndTestProbabilistic(data);
        BSCtest.setThreshold(threshold);
        Rfci BSCrfci = new Rfci(BSCtest);
        BSCrfci.setVerbose(false);
        BSCrfci.setDepth(this.depth);
        for (int i = 0; i < numModels; ++i) {
            Graph BSCPag = BSCrfci.search();
            BSCPag = GraphUtils.replaceNodes(BSCPag, data.getVariables());
            pags.add(BSCPag);
        }
        return BSCtest;
    }

    private Graph runPagCs(DataSet data, double alpha) {
        IndTestChiSquare test = new IndTestChiSquare(data, alpha);
        Rfci fci1 = new Rfci(test);
        fci1.setDepth(this.depth);
        fci1.setVerbose(false);
        Graph PAG_CS = fci1.search();
        PAG_CS = GraphUtils.replaceNodes(PAG_CS, data.getVariables());
        return PAG_CS;
    }

    private double getLnProbUsingDepFiltering(Graph pag, Map<IndependenceFact, Double> H, BayesIm im, Graph dep) {
        double lnQ = 0.0;
        for (IndependenceFact fact : H.keySet()) {
            double v;
            double p = 0.0;
            BCInference.OP op = pag.paths().isDSeparatedFrom(fact.getX(), fact.getY(), fact.getZ()) ? BCInference.OP.independent : BCInference.OP.dependent;
            if (im.getNode(fact.toString()) != null) {
                Node node = im.getNode(fact.toString());
                int[] parents = im.getParents(im.getNodeIndex(node));
                if (parents.length > 0) {
                    int[] parentValues = new int[parents.length];
                    for (int parentIndex = 0; parentIndex < parentValues.length; ++parentIndex) {
                        String parentName = im.getNode(parents[parentIndex]).getName();
                        String[] splitParent = parentName.split(Pattern.quote("_||_"));
                        Node X = pag.getNode(splitParent[0].trim());
                        String[] splitParent2 = splitParent[1].trim().split(Pattern.quote("|"));
                        Node Y = pag.getNode(splitParent2[0].trim());
                        ArrayList<Node> Z = new ArrayList<Node>();
                        if (splitParent2.length > 1) {
                            String[] splitParent3;
                            for (String s : splitParent3 = splitParent2[1].trim().split(Pattern.quote(","))) {
                                Z.add(pag.getNode(s.trim()));
                            }
                        }
                        IndependenceFact parentFact = new IndependenceFact(X, Y, Z);
                        parentValues[parentIndex] = pag.paths().isDSeparatedFrom(parentFact.getX(), parentFact.getY(), parentFact.getZ()) ? 1 : 0;
                    }
                    int rowIndex = im.getRowIndex(im.getNodeIndex(node), parentValues);
                    p = im.getProbability(im.getNodeIndex(node), rowIndex, 1);
                } else {
                    p = im.getProbability(im.getNodeIndex(node), 0, 1);
                }
                if (op == BCInference.OP.dependent) {
                    p = 1.0 - p;
                }
                if (p < -1.0E-4 || p > 1.0001 || Double.isNaN(p) || Double.isInfinite(p)) {
                    throw new IllegalArgumentException("p illegally equals " + p);
                }
                double v2 = lnQ + FastMath.log(p);
                if (Double.isNaN(v2) || Double.isInfinite(v2)) continue;
                lnQ = v2;
                continue;
            }
            p = H.get(fact);
            if (p < -1.0E-4 || p > 1.0001 || Double.isNaN(p) || Double.isInfinite(p)) {
                throw new IllegalArgumentException("p illegally equals " + p);
            }
            if (op == BCInference.OP.dependent) {
                p = 1.0 - p;
            }
            if (Double.isNaN(v = lnQ + FastMath.log(p)) || Double.isInfinite(v)) continue;
            lnQ = v;
        }
        return lnQ;
    }

    private double getLnProb(Graph pag, Map<IndependenceFact, Double> H) {
        double lnQ = 0.0;
        for (IndependenceFact fact : H.keySet()) {
            double v;
            BCInference.OP op = pag.paths().isDSeparatedFrom(fact.getX(), fact.getY(), fact.getZ()) ? BCInference.OP.independent : BCInference.OP.dependent;
            double p = H.get(fact);
            if (p < -1.0E-4 || p > 1.0001 || Double.isNaN(p) || Double.isInfinite(p)) {
                throw new IllegalArgumentException("p illegally equals " + p);
            }
            if (op == BCInference.OP.dependent) {
                p = 1.0 - p;
            }
            if (Double.isNaN(v = lnQ + FastMath.log(p)) || Double.isInfinite(v)) continue;
            lnQ = v;
        }
        return lnQ;
    }

    private Map<Graph, Double> normalProbs(Map<Graph, Double> pagLnProbs) {
        double lnQTotal = this.lnQTotal(pagLnProbs);
        HashMap<Graph, Double> normalized = new HashMap<Graph, Double>();
        for (Graph pag : pagLnProbs.keySet()) {
            double lnQ = pagLnProbs.get(pag);
            double normalizedlnQ = lnQ - lnQTotal;
            normalized.put(pag, FastMath.exp(normalizedlnQ));
        }
        return normalized;
    }

    private BayesIm loadBayesIm(String filename, boolean useDisplayNames) {
        try {
            Builder builder = new Builder();
            File dir = new File(this.directory + "/xdsl");
            File file = new File(dir, filename);
            Document document = builder.build(file);
            XdslXmlParser parser = new XdslXmlParser();
            parser.setUseDisplayNames(useDisplayNames);
            return parser.getBayesIm(document.getRootElement());
        }
        catch (IOException | ParsingException e) {
            throw new RuntimeException(e);
        }
    }

    protected double lnXplusY(double lnX, double lnY) {
        double lnYminusLnX;
        if (lnY > lnX) {
            double temp = lnX;
            lnX = lnY;
            lnY = temp;
        }
        if ((lnYminusLnX = lnY - lnX) < -1022.0) {
            return lnX;
        }
        double w = FastMath.log1p(FastMath.exp(lnYminusLnX));
        return w + lnX;
    }

    private double lnQTotal(Map<Graph, Double> pagLnProb) {
        Set<Graph> pags = pagLnProb.keySet();
        Iterator<Graph> iter = pags.iterator();
        double lnQTotal = pagLnProb.get(iter.next());
        while (iter.hasNext()) {
            Graph pag = iter.next();
            double lnQ = pagLnProb.get(pag);
            lnQTotal = this.lnXplusY(lnQTotal, lnQ);
        }
        return lnQTotal;
    }

    public DataSet bootStrapSampling(DataSet data, int numBootstrapSamples, int bootsrapSampleSize) {
        return DataUtils.getBootstrapSample(data, bootsrapSampleSize);
    }

    private static class allScores {
        Map<Graph, Double> LnBSCD;
        Map<Graph, Double> LnBSCI;

        allScores(Map<Graph, Double> LnBSCD, Map<Graph, Double> LnBSCI) {
            this.LnBSCD = LnBSCD;
            this.LnBSCI = LnBSCI;
        }
    }

    private static class MapUtil {
        private MapUtil() {
        }

        public static <K, V extends Comparable<? super V>> Map<K, V> sortByValue(Map<K, V> map) {
            LinkedList<Map.Entry<K, V>> list = new LinkedList<Map.Entry<K, V>>(map.entrySet());
            Collections.sort(list, new Comparator<Map.Entry<K, V>>(){

                @Override
                public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2) {
                    return ((Comparable)o2.getValue()).compareTo(o1.getValue());
                }
            });
            LinkedHashMap result = new LinkedHashMap();
            for (Map.Entry entry : list) {
                result.put(entry.getKey(), (Comparable)entry.getValue());
            }
            return result;
        }
    }
}

