/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.dbmi.algo.resampling;

import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.GraphTools;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.TetradLogger;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingSearch;
import edu.pitt.dbmi.algo.resampling.ResamplingEdgeEnsemble;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;

public class GeneralResamplingTest {
    private final GeneralResamplingSearch resamplingSearch;
    private final ResamplingEdgeEnsemble edgeEnsemble;
    private ScoreWrapper scoreWrapper;
    private PrintStream out = System.out;
    private Parameters parameters;
    private Algorithm algorithm;
    private MultiDataSetAlgorithm multiDataSetAlgorithm;
    private List<Graph> graphs;
    private boolean verbose;
    private Knowledge knowledge = new Knowledge();
    private Graph externalGraph;

    public GeneralResamplingTest(DataSet data, Algorithm algorithm, int numberResampling, double percentResamplingSize, boolean resamplingWithReplacement, int edgeEnsemble, boolean addOriginalDataset) {
        this.algorithm = algorithm;
        this.resamplingSearch = new GeneralResamplingSearch(data, numberResampling);
        this.resamplingSearch.setPercentResampleSize(percentResamplingSize);
        this.resamplingSearch.setResamplingWithReplacement(resamplingWithReplacement);
        this.resamplingSearch.setAddOriginalDataset(addOriginalDataset);
        switch (edgeEnsemble) {
            case 1: {
                this.edgeEnsemble = ResamplingEdgeEnsemble.Preserved;
                break;
            }
            case 2: {
                this.edgeEnsemble = ResamplingEdgeEnsemble.Highest;
                break;
            }
            case 3: {
                this.edgeEnsemble = ResamplingEdgeEnsemble.Majority;
                break;
            }
            default: {
                throw new IllegalArgumentException("Expecting 1, 2, or 3.");
            }
        }
    }

    public GeneralResamplingTest(List<DataSet> dataSets, MultiDataSetAlgorithm multiDataSetAlgorithm, int numberResampling, double percentResamplingSize, boolean resamplingWithReplacement, int edgeEnsemble, boolean addOriginalDataset) {
        this.multiDataSetAlgorithm = multiDataSetAlgorithm;
        this.resamplingSearch = new GeneralResamplingSearch(dataSets, numberResampling);
        this.resamplingSearch.setPercentResampleSize(percentResamplingSize);
        this.resamplingSearch.setResamplingWithReplacement(resamplingWithReplacement);
        this.resamplingSearch.setAddOriginalDataset(addOriginalDataset);
        switch (edgeEnsemble) {
            case 1: {
                this.edgeEnsemble = ResamplingEdgeEnsemble.Preserved;
                break;
            }
            case 2: {
                this.edgeEnsemble = ResamplingEdgeEnsemble.Highest;
                break;
            }
            case 3: {
                this.edgeEnsemble = ResamplingEdgeEnsemble.Majority;
                break;
            }
            default: {
                throw new IllegalArgumentException("Expecting 1, 2, or 3.");
            }
        }
    }

    public static int[][] getAdjConfusionMatrix(Graph truth, Graph estimate) {
        EdgeListGraph complete = new EdgeListGraph(estimate.getNodes());
        complete.fullyConnect(Endpoint.TAIL);
        ArrayList<Edge> edges = new ArrayList<Edge>(complete.getEdges());
        int numEdges = edges.size();
        int[][] adjAr = new int[2][2];
        GeneralResamplingTest.countAdjConfMatrix(adjAr, edges, truth, estimate, 0, numEdges - 1);
        return adjAr;
    }

    private static void countAdjConfMatrix(int[][] adjAr, List<Edge> edges, Graph truth, Graph estimate, int start, int end) {
        if (start == end) {
            Edge edge = edges.get(start);
            Node n1 = truth.getNode(edge.getNode1().toString());
            Node n2 = truth.getNode(edge.getNode2().toString());
            Node nn1 = estimate.getNode(edge.getNode1().toString());
            Node nn2 = estimate.getNode(edge.getNode2().toString());
            int i = truth.getEdge(n1, n2) == null ? 0 : 1;
            int j = estimate.getEdge(nn1, nn2) == null ? 0 : 1;
            int[] nArray = adjAr[i];
            int n = j;
            nArray[n] = nArray[n] + 1;
        } else if (start < end) {
            int mid = (start + end) / 2;
            GeneralResamplingTest.countAdjConfMatrix(adjAr, edges, truth, estimate, start, mid);
            GeneralResamplingTest.countAdjConfMatrix(adjAr, edges, truth, estimate, mid + 1, end);
        }
    }

    public static int[][] getEdgeTypeConfusionMatrix(Graph truth, Graph estimate) {
        EdgeListGraph complete = new EdgeListGraph(estimate.getNodes());
        complete.fullyConnect(Endpoint.TAIL);
        ArrayList<Edge> edges = new ArrayList<Edge>(complete.getEdges());
        int numEdges = edges.size();
        int[][] edgeAr = new int[8][8];
        GeneralResamplingTest.countEdgeTypeConfMatrix(edgeAr, edges, truth, estimate, 0, numEdges - 1);
        return edgeAr;
    }

