/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.sem.StandardizedSemIm;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestStandardizedSem
extends TestCase {
    public TestStandardizedSem(String name) {
        super(name);
    }

    public void test1() {
        SemGraph graph = new SemGraph(GraphUtils.randomDag(5, 5, false));
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet dataSet = im.simulateData(1000, false);
        DoubleMatrix2D _dataSet = dataSet.getDoubleData();
        _dataSet = DataUtils.standardizeData(_dataSet);
        ColtDataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
        System.out.println(DataUtils.cov(_dataSet));
        System.out.println(DataUtils.mean(_dataSet));
        SemEstimator estimator = new SemEstimator(dataSetStandardized, pm);
        SemIm imStandardized = estimator.estimate();
        System.out.println("Edge coef: " + imStandardized.getEdgeCoef());
        System.out.println("Error cover: " + imStandardized.getErrCovar());
        System.out.println("Variable means: " + new DenseDoubleMatrix1D(imStandardized.getMeans()));
        System.out.println("Original edge coefficients: " + imStandardized.getEdgeCoef());
        System.out.println("Original error covariances: " + imStandardized.getErrCovar());
        StandardizedSemIm sem = new StandardizedSemIm(im);
        System.out.println("Edge coefficients after construction: " + imStandardized.getEdgeCoef());
        System.out.println("Error covariances after construction: " + imStandardized.getErrCovar());
        TestStandardizedSem.assertTrue(this.isStandardized(sem));
    }

    public void test2() {
        RandomUtil.getInstance().setSeed(5729384723L);
        SemGraph graph = new SemGraph();
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        ContinuousVariable x4 = new ContinuousVariable("X4");
        ContinuousVariable x5 = new ContinuousVariable("X5");
        graph.addNode(x1);
        graph.addNode(x2);
        graph.addNode(x3);
        graph.addNode(x4);
        graph.addNode(x5);
        graph.setShowErrorTerms(true);
        graph.addDirectedEdge(x1, x2);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x4, x3);
        graph.addDirectedEdge(x2, x4);
        graph.addDirectedEdge(x1, x4);
        graph.addDirectedEdge(x5, x4);
        System.out.println(graph);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        StandardizedSemIm sem = new StandardizedSemIm(im);
        System.out.println(sem);
        TestStandardizedSem.assertTrue(this.isStandardized(sem));
    }

    public void test3() {
        RandomUtil.getInstance().setSeed(582374923L);
        SemGraph graph = new SemGraph();
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        graph.addNode(x1);
        graph.addNode(x2);
        graph.addNode(x3);
        graph.setShowErrorTerms(true);
        graph.addDirectedEdge(x1, x3);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x1, x2);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        System.out.println(im);
        StandardizedSemIm sem = new StandardizedSemIm(im);
        System.out.println(sem);
        DataSet data = sem.simulateData(5000, false);
        System.out.println(sem.getVariableNodes());
        System.out.println(DataUtils.cov(data.getDoubleData()));
        System.out.println(sem.getCoefficientRange(x1, x2));
        TestStandardizedSem.assertFalse(sem.setEdgeCoefficient(x1, x2, 1.2));
        TestStandardizedSem.assertFalse(sem.setEdgeCoefficient(x1, x2, 1.5));
        TestStandardizedSem.assertTrue(sem.setEdgeCoefficient(x1, x2, 0.6));
        TestStandardizedSem.assertTrue(sem.setEdgeCoefficient(x1, x3, -0.1));
        System.out.println(sem);
        TestStandardizedSem.assertTrue(this.isStandardized(sem));
    }

    public void test4() {
        SemGraph graph = new SemGraph(GraphUtils.randomDag(10, 10, false));
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        StandardizedSemIm sem = new StandardizedSemIm(im);
        for (int i = 0; i < 20; ++i) {
            List<Edge> edges = graph.getEdges();
            RandomUtil random = RandomUtil.getInstance();
            int index = random.nextInt(edges.size());
            Edge edge = edges.get(index);
            Node a = edge.getNode1();
            Node b = edge.getNode2();
            StandardizedSemIm.ParameterRange range = sem.getCoefficientRange(a, b);
            double high = range.getHigh();
            double low = range.getLow();
            double coef = low + random.nextDouble() * (high - low);
            TestStandardizedSem.assertTrue(sem.setEdgeCoefficient(a, b, coef));
            coef = high + random.nextDouble() * (high - low);
            TestStandardizedSem.assertFalse(sem.setEdgeCoefficient(a, b, coef));
            coef = low - random.nextDouble() * (high - low);
            TestStandardizedSem.assertFalse(sem.setEdgeCoefficient(a, b, coef));
        }
    }

    public void test5() {
        RandomUtil.getInstance().setSeed(582374923L);
        SemGraph graph = new SemGraph();
        graph.setShowErrorTerms(true);
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        graph.addNode(x1);
        graph.addNode(x2);
        graph.addNode(x3);
        graph.setShowErrorTerms(true);
        Node ex1 = graph.getExogenous(x1);
        Node ex2 = graph.getExogenous(x2);
        Node ex3 = graph.getExogenous(x3);
        graph.addDirectedEdge(x1, x3);
        graph.addDirectedEdge(x2, x3);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet dataSet = im.simulateDataRecursive(1000, false);
        DoubleMatrix2D _dataSet = dataSet.getDoubleData();
        _dataSet = DataUtils.standardizeData(_dataSet);
        ColtDataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
        SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
        SemIm imStandardized = estimator.estimate();
        StandardizedSemIm sem = new StandardizedSemIm(im);
        System.out.println(sem);
        TestStandardizedSem.assertTrue(this.isStandardized(sem));
    }

    public void test6() {
        SemGraph graph = new SemGraph();
        graph.setShowErrorTerms(true);
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        graph.addNode(x1);
        graph.addNode(x2);
        graph.addNode(x3);
        graph.setShowErrorTerms(true);
        Node ex1 = graph.getExogenous(x1);
        Node ex2 = graph.getExogenous(x2);
        Node ex3 = graph.getExogenous(x3);
        graph.addDirectedEdge(x1, x3);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x1, x2);
        graph.addBidirectedEdge(ex1, ex2);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet dataSet = im.simulateDataRecursive(1000, false);
        System.out.println("im " + im.getErrCovar(x1, x2));
        DoubleMatrix2D _dataSet = dataSet.getDoubleData();
        _dataSet = DataUtils.standardizeData(_dataSet);
        ColtDataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
        SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
        SemIm imStandardized = estimator.estimate();
        System.out.println(imStandardized);
        System.out.println("im st " + imStandardized.getErrCovar(x1, x2));
        StandardizedSemIm sem = new StandardizedSemIm(im);
        System.out.println("sem " + sem.getErrorCovariance(x1, x2));
        System.out.println(sem);
        TestStandardizedSem.assertTrue(this.isStandardized(sem));
    }

    public void test7() {
        RandomUtil random = RandomUtil.getInstance();
        SemGraph graph = new SemGraph(GraphUtils.randomDag(5, 5, false));
        List<Node> nodes = graph.getNodes();
        int n1 = RandomUtil.getInstance().nextInt(nodes.size());
        int n2 = RandomUtil.getInstance().nextInt(nodes.size());
        while (n1 == n2) {
            n2 = RandomUtil.getInstance().nextInt(nodes.size());
        }
        Node node1 = nodes.get(n1);
        Node node2 = nodes.get(n2);
        Edge _edge = Edges.bidirectedEdge(node1, node2);
        System.out.println(_edge);
        graph.addEdge(_edge);
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        StandardizedSemIm sem = new StandardizedSemIm(im);
        DataSet data3 = sem.simulateDataReducedForm(1000, false);
        System.out.println(new CovarianceMatrix(data3));
        for (int i = 0; i < 1; ++i) {
            for (Edge edge : graph.getEdges()) {
                Node a = edge.getNode1();
                Node b = edge.getNode2();
                if (Edges.isDirectedEdge(edge)) {
                    double initial = sem.getEdgeCoefficient(a, b);
                    StandardizedSemIm.ParameterRange range = sem.getCoefficientRange(a, b);
                    TestStandardizedSem.assertEquals(initial, sem.getEdgeCoefficient(a, b));
                    double low = range.getLow();
                    double high = range.getHigh();
                    double _coef = sem.getEdgeCoefficient(a, b);
                    double coef = low + random.nextDouble() * (high - low);
                    TestStandardizedSem.assertTrue(sem.setEdgeCoefficient(a, b, coef));
                    sem.setEdgeCoefficient(a, b, _coef);
                    coef = high + random.nextDouble() * (high - low);
                    TestStandardizedSem.assertFalse(sem.setEdgeCoefficient(a, b, coef));
                    coef = low - random.nextDouble() * (high - low);
                    TestStandardizedSem.assertFalse(sem.setEdgeCoefficient(a, b, coef));
                    continue;
                }
                if (!Edges.isBidirectedEdge(edge)) continue;
                System.out.println("covariance = " + sem.getErrorCovariance(a, b));
                sem.setErrorCovariance(node1, node2, 0.15);
                TestStandardizedSem.assertTrue(this.isStandardized(sem));
                StandardizedSemIm.ParameterRange range2 = sem.getCovarianceRange(a, b);
                System.out.println(range2);
                double low = range2.getLow();
                double high = range2.getHigh();
                if (low == Double.NEGATIVE_INFINITY) {
                    low = -10000.0;
                }
                if (high == Double.POSITIVE_INFINITY) {
                    high = 10000.0;
                }
                double _coef = sem.getErrorCovariance(a, b);
                double coef = low + random.nextDouble() * (high - low);
                System.out.println("Picked " + coef);
                TestStandardizedSem.assertTrue(sem.setErrorCovariance(a, b, coef));
                sem.setErrorCovariance(a, b, _coef);
                if (high != 10000.0) {
                    coef = high + random.nextDouble() * (high - low);
                    TestStandardizedSem.assertFalse(sem.setErrorCovariance(a, b, coef));
                }
                if (low == -10000.0) continue;
                coef = low - random.nextDouble() * (high - low);
                TestStandardizedSem.assertFalse(sem.setErrorCovariance(a, b, coef));
            }
        }
    }

    public void test8() {
        SemGraph graph = new SemGraph();
        ContinuousVariable x = new ContinuousVariable("X");
        ContinuousVariable y = new ContinuousVariable("Y");
        ContinuousVariable z = new ContinuousVariable("Z");
        graph.addNode(x);
        graph.addNode(y);
        graph.addNode(z);
        graph.addDirectedEdge(x, y);
        graph.addBidirectedEdge(x, y);
        graph.addDirectedEdge(x, z);
        graph.addDirectedEdge(y, z);
        graph.setShowErrorTerms(true);
        SemPm semPm = new SemPm(graph);
        SemIm semIm = new SemIm(semPm);
        System.out.println(semIm);
        StandardizedSemIm sem = new StandardizedSemIm(semIm, StandardizedSemIm.Initialization.CALCULATE_FROM_SEM);
        System.out.println(sem);
        DataSet data = semIm.simulateDataCholesky(1000, false);
        data = ColtDataSet.makeContinuousData(data.getVariables(), DataUtils.standardizeData(data.getDoubleData()));
        SemEstimator estimator = new SemEstimator(data, semPm);
        semIm = estimator.estimate();
        DataSet data2 = semIm.simulateDataReducedForm(1000, false);
        System.out.println(new CovarianceMatrix(data2));
        DataSet data3 = sem.simulateDataReducedForm(1000, false);
        System.out.println(new CovarianceMatrix(data3));
        StandardizedSemIm.ParameterRange range2 = sem.getCovarianceRange(x, y);
        System.out.println(range2);
        double high = range2.getHigh();
        double low = range2.getLow();
        if (high == Double.POSITIVE_INFINITY) {
            high = 1000.0;
        }
        if (low == Double.NEGATIVE_INFINITY) {
            low = -1000.0;
        }
        double coef = low + RandomUtil.getInstance().nextDouble() * (high - low);
        System.out.println("Picked " + coef);
        TestStandardizedSem.assertTrue(sem.setErrorCovariance(x, y, coef));
        System.out.println(new CovarianceMatrix(data3));
        assert (this.isStandardized(sem));
    }

    private boolean isStandardized(StandardizedSemIm sem) {
        DataSet dataSet = sem.simulateData(5000, false);
        DoubleMatrix2D _dataSet = dataSet.getDoubleData();
        DoubleMatrix2D cov = DataUtils.cov(_dataSet);
        DoubleMatrix1D means = DataUtils.mean(_dataSet);
        for (int i = 0; i < cov.rows(); ++i) {
            if (!(Math.abs(cov.get(i, i) - 1.0) < 0.1)) {
                System.out.println("Variable " + sem.getErrorNodes().get(i) + " variance not equal to 1: " + cov.get(i, i));
                return false;
            }
            if (Math.abs(means.get(i)) < 0.1) continue;
            System.out.println("Mean not equal to 0:" + means.get(i));
            return false;
        }
        return true;
    }

    public void testSliderValues() {
        int i;
        int n = 100;
        for (i = 0; i <= 100; ++i) {
            TestStandardizedSem.assertEquals(i, this.sliderToSlider(i, -5.0, 5.0, n));
        }
        for (i = 0; i <= 100; ++i) {
            TestStandardizedSem.assertEquals(i, this.sliderToSlider(i, -5.0, Double.POSITIVE_INFINITY, n));
        }
        for (i = 0; i <= 100; ++i) {
            TestStandardizedSem.assertEquals(i, this.sliderToSlider(i, Double.NEGATIVE_INFINITY, 5.0, n));
        }
        for (i = 0; i <= 100; ++i) {
            TestStandardizedSem.assertEquals(i, this.sliderToSlider(i, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, n));
        }
    }

    private int sliderToSlider(int slider, double min, double max, int n) {
        double value = this.sliderToValue(slider, min, max, n);
        return this.valueToSlider(value, min, max, n);
    }

    private double sliderToValue(int slider, double min, double max, int n) {
        double f = min != Double.NEGATIVE_INFINITY && max != Double.POSITIVE_INFINITY ? min + (double)slider / (double)n * (max - min) : (min != Double.NEGATIVE_INFINITY ? min + Math.tan((double)slider / (double)n * 1.5707963267948966) : (max != Double.POSITIVE_INFINITY ? max + Math.tan(-(((double)n - (double)slider) / (double)n) * 1.5707963267948966) : Math.tan(-1.5707963267948966 + (double)slider / (double)n * Math.PI)));
        return f;
    }

    private int valueToSlider(double value, double min, double max, int n) {
        double x = min != Double.NEGATIVE_INFINITY && max != Double.POSITIVE_INFINITY ? (double)n * (value - min) / (max - min) : (min != Double.NEGATIVE_INFINITY ? 2.0 * (double)n / Math.PI * Math.atan(value - min) : (max != Double.POSITIVE_INFINITY ? (double)n + 2.0 * (double)n / Math.PI * Math.atan(value - max) : (double)n / Math.PI * (Math.atan(value) + 1.5707963267948966)));
        int slider = (int)Math.round(x);
        if (slider > 100) {
            slider = 100;
        }
        if (slider < 0) {
            slider = 0;
        }
        return slider;
    }

    public static Test suite() {
        return new TestSuite(TestStandardizedSem.class);
    }
}

