/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.search.Ges;
import edu.cmu.tetrad.search.LingamPattern;
import edu.cmu.tetrad.search.TestPc;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.dist.Distribution;
import edu.cmu.tetrad.util.dist.Normal;
import edu.cmu.tetrad.util.dist.Uniform;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestLingamPattern
extends TestCase {
    public TestLingamPattern(String name) {
        super(name);
    }

    public void test1() {
        int sampleSize = 1000;
        Dag graph = GraphUtils.randomDag(6, 0, 6, 4, 4, 4, false);
        System.out.println("true graph = " + graph);
        ArrayList<Distribution> variableDistributions = new ArrayList<Distribution>();
        variableDistributions.add(new Normal(0.0, 1.0));
        variableDistributions.add(new Normal(0.0, 1.0));
        variableDistributions.add(new Normal(0.0, 1.0));
        variableDistributions.add(new Uniform(-1.0, 1.0));
        variableDistributions.add(new Normal(0.0, 1.0));
        variableDistributions.add(new Normal(0.0, 1.0));
        SemPm semPm = new SemPm(graph);
        SemIm semIm = new SemIm(semPm);
        DataSet dataSet = this.simulateDataNonNormal(semIm, sampleSize, variableDistributions);
        Graph estPattern = new Ges(dataSet).search();
        LingamPattern lingam = new LingamPattern(estPattern, dataSet);
        Graph pattern = lingam.search();
        System.out.println("Pattern = " + pattern);
        double[] pvals = lingam.getPValues();
        System.out.println("Anderson Darling P value for Variables\n");
        DecimalFormat nf = new DecimalFormat("0.0000");
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            System.out.println(dataSet.getVariable(j) + ": " + nf.format(pvals[j]));
        }
        System.out.println();
    }

    private DataSet simulateDataNonNormal(SemIm semIm, int sampleSize, List<Distribution> distributions) {
        LinkedList<Node> variables = new LinkedList<Node>();
        List<Node> variableNodes = semIm.getSemPm().getVariableNodes();
        for (Node node : variableNodes) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            variables.add(var);
        }
        ColtDataSet dataSet = new ColtDataSet(sampleSize, variables);
        SemGraph graph = semIm.getSemPm().getGraph();
        List<Node> tierOrdering = graph.getTierOrdering();
        int[] tierIndices = new int[variableNodes.size()];
        for (int i = 0; i < tierIndices.length; ++i) {
            tierIndices[i] = variableNodes.indexOf(tierOrdering.get(i));
        }
        int[][] _parents = new int[variables.size()][];
        for (int i = 0; i < variableNodes.size(); ++i) {
            Node node = variableNodes.get(i);
            List<Node> parents = graph.getParents(node);
            Iterator<Node> j = parents.iterator();
            while (j.hasNext()) {
                Node _node = j.next();
                if (_node.getNodeType() != NodeType.ERROR) continue;
                j.remove();
            }
            _parents[i] = new int[parents.size()];
            for (int j2 = 0; j2 < parents.size(); ++j2) {
                Node _parent = parents.get(j2);
                _parents[i][j2] = variableNodes.indexOf(_parent);
            }
        }
        for (int row = 0; row < sampleSize; ++row) {
            for (int i = 0; i < tierOrdering.size(); ++i) {
                int col = tierIndices[i];
                Distribution distribution = distributions.get(col);
                double value = distribution.nextRandom();
                for (int j = 0; j < _parents[col].length; ++j) {
                    int parent = _parents[col][j];
                    value += dataSet.getDouble(row, parent) * semIm.getEdgeCoef().get(parent, col);
                }
                dataSet.setDouble(row, col, value += semIm.getMeans()[col]);
            }
        }
        return dataSet;
    }

    public static Test suite() {
        return new TestSuite(TestPc.class);
    }
}

