/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.calibration;

import edu.cmu.tetrad.calibration.BootstrapWorker;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.RandomGraph;
import edu.cmu.tetrad.search.BfciFoo;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.Rfci;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.search.SemBicScore;
import edu.cmu.tetrad.sem.LargeScaleSimulation;
import edu.cmu.tetrad.util.MillisecondTimes;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class DataForCalibration_RFCI {
    PrintWriter outProb;
    private PrintWriter outGraph;
    private PrintWriter outPag;
    public int depth = 5;

    public static void main(String[] args) throws IOException {
        String algorithm = "";
        double alpha = 0.05;
        double numLatentConfounders = 0.1;
        int numVars = 0;
        int numCases = 0;
        int edgesPerNode = 0;
        int numBootstrapSamples = 0;
        int seedIndex = -1;
        String data_path = System.getProperty("user.dir");
        System.out.println(Arrays.asList(args));
        block24: for (int i = 0; i < args.length; ++i) {
            switch (args[i]) {
                case "-v": {
                    numVars = Integer.parseInt(args[i + 1]);
                    continue block24;
                }
                case "-p": {
                    edgesPerNode = Integer.parseInt(args[i + 1]);
                    continue block24;
                }
                case "-c": {
                    numCases = Integer.parseInt(args[i + 1]);
                    continue block24;
                }
                case "-l": {
                    numLatentConfounders = Double.parseDouble(args[i + 1]);
                    continue block24;
                }
                case "-b": {
                    numBootstrapSamples = Integer.parseInt(args[i + 1]);
                    continue block24;
                }
                case "-s": {
                    seedIndex = Integer.parseInt(args[i + 1]);
                    continue block24;
                }
                case "-a": {
                    algorithm = args[i + 1];
                    continue block24;
                }
                case "-alpha": {
                    alpha = Double.parseDouble(args[i + 1]);
                    continue block24;
                }
                case "-dir": {
                    data_path = args[i + 1];
                }
            }
        }
        DataForCalibration_RFCI DFC = new DataForCalibration_RFCI();
        int numEdges = numVars * edgesPerNode;
        boolean probFileExists = DFC.checkProbFileExists("RandomGraph", numVars, numEdges, numCases, numBootstrapSamples, algorithm, seedIndex, numLatentConfounders, alpha, data_path);
        if (probFileExists) {
            String dirname = data_path + "/CalibrationConstraintBased/" + algorithm + "/RandomGraph-Vars" + numVars + "-Edges" + numEdges + "-Cases" + numCases + "-BS" + numBootstrapSamples + "-H" + numLatentConfounders + "-a" + alpha;
            String probFileName = "probs_v" + numVars + "_e" + numEdges + "_c" + numCases + "_b" + numBootstrapSamples + "_" + seedIndex + ".txt";
            System.out.println("Warning: The program stopped because the Prob File already exists in the following path: \n" + dirname + "/" + probFileName);
            return;
        }
        String ConfigString = String.valueOf(FastMath.random());
        System.out.println(ConfigString + ": Started!");
        int LV = (int)FastMath.floor(numLatentConfounders * (double)numVars);
        System.out.println("LV: " + LV);
        Graph dag = DFC.makeDAG(numVars, edgesPerNode, LV);
        System.out.println("Graph simulation done");
        Graph truePag = SearchGraphUtils.dagToPag(dag);
        System.out.println("true PAG construction Done!");
        truePag = GraphUtils.replaceNodes(truePag, dag.getNodes());
        LargeScaleSimulation simulator = new LargeScaleSimulation(dag);
        DataSet data = simulator.simulateDataReducedForm(numCases);
        data = DataUtils.restrictToMeasured(data);
        System.out.println("Data simulation done");
        System.out.println("Covariance matrix done");
        long time1 = MillisecondTimes.timeMillis();
        IndTestFisherZ test = new IndTestFisherZ(data, 0.001);
        SemBicScore score = new SemBicScore(data);
        score.setPenaltyDiscount(2.0);
        System.out.println("Starting search with all data");
        BfciFoo fci = new BfciFoo(test, score);
        fci.setVerbose(false);
        fci.setCompleteRuleSetUsed(true);
        fci.setDepth(DFC.depth);
        Graph estPag = fci.search();
        System.out.println("Search done with all data");
        long time2 = MillisecondTimes.timeMillis();
        System.out.println("Elapsed (running RFCI on the data): " + (time2 - time1) / 1000L + " sec");
        estPag = GraphUtils.replaceNodes(estPag, truePag.getNodes());
        System.out.println("Generating bootstrap samples from data");
        ArrayList<Graph> BNfromBootstrap = new ArrayList<Graph>();
        BootstrapWorker.alpha = alpha;
        BootstrapWorker.DFC = DFC;
        BootstrapWorker.truePag = truePag;
        BootstrapWorker.BootstrapNum = numBootstrapSamples;
        long start = MillisecondTimes.timeMillis();
        for (int i1 = 0; i1 < numBootstrapSamples; ++i1) {
            DataSet bootstrapSample = DFC.bootStrapSampling(data, data.getNumRows());
            if (!algorithm.equals("RFCI")) {
                System.out.println("invalid search algorithm");
                return;
            }
            BootstrapWorker tmp = new BootstrapWorker(bootstrapSample, BNfromBootstrap);
            BootstrapWorker.addToWaitingList(tmp);
        }
        try {
            BootstrapWorker.executeThreads_and_wait();
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        long stop = MillisecondTimes.timeMillis();
        System.out.println("Bootstrap finished in " + (stop - start) + " ms");
        System.out.println("Probability estimates...");
        EdgeFrequency frequency = new EdgeFrequency(BNfromBootstrap);
        boolean set = DFC.setOut("RandomGraph", numVars, numEdges, numCases, numBootstrapSamples, algorithm, seedIndex, numLatentConfounders, alpha, data_path);
        if (!set) {
            return;
        }
        start = MillisecondTimes.timeMillis();
        DFC.probDistribution(truePag, estPag, frequency, DFC.outProb, algorithm);
        stop = MillisecondTimes.timeMillis();
        System.out.println("probDistribution finished in " + (stop - start) + " ms");
        System.out.println("Writing Probs File: done!");
        DFC.print("Graph: " + dag, DFC.outGraph);
        DFC.print("\n\n", DFC.outGraph);
        DFC.print("Pag:" + truePag, DFC.outPag);
        DFC.print("\n\n", DFC.outPag);
        DFC.outProb.close();
        DFC.outGraph.close();
        DFC.outPag.close();
        System.out.println(ConfigString + ": Done!");
    }

    private void probDistribution(Graph trueBN, Graph gesOut, EdgeFrequency frequency, PrintWriter outP, String algorithm) {
        this.print("A, B, 0-7, A  B, A --> B, B --> A, A o-> B, B o-> A, A o-o B, A <-> B, A --- B, " + algorithm + " \n", outP);
        EdgeListGraph complete = new EdgeListGraph(trueBN.getNodes());
        complete.fullyConnect(Endpoint.TAIL);
        for (Edge e : complete.getEdges()) {
            Node n2;
            int trueType = 0;
            int estType = 0;
            Node n1 = e.getNode1();
            if (trueBN.getEdge(n1, n2 = e.getNode2()) != null) {
                Endpoint p1 = trueBN.getEdge(n1, n2).getEndpoint1();
                Endpoint p2 = trueBN.getEdge(n1, n2).getEndpoint2();
                if (p1 == Endpoint.TAIL && p2 == Endpoint.ARROW) {
                    trueType = 1;
                } else if (p1 == Endpoint.ARROW && p2 == Endpoint.TAIL) {
                    trueType = 2;
                } else if (p1 == Endpoint.CIRCLE && p2 == Endpoint.ARROW) {
                    trueType = 3;
                } else if (p1 == Endpoint.ARROW && p2 == Endpoint.CIRCLE) {
                    trueType = 4;
                } else if (p1 == Endpoint.CIRCLE && p2 == Endpoint.CIRCLE) {
                    trueType = 5;
                } else if (p1 == Endpoint.ARROW && p2 == Endpoint.ARROW) {
                    trueType = 6;
                } else if (p1 == Endpoint.TAIL && p2 == Endpoint.TAIL) {
                    trueType = 7;
                }
            }
            Edge e1 = new Edge(n1, n2, Endpoint.NULL, Endpoint.NULL);
            double AnilB = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.TAIL, Endpoint.ARROW);
            double AtoB = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.ARROW, Endpoint.TAIL);
            double BtoA = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.CIRCLE, Endpoint.ARROW);
            double ACtoB = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.ARROW, Endpoint.CIRCLE);
            double BCtoA = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.CIRCLE, Endpoint.CIRCLE);
            double AccB = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.ARROW, Endpoint.ARROW);
            double AbB = frequency.getProbability(e1);
            e1 = new Edge(n1, n2, Endpoint.TAIL, Endpoint.TAIL);
            double AuB = frequency.getProbability(e1);
            if (gesOut.getEdge(n1, n2) != null) {
                Endpoint p1 = gesOut.getEdge(n1, n2).getEndpoint1();
                Endpoint p2 = gesOut.getEdge(n1, n2).getEndpoint2();
                if (p1 == Endpoint.TAIL && p2 == Endpoint.ARROW) {
                    estType = 1;
                } else if (p1 == Endpoint.ARROW && p2 == Endpoint.TAIL) {
                    estType = 2;
                } else if (p1 == Endpoint.CIRCLE && p2 == Endpoint.ARROW) {
                    estType = 3;
                } else if (p1 == Endpoint.ARROW && p2 == Endpoint.CIRCLE) {
                    estType = 4;
                } else if (p1 == Endpoint.CIRCLE && p2 == Endpoint.CIRCLE) {
                    estType = 5;
                } else if (p1 == Endpoint.ARROW && p2 == Endpoint.ARROW) {
                    estType = 6;
                } else if (p1 == Endpoint.TAIL && p2 == Endpoint.TAIL) {
                    estType = 7;
                }
            }
            this.print(n1 + ", " + n2 + ", " + trueType + ", " + AnilB + ", " + AtoB + ", " + BtoA + ", " + ACtoB + ", " + BCtoA + ", " + AccB + ", " + AbB + ", " + AuB + ", " + estType + "\n", outP);
        }
    }

    public Graph makeDAG(int numVars, double edgesPerNode, int numLatentConfounders) {
        int numEdges = (int)((double)numVars * edgesPerNode);
        System.out.println("Making list of vars");
        ArrayList<Node> vars = new ArrayList<Node>();
        for (int i = 0; i < numVars; ++i) {
            vars.add(new ContinuousVariable(Integer.toString(i)));
        }
        System.out.println("Making dag");
        return RandomGraph.randomGraphRandomForwardEdges(vars, numLatentConfounders, numEdges, 30, 15, 15, false, true);
    }

    public DataSet bootStrapSampling(DataSet data, int bootsrapSampleSize) {
        return DataUtils.getBootstrapSample(data, bootsrapSampleSize);
    }

    public Graph learnBNRFCI(DataSet bootstrapSample, int depth, Graph truePag) {
        IndTestFisherZ test = new IndTestFisherZ(bootstrapSample, 0.001);
        SemBicScore score = new SemBicScore(bootstrapSample);
        score.setPenaltyDiscount(2.0);
        System.out.println("Starting search with a bootstrap");
        Rfci fci = new Rfci(test);
        fci.setVerbose(false);
        fci.setDepth(depth);
        Graph estPag = fci.search();
        estPag = GraphUtils.replaceNodes(estPag, truePag.getNodes());
        System.out.println("Search done with a bootstrap");
        return estPag;
    }

    public boolean checkProbFileExists(String modelName, int numVars, int numEdges, int numCases, int numBootstrapSamples, String alg, int i, double numLatentConfounders, double alpha, String data_path) {
        String dirname = data_path + "/CalibrationConstraintBased/" + alg + "/" + modelName + "-Vars" + numVars + "-Edges" + numEdges + "-Cases" + numCases + "-BS" + numBootstrapSamples + "-H" + numLatentConfounders + "-a" + alpha;
        File dir = new File(dirname);
        dir.mkdirs();
        String probFileName = "probs_v" + numVars + "_e" + numEdges + "_c" + numCases + "_b" + numBootstrapSamples + "_" + i + ".txt";
        File probFile = new File(dir, probFileName);
        return probFile.exists();
    }

    public boolean setOut(String modelName, int numVars, int numEdges, int numCases, int numBootstrapSamples, String alg, int i, double numLatentConfounders, double alpha, String data_path) {
        try {
            String dirname = data_path + "/CalibrationConstraintBased/" + alg + "/" + modelName + "-Vars" + numVars + "-Edges" + numEdges + "-Cases" + numCases + "-BS" + numBootstrapSamples + "-H" + numLatentConfounders + "-a" + alpha;
            File dir = new File(dirname);
            dir.mkdirs();
            String probFileName = "probs_v" + numVars + "_e" + numEdges + "_c" + numCases + "_b" + numBootstrapSamples + "_" + i + ".txt";
            String graphFileName = "BN_v" + numVars + "_e" + numEdges + "_c" + numCases + "_b" + numBootstrapSamples + "_" + i + ".txt";
            String PagFileName = "PAG_v" + numVars + "_e" + numEdges + "_c" + numCases + "_b" + numBootstrapSamples + "_" + i + ".txt";
            File probFile = new File(dir, probFileName);
            File graphFile = new File(dir, graphFileName);
            File PagFile = new File(dir, PagFileName);
            if (probFile.exists()) {
                return false;
            }
            this.outProb = new PrintWriter(probFile);
            this.outGraph = new PrintWriter(graphFile);
            this.outPag = new PrintWriter(PagFile);
            return true;
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
            return false;
        }
    }

    private void print(String s, PrintWriter out) {
        if (out == null) {
            return;
        }
        out.flush();
        out.print(s);
        out.flush();
    }

    private static class EdgeFrequency
    implements EdgeProbabiity {
        private final List<Graph> PagProbs;

        public EdgeFrequency(List<Graph> PagProb) {
            this.PagProbs = PagProb;
        }

        @Override
        public double getProbability(Edge e) {
            int count = 0;
            if (!this.PagProbs.get(0).containsNode(e.getNode1())) {
                throw new IllegalArgumentException();
            }
            if (!this.PagProbs.get(0).containsNode(e.getNode2())) {
                throw new IllegalArgumentException();
            }
            for (Graph g : this.PagProbs) {
                if (e.getEndpoint1() == Endpoint.NULL || e.getEndpoint2() == Endpoint.NULL) {
                    if (g.isAdjacentTo(e.getNode1(), e.getNode2())) continue;
                    ++count;
                    continue;
                }
                if (!g.containsEdge(e)) continue;
                ++count;
            }
            return (double)count / (double)this.PagProbs.size();
        }
    }

    private static interface EdgeProbabiity {
        public double getProbability(Edge var1);
    }
}