    private static void countEdgeTypeConfMatrix(int[][] edgeAr, List<Edge> edges, Graph truth, Graph estimate, int start, int end) {
        if (start == end) {
            Edge eTT;
            Edge eAA;
            Edge eCC;
            Edge eAC;
            Edge eCA;
            Edge eAT;
            Edge edge = edges.get(start);
            Node n1 = truth.getNode(edge.getNode1().toString());
            Node n2 = truth.getNode(edge.getNode2().toString());
            Node nn1 = estimate.getNode(edge.getNode1().toString());
            Node nn2 = estimate.getNode(edge.getNode2().toString());
            int i = 0;
            int j = 0;
            Edge eTA = new Edge(n1, n2, Endpoint.TAIL, Endpoint.ARROW);
            if (truth.containsEdge(eTA)) {
                i = 1;
            }
            if (estimate.containsEdge(eTA = new Edge(nn1, nn2, Endpoint.TAIL, Endpoint.ARROW))) {
                j = 1;
            }
            if (truth.containsEdge(eAT = new Edge(n1, n2, Endpoint.ARROW, Endpoint.TAIL))) {
                i = 2;
            }
            if (estimate.containsEdge(eAT = new Edge(nn1, nn2, Endpoint.ARROW, Endpoint.TAIL))) {
                j = 2;
            }
            if (truth.containsEdge(eCA = new Edge(n1, n2, Endpoint.CIRCLE, Endpoint.ARROW))) {
                i = 3;
            }
            if (estimate.containsEdge(eCA = new Edge(nn1, nn2, Endpoint.CIRCLE, Endpoint.ARROW))) {
                j = 3;
            }
            if (truth.containsEdge(eAC = new Edge(n1, n2, Endpoint.ARROW, Endpoint.CIRCLE))) {
                i = 4;
            }
            if (estimate.containsEdge(eAC = new Edge(nn1, nn2, Endpoint.ARROW, Endpoint.CIRCLE))) {
                j = 4;
            }
            if (truth.containsEdge(eCC = new Edge(n1, n2, Endpoint.CIRCLE, Endpoint.CIRCLE))) {
                i = 5;
            }
            if (estimate.containsEdge(eCC = new Edge(nn1, nn2, Endpoint.CIRCLE, Endpoint.CIRCLE))) {
                j = 5;
            }
            if (truth.containsEdge(eAA = new Edge(n1, n2, Endpoint.ARROW, Endpoint.ARROW))) {
                i = 6;
            }
            if (estimate.containsEdge(eAA = new Edge(nn1, nn2, Endpoint.ARROW, Endpoint.ARROW))) {
                j = 6;
            }
            if (truth.containsEdge(eTT = new Edge(n1, n2, Endpoint.TAIL, Endpoint.TAIL))) {
                i = 7;
            }
            if (estimate.containsEdge(eTT = new Edge(nn1, nn2, Endpoint.TAIL, Endpoint.TAIL))) {
                j = 7;
            }
            int[] nArray = edgeAr[i];
            int n = j;
            nArray[n] = nArray[n] + 1;
        } else if (start < end) {
            int mid = (start + end) / 2;
            GeneralResamplingTest.countEdgeTypeConfMatrix(edgeAr, edges, truth, estimate, start, mid);
            GeneralResamplingTest.countEdgeTypeConfMatrix(edgeAr, edges, truth, estimate, mid + 1, end);
        }
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public PrintStream getOut() {
        return this.out;
    }

    public void setOut(PrintStream out) {
        this.out = out;
    }

    public void setParameters(Parameters parameters) {
        this.parameters = parameters;
        Object obj = parameters.get("printStream");
        if (obj instanceof PrintStream) {
            this.setOut((PrintStream)obj);
        }
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public void setExternalGraph(Graph externalGraph) {
        this.externalGraph = externalGraph;
    }

    public Graph search() {
        long start = MillisecondTimes.timeMillis();
        if (this.algorithm != null) {
            this.resamplingSearch.setAlgorithm(this.algorithm);
        } else {
            this.resamplingSearch.setMultiDataSetAlgorithm(this.multiDataSetAlgorithm);
        }
        boolean runParallel = true;
        this.resamplingSearch.setRunParallel(runParallel);
        this.resamplingSearch.setVerbose(this.verbose);
        this.resamplingSearch.setParameters(this.parameters);
        this.resamplingSearch.setScoreWrapper(this.scoreWrapper);
        if (!this.knowledge.isEmpty()) {
            this.resamplingSearch.setKnowledge(this.knowledge);
        }
        if (this.externalGraph != null) {
            this.resamplingSearch.setExternalGraph(this.externalGraph);
        }
        if (this.verbose) {
            if (this.algorithm != null) {
                this.out.println("Resampling on the " + this.algorithm.getDescription());
            } else if (this.multiDataSetAlgorithm != null) {
                this.out.println("Resampling on the " + this.multiDataSetAlgorithm.getDescription());
            }
        }
        this.graphs = this.resamplingSearch.search();
        int numNoGraphs = this.resamplingSearch.getNumNograph();
        TetradLogger.getInstance().forceLogMessage("Bootstrappiung: Number of searches that didn't return a graph = " + numNoGraphs);
        if (this.verbose) {
            this.out.println("Resampling number is : " + this.graphs.size());
        }
        long stop = MillisecondTimes.timeMillis();
        if (this.verbose) {
            this.out.println("Processing time of total resamplings : " + (double)(stop - start) / 1000.0 + " sec");
        }
        start = MillisecondTimes.timeMillis();
        Graph graph = GraphTools.createHighEdgeProbabilityGraph(this.graphs, this.edgeEnsemble);
        stop = MillisecondTimes.timeMillis();
        TetradLogger.getInstance().forceLogMessage("Final Resampling Search Result:");
        TetradLogger.getInstance().forceLogMessage(GraphUtils.graphToText(graph, false));
        TetradLogger.getInstance().forceLogMessage("probDistribution finished in " + (stop - start) + " ms");
        return graph;
    }

    public void setScoreWrapper(ScoreWrapper scoreWrapper) {
        this.scoreWrapper = scoreWrapper;
    }
}

