/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphConverter;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.search.Ges;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestGes
extends TestCase {
    public TestGes(String name) {
        super(name);
    }

    @Override
    public void setUp() throws Exception {
        TetradLogger.getInstance().addOutputStream(System.out);
        TetradLogger.getInstance().setForceLog(true);
    }

    public void testBlank() {
    }

    public void rtestSearch1() {
        this.checkSearch("X1-->X2,X1-->X3,X2-->X4,X3-->X4", "X1---X2,X1---X3,X2-->X4,X3-->X4");
    }

    public void rtestSearch2() {
        this.checkSearch("A-->D,A-->B,B-->D,C-->D,D-->E", "A-->D,A---B,B-->D,C-->D,D-->E");
    }

    public void rtestSearch3() {
        Knowledge knowledge = new Knowledge();
        knowledge.setEdgeForbidden("B", "D", true);
        knowledge.setEdgeForbidden("D", "B", true);
        knowledge.setEdgeForbidden("C", "B", true);
        this.checkWithKnowledge("A-->B,C-->B,B-->D", "A---B,C---A,B-->C,C-->D,A-->D", knowledge);
    }

    public void rtestSearch3_5() {
        Dag dag = GraphUtils.randomDag(20, 0, 20, 5, 5, 5, false);
        System.out.println(dag);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        DataSet dataSet = im.simulateData(100, false);
        Ges ges = new Ges(dataSet);
        Graph graph = ges.search();
        System.out.println(graph);
        Graph dag2 = SearchGraphUtils.chooseDagInPattern(graph);
        SemPm pm2 = new SemPm(dag2);
        SemEstimator est = new SemEstimator(dataSet, pm2);
        est.estimate();
        SemIm estIm = est.getEstimatedSem();
        double estBicScore = estIm.getBicScore();
        System.out.println("Estimate BIC = " + estBicScore);
        double gesBicScore = ges.scoreGraph(dag);
        System.out.println("GES score = " + gesBicScore);
        System.out.println("bic / ges = " + gesBicScore / estBicScore);
    }

    public void rtestSearch4() {
        int numVars;
        int numEdges = numVars = 40;
        int sampleSize = 200;
        boolean latentDataSaved = false;
        Dag trueGraph = GraphUtils.randomDag(numVars, 0, numEdges, 7, 5, 5, false);
        System.out.println("\nInput graph:");
        System.out.println(trueGraph);
        SemPm pm = new SemPm(trueGraph);
        SemIm im = new SemIm(pm);
        DataSet dataSet = im.simulateData(sampleSize, false);
        Ges ges = new Ges(dataSet);
        ges.setTrueGraph(trueGraph);
        Graph pattern = ges.search();
        System.out.println("\nResult graph:");
        System.out.println(pattern);
        int adjFp = GraphUtils.countAdjErrors(pattern, trueGraph);
        int adjFn = GraphUtils.countAdjErrors(trueGraph, pattern);
        System.out.println("adj fp = " + adjFp + " adjFn = " + adjFn);
    }

    public void rtestSearch4a() {
        int numVars;
        int numEdges = numVars = 30;
        int sampleSize = 1000;
        int numIterations = 10;
        double sumFp = 0.0;
        double sumFn = 0.0;
        DecimalFormat nf = new DecimalFormat("0.00");
        System.out.println("\tADJ_FP\tADJ_FN");
        for (int count = 0; count < numIterations; ++count) {
            Dag trueGraph = GraphUtils.randomDag(numVars, 0, numEdges, 7, 5, 5, false);
            SemPm pm = new SemPm(trueGraph);
            SemIm im = new SemIm(pm);
            DataSet dataSet = im.simulateData(sampleSize, false);
            Ges ges = new Ges(dataSet);
            ges.setTrueGraph(trueGraph);
            Graph pattern = ges.search();
            int adjFp = GraphUtils.countAdjErrors(pattern, trueGraph);
            int adjFn = GraphUtils.countAdjErrors(trueGraph, pattern);
            sumFp += (double)adjFp;
            sumFn += (double)adjFn;
            System.out.println(count + 1 + "\t" + adjFp + "\t" + adjFn);
        }
        double avgFp = sumFp / (double)numIterations;
        double avgFn = sumFn / (double)numIterations;
        System.out.println("Means\t" + nf.format(avgFp) + "\t" + nf.format(avgFn));
    }

    public void testSearch5() {
        int numVars = 10;
        int numEdges = 20;
        int sampleSize = 20000;
        Dag trueGraph = GraphUtils.randomDag(numVars, 0, numEdges, 7, 5, 5, false);
        System.out.println("\nInput graph:");
        System.out.println(trueGraph);
        System.out.println("********** SAMPLE SIZE = " + sampleSize);
        SemPm semPm = new SemPm(trueGraph);
        SemIm bayesIm = new SemIm(semPm);
        DataSet dataSet = bayesIm.simulateData(sampleSize, false);
        Ges ges = new Ges(dataSet);
        ges.setTrueGraph(trueGraph);
        Graph resultGraph = ges.search();
        System.out.println("\nResult graph:");
        System.out.println(resultGraph);
    }

    public void testSearch6() {
        Dag trueGraph = GraphUtils.randomDag(10, 10, false);
        int sampleSize = 1000;
        SemPm semPm = new SemPm(trueGraph);
        SemIm bayesIm = new SemIm(semPm);
        DataSet dataSet = bayesIm.simulateData(sampleSize, false);
        Ges ges = new Ges(dataSet);
        Graph pattern = ges.search();
        System.out.println("True graph = " + SearchGraphUtils.patternForDag(trueGraph));
        System.out.println("Pattern = " + pattern);
    }

    private void checkSearch(String inputGraph, String outputGraph) {
        Graph graph = GraphConverter.convert(inputGraph);
        SemPm semPm = new SemPm(graph);
        SemIm semIM = new SemIm(semPm);
        DataSet dataSet = semIM.simulateData(500, false);
        Ges ges = new Ges(dataSet);
        ges.setTrueGraph(graph);
        Graph resultGraph = ges.search();
        Graph trueGraph = GraphConverter.convert(outputGraph);
        System.out.println("\nInput graph:");
        System.out.println(graph);
        System.out.println("\nResult graph:");
        System.out.println(resultGraph);
        TestGes.assertTrue(((Object)resultGraph).equals(trueGraph));
    }

    private void checkWithKnowledge(String inputGraph, String outputGraph, Knowledge knowledge) {
        Graph graph = GraphConverter.convert(inputGraph);
        SemPm semPm = new SemPm(graph);
        SemIm semIM = new SemIm(semPm);
        DataSet dataSet = semIM.simulateData(1000, false);
        Ges ges = new Ges(dataSet);
        ges.setKnowledge(knowledge);
        Graph resultGraph = ges.search();
        System.out.println(knowledge);
        System.out.println("Input graph:");
        System.out.println(graph);
        System.out.println("Result graph:");
        System.out.println(resultGraph);
        Graph trueGraph = GraphConverter.convert(outputGraph);
        TestGes.assertTrue(((Object)resultGraph).equals(trueGraph));
    }

    public static Test suite() {
        return new TestSuite(TestGes.class);
    }
}

