/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestSemVarMeans
extends TestCase {
    public TestSemVarMeans(String name) {
        super(name);
    }

    public void testMeansRecursive() {
        Graph graph = this.constructGraph1();
        SemPm semPm1 = new SemPm(graph);
        List<Parameter> parameters = semPm1.getParameters();
        for (Parameter p : parameters) {
            p.setInitializedRandomly(false);
        }
        SemIm semIm1 = new SemIm(semPm1);
        double[] means = new double[]{5.0, 4.0, 3.0, 2.0, 1.0};
        RandomUtil.getInstance().setSeed(-379467L);
        for (int i = 0; i < semIm1.getVariableNodes().size(); ++i) {
            Node node = semIm1.getVariableNodes().get(i);
            semIm1.setMean(node, means[i]);
        }
        DataSet dataSet = semIm1.simulateDataRecursive(1000, false);
        SemEstimator semEst = new SemEstimator(dataSet, semPm1);
        semEst.estimate();
        SemIm estSemIm = semEst.getEstimatedSem();
        List<Node> nodes = semPm1.getVariableNodes();
        for (Node node : nodes) {
            double mean = semIm1.getMean(node);
            TestSemVarMeans.assertEquals(mean, estSemIm.getMean(node), 0.06);
        }
    }

    public void testMeansReducedForm() {
        Graph graph = this.constructGraph1();
        SemPm semPm1 = new SemPm(graph);
        List<Parameter> parameters = semPm1.getParameters();
        for (Parameter p : parameters) {
            p.setInitializedRandomly(false);
        }
        SemIm semIm1 = new SemIm(semPm1);
        double[] means = new double[]{5.0, 4.0, 3.0, 2.0, 1.0};
        RandomUtil.getInstance().setSeed(-379467L);
        for (int i = 0; i < semIm1.getVariableNodes().size(); ++i) {
            Node node = semIm1.getVariableNodes().get(i);
            semIm1.setMean(node, means[i]);
        }
        DataSet dataSet = semIm1.simulateDataReducedForm(1000, false);
        SemEstimator semEst = new SemEstimator(dataSet, semPm1);
        semEst.estimate();
        SemIm estSemIm = semEst.getEstimatedSem();
        List<Node> nodes = semPm1.getVariableNodes();
        for (Node node : nodes) {
            double mean = semIm1.getMean(node);
            TestSemVarMeans.assertEquals(mean, estSemIm.getMean(node), 0.06);
        }
    }

    public void testMeansCholesky() {
        Graph graph = this.constructGraph1();
        SemPm semPm1 = new SemPm(graph);
        List<Parameter> parameters = semPm1.getParameters();
        for (Parameter p : parameters) {
            p.setInitializedRandomly(false);
        }
        SemIm semIm1 = new SemIm(semPm1);
        double[] means = new double[]{5.0, 4.0, 3.0, 2.0, 1.0};
        RandomUtil.getInstance().setSeed(-379467L);
        for (int i = 0; i < semIm1.getVariableNodes().size(); ++i) {
            Node node = semIm1.getVariableNodes().get(i);
            semIm1.setMean(node, means[i]);
        }
        DataSet dataSet = semIm1.simulateDataCholesky(1000, false);
        SemEstimator semEst = new SemEstimator(dataSet, semPm1);
        semEst.estimate();
        SemIm estSemIm = semEst.getEstimatedSem();
        List<Node> nodes = semPm1.getVariableNodes();
        for (Node node : nodes) {
            double mean = semIm1.getMean(node);
            TestSemVarMeans.assertEquals(mean, estSemIm.getMean(node), 0.06);
        }
    }

    private Graph constructGraph1() {
        EdgeListGraph graph = new EdgeListGraph();
        GraphNode x1 = new GraphNode("X1");
        GraphNode x2 = new GraphNode("X2");
        GraphNode x3 = new GraphNode("X3");
        GraphNode x4 = new GraphNode("X4");
        GraphNode x5 = new GraphNode("X5");
        graph.addNode(x1);
        graph.addNode(x2);
        graph.addNode(x3);
        graph.addNode(x4);
        graph.addNode(x5);
        graph.addDirectedEdge(x1, x2);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x3, x4);
        graph.addDirectedEdge(x1, x4);
        graph.addDirectedEdge(x4, x5);
        return graph;
    }

    public static Test suite() {
        return new TestSuite(TestSemVarMeans.class);
    }
}

