/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndTestFisherZGeneralizedInverse;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestIndTestFisherZD
extends TestCase {
    public TestIndTestFisherZD(String name) {
        super(name);
    }

    public void testIsIndependent() {
        Graph graph = this.constructGraph();
        SemPm semPm = new SemPm(graph);
        SemIm semIm = new SemIm(semPm);
        System.out.println("Original SemIm: " + semIm);
        DataSet dataSet = semIm.simulateData(400, false);
        IndTestFisherZGeneralizedInverse test = new IndTestFisherZGeneralizedInverse(dataSet, 0.05);
        List<Node> v = test.getVariables();
        Node xVar = v.get(0);
        Node yVar = v.get(2);
        ArrayList<Node> zList = new ArrayList<Node>();
        System.out.println(test.isIndependent(xVar, yVar, zList));
    }

    private Graph constructGraph() {
        EdgeListGraph graph = new EdgeListGraph();
        GraphNode x0 = new GraphNode("X0");
        GraphNode x1 = new GraphNode("X1");
        GraphNode x2 = new GraphNode("X2");
        GraphNode x3 = new GraphNode("X3");
        GraphNode x4 = new GraphNode("X4");
        graph.addNode(x0);
        graph.addNode(x1);
        graph.addNode(x2);
        graph.addNode(x3);
        graph.addNode(x4);
        graph.addDirectedEdge(x0, x1);
        graph.addDirectedEdge(x1, x2);
        return graph;
    }

    public void testGInverse() {
        DenseDoubleMatrix2D X = new DenseDoubleMatrix2D(new double[][]{{2.0, 2.0}, {4.0, 6.0}, {3.0, -5.0}, {3.0, -5.0}});
        DoubleMatrix2D G = MatrixUtils.ginverse(X);
        DoubleMatrix2D prod = new Algebra().mult((DoubleMatrix2D)X, new Algebra().mult(G, X));
        System.out.println("Prod = " + prod);
    }

    public void testSplitDetermination() {
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        System.out.println("trial\ta1\ta2\ta3\ta4\tcases\tp\tind");
        for (int i = 0; i < 200; ++i) {
            double limit = 20.0;
            double a1 = RandomUtil.getInstance().nextDouble() * 2.0 * limit - limit;
            if (Math.abs(a1) < 0.4) {
                --i;
                continue;
            }
            double a2 = RandomUtil.getInstance().nextDouble() * 2.0 * 30.0 - 30.0;
            double limit2 = 1.0E-5;
            double a3 = RandomUtil.getInstance().nextDouble() * 2.0 * limit2 - limit2;
            double a4 = RandomUtil.getInstance().nextDouble() * 2.0 * 30.0 - 30.0;
            int sampleSize = 1000;
            ContinuousVariable x2 = new ContinuousVariable("X2");
            ContinuousVariable x4 = new ContinuousVariable("X4");
            ContinuousVariable x6 = new ContinuousVariable("X6");
            LinkedList<Node> variables = new LinkedList<Node>();
            variables.add(x2);
            variables.add(x4);
            variables.add(x6);
            ColtDataSet dataSet = new ColtDataSet(sampleSize, variables);
            double[] x2Data = new double[sampleSize];
            double[] x4Data = new double[sampleSize];
            double[] x6Data = new double[sampleSize];
            for (int j = 0; j < sampleSize; ++j) {
                double d1 = (RandomUtil.getInstance().nextDouble() - 0.5) * 200.0;
                double d2 = (RandomUtil.getInstance().nextDouble() - 0.5) * 200.0;
                if (RandomUtil.getInstance().nextDouble() >= 0.5) {
                    x2Data[j] = a1 * d1 + a2;
                    x4Data[j] = d1;
                    x6Data[j] = d2;
                    continue;
                }
                x2Data[j] = d2;
                x4Data[j] = d1;
                x6Data[j] = a3 * d1 + a4;
            }
            int col = dataSet.getVariables().indexOf(x2);
            for (int i1 = 0; i1 < x2Data.length; ++i1) {
                dataSet.setDouble(i1, col, x2Data[i1]);
            }
            int col1 = dataSet.getVariables().indexOf(x4);
            for (int i2 = 0; i2 < x4Data.length; ++i2) {
                dataSet.setDouble(i2, col1, x4Data[i2]);
            }
            int col2 = dataSet.getVariables().indexOf(x6);
            for (int i3 = 0; i3 < x6Data.length; ++i3) {
                dataSet.setDouble(i3, col2, x6Data[i3]);
            }
            IndTestFisherZGeneralizedInverse test = new IndTestFisherZGeneralizedInverse(dataSet, 0.05);
            boolean independent = test.isIndependent((Node)x2, (Node)x6, Collections.singletonList(x4));
            System.out.print(i + 1 + "\t");
            System.out.print(nf.format(a1) + "\t");
            System.out.print(nf.format(a2) + "\t");
            System.out.print(nf.format(a3) + "\t");
            System.out.print(nf.format(a4) + "\t");
            System.out.print(sampleSize + "\t");
            System.out.print(nf.format(test.getPValue()) + "\t");
            System.out.print(independent);
            System.out.println();
        }
    }

    public static Test suite() {
        return new TestSuite(TestIndTestFisherZD.class);
    }
}

