/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphConverter;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Cpc;
import edu.cmu.tetrad.search.IndTestDSep;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.Pc;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.NumberFormat;
import java.util.LinkedList;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestCpc
extends TestCase {
    public TestCpc(String name) {
        super(name);
    }

    @Override
    public void setUp() throws Exception {
        TetradLogger.getInstance().addOutputStream(System.out);
        TetradLogger.getInstance().setForceLog(true);
    }

    @Override
    public void tearDown() {
        TetradLogger.getInstance().setForceLog(false);
        TetradLogger.getInstance().removeOutputStream(System.out);
    }

    public void testSearch1() {
        this.checkSearch("X1-->X2,X1-->X3,X2-->X4,X3-->X4", "X1---X2,X1---X3,X2-->X4,X3-->X4");
    }

    public void testSearch2() {
        this.checkSearch("A-->D,A-->B,B-->D,C-->D,D-->E", "A-->D,A---B,B-->D,C-->D,D-->E");
    }

    public void testSearch3() {
        Knowledge knowledge = new Knowledge();
        knowledge.setEdgeForbidden("B", "D", true);
        knowledge.setEdgeForbidden("D", "B", true);
        knowledge.setEdgeForbidden("C", "B", true);
        this.checkWithKnowledge("A-->B,C-->B,B-->D", "A-->B,C-->B,A-->D,C-->D", knowledge);
    }

    public void showInefficiency() {
        int numVars = 20;
        int numEdges = 20;
        int maxSample = 2000;
        boolean latentDataSaved = false;
        int increment = 1;
        Dag trueGraph = GraphUtils.randomDag(numVars, 0, numEdges, 7, 5, 5, false);
        System.out.println("\nInput graph:");
        System.out.println(trueGraph);
        SemPm semPm = new SemPm(trueGraph);
        SemIm semIm = new SemIm(semPm);
        DataSet _dataSet = semIm.simulateData(maxSample, latentDataSaved);
        Graph previousResult = null;
        for (int n = 3; n <= maxSample; n += increment) {
            int[] rows = new int[n];
            for (int i = 0; i < rows.length; ++i) {
                rows[i] = i;
            }
            DataSet dataSet = _dataSet.subsetRows(rows);
            IndTestFisherZ test = new IndTestFisherZ(dataSet, 0.05);
            Cpc search = new Cpc(test);
            Graph resultGraph = search.search();
            if (previousResult != null) {
                List<Edge> resultEdges = resultGraph.getEdges();
                List<Edge> previousEdges = previousResult.getEdges();
                LinkedList<Edge> addedEdges = new LinkedList<Edge>();
                for (Edge edge : resultEdges) {
                    if (previousEdges.contains(edge)) continue;
                    addedEdges.add(edge);
                }
                LinkedList<Edge> removedEdges = new LinkedList<Edge>();
                for (Edge edge : previousEdges) {
                    if (resultEdges.contains(edge)) continue;
                    removedEdges.add(edge);
                }
                if (!addedEdges.isEmpty() && !removedEdges.isEmpty()) {
                    System.out.println("\nn = " + n + ":");
                    if (!addedEdges.isEmpty()) {
                        System.out.println("Added: " + addedEdges);
                    }
                    if (!removedEdges.isEmpty()) {
                        System.out.println("Removed: " + removedEdges);
                    }
                }
            }
            previousResult = resultGraph;
        }
        System.out.println("Final graph = " + previousResult);
    }

    public void test7() {
        int numVars = 6;
        int numEdges = 6;
        Dag trueGraph = GraphUtils.randomDag(numVars, 0, numEdges, 7, 5, 5, false);
        System.out.println("\nInput graph:");
        System.out.println(trueGraph);
        SemPm semPm = new SemPm(trueGraph);
        SemIm semIm = new SemIm(semPm);
        DataSet _dataSet = semIm.simulateData(1000, false);
        IndTestFisherZ test = new IndTestFisherZ(_dataSet, 0.05);
        Cpc search = new Cpc(test);
        Graph resultGraph = search.search();
    }

    public void tripleAccuracy() {
        boolean success = false;
        boolean fail = false;
        boolean totBidirected = false;
        boolean numCyclic = false;
        int numCorrectColliders = 0;
        int numCorrectNoncolliders = 0;
        int numEstColliders = 0;
        int numEstNoncolliders = 0;
        int numEstAmbiguous = 0;
        for (int i = 0; i < 100; ++i) {
            int[] choice;
            TetradLogger.getInstance().log("info", "# " + (i + 1));
            Dag graph = GraphUtils.randomDag(20, 0, 20, 4, 4, 4, false);
            SemPm pm = new SemPm(graph);
            SemIm im = new SemIm(pm);
            DataSet dataSet = im.simulateData(1000, false);
            IndTestFisherZ test = new IndTestFisherZ(dataSet, 0.05);
            Cpc search = new Cpc(test);
            Graph graph2 = search.search();
            ChoiceGenerator cg = new ChoiceGenerator(graph.getNumNodes(), 3);
            List<Node> nodes = graph.getNodes();
            while ((choice = cg.next()) != null) {
                Node node0 = nodes.get(choice[0]);
                Node node1 = nodes.get(choice[1]);
                Node node2 = nodes.get(choice[2]);
                Node node02 = graph2.getNode(node0.getName());
                Node node12 = graph2.getNode(node1.getName());
                Node node22 = graph2.getNode(node2.getName());
                if (!graph2.isAdjacentTo(node02, node12) || !graph2.isAdjacentTo(node12, node22)) continue;
                if (graph2.isAmbiguousTriple(node02, node12, node22)) {
                    ++numEstAmbiguous;
                    continue;
                }
                if (graph2.isDefCollider(node02, node12, node22)) {
                    ++numEstColliders;
                    if (!graph.isDefCollider(node0, node1, node2)) continue;
                    ++numCorrectColliders;
                    continue;
                }
                ++numEstNoncolliders;
                if (!graph.isAdjacentTo(node0, node1) || !graph.isAdjacentTo(node1, node2) || graph.isDefCollider(node0, node1, node2)) continue;
                ++numCorrectNoncolliders;
            }
        }
        double percentCorrectColliders = 100.0 * ((double)numCorrectColliders / (double)numEstColliders);
        double percentCorrectNoncolliders = 100.0 * ((double)numCorrectNoncolliders / (double)numEstNoncolliders);
        double percentAmbiguousTriples = 100.0 * ((double)numEstAmbiguous / (double)(numEstColliders + numEstNoncolliders));
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        TetradLogger.getInstance().log("info", "# estimated colliders = " + numEstColliders);
        TetradLogger.getInstance().log("info", "# estimated noncolliders = " + numEstNoncolliders);
        TetradLogger.getInstance().log("info", "# estimated ambiguous = " + numEstAmbiguous);
        TetradLogger.getInstance().log("info", "# correct colliders = " + numCorrectColliders);
        TetradLogger.getInstance().log("info", "# correct noncolliders = " + numCorrectNoncolliders);
        TetradLogger.getInstance().log("info", "% correct colliders = " + nf.format(percentCorrectColliders));
        TetradLogger.getInstance().log("info", "% correct noncolliders = " + nf.format(percentCorrectNoncolliders));
        TetradLogger.getInstance().log("info", "% ambiguous triples = " + nf.format(percentAmbiguousTriples));
    }

    private void checkSearch(String inputGraph, String outputGraph) {
        Graph graph = GraphConverter.convert(inputGraph);
        IndTestDSep independence = new IndTestDSep(graph);
        Pc pcSearch = new Pc(independence);
        Graph resultGraph = pcSearch.search();
        Graph trueGraph = GraphConverter.convert(outputGraph);
        System.out.println("\nInput graph:");
        System.out.println(graph);
        System.out.println("\nResult graph:");
        System.out.println(resultGraph);
        System.out.println("\nTrue graph:");
        System.out.println(trueGraph);
        TestCpc.assertTrue(((Object)resultGraph).equals(trueGraph));
    }

    private void checkWithKnowledge(String inputGraph, String outputGraph, Knowledge knowledge) {
        Graph graph = GraphConverter.convert(inputGraph);
        SemPm semPm = new SemPm(graph);
        SemIm semIM = new SemIm(semPm);
        DataSet dataSet = semIM.simulateData(1000, false);
        IndTestFisherZ independence = new IndTestFisherZ(dataSet, 0.05);
        Cpc cpc = new Cpc(independence);
        cpc.setKnowledge(knowledge);
        Graph resultGraph = cpc.search();
        GraphConverter.convert(outputGraph);
        System.out.println(knowledge);
        System.out.println("\nInput graph:");
        System.out.println(graph);
        System.out.println("\nResult graph:");
        System.out.println(resultGraph);
    }

    public static Test suite() {
        return new TestSuite(TestCpc.class);
    }
}

