/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.sem.TemplateExpander;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestGeneralizedSem
extends TestCase {
    public TestGeneralizedSem(String name) {
        super(name);
    }

    public void test1() {
        GeneralizedSemPm pm = this.makeTypicalPm();
        Node x1 = pm.getNode("X1");
        Node x2 = pm.getNode("X2");
        Node x3 = pm.getNode("X3");
        Node x4 = pm.getNode("X4");
        Node x5 = pm.getNode("X5");
        SemGraph graph = pm.getGraph();
        List<Node> variablesNodes = pm.getVariableNodes();
        System.out.println(variablesNodes);
        List<Node> errorNodes = pm.getErrorNodes();
        System.out.println(errorNodes);
        try {
            Set<Node> referencingNodes;
            pm.setNodeExpression(x1, "cos(b1) +\n E_X1");
            String b1 = "b1";
            Set<Node> nodes = pm.getReferencingNodes(b1);
            TestGeneralizedSem.assertTrue(nodes.contains(x1) && nodes.contains(x3));
            TestGeneralizedSem.assertTrue(!nodes.contains(x1) || !nodes.contains(x2));
            Set<String> referencedParameters = pm.getReferencedParameters(x3);
            String b2 = "b2";
            String b3 = "b3";
            System.out.println("Parameters referenced by X3 are: " + referencedParameters);
            TestGeneralizedSem.assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2));
            TestGeneralizedSem.assertTrue(!referencedParameters.contains(b1) || !referencedParameters.contains(b3));
            Node e_x3 = pm.getNode("E_X3");
            for (Node node : pm.getNodes()) {
                referencingNodes = pm.getReferencingNodes(node);
                System.out.println("Nodes referencing " + node + " are: " + referencingNodes);
            }
            for (Node node : pm.getVariableNodes()) {
                referencingNodes = pm.getReferencedNodes(node);
                System.out.println("Nodes referenced by " + node + " are: " + referencingNodes);
            }
            Set<Node> referencingX3 = pm.getReferencingNodes(x3);
            TestGeneralizedSem.assertTrue(referencingX3.contains(x4) && !referencingX3.contains(x5));
            Set<Node> referencedByX3 = pm.getReferencedNodes(x3);
            TestGeneralizedSem.assertTrue(referencedByX3.contains(x1) && referencedByX3.contains(x2) && referencedByX3.contains(e_x3) && !referencedByX3.contains(x4));
            pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5");
            Node e_x5 = pm.getErrorNode(x5);
            graph.setShowErrorTerms(true);
            TestGeneralizedSem.assertTrue(((Object)e_x5).equals(graph.getExogenous(x5)));
            pm.setNodeExpression(e_x5, "Beta(3, 5)");
            System.out.println(pm);
            TestGeneralizedSem.assertEquals("U(0, 1)", pm.getParameterExpressionString(b1));
            pm.setParameterExpression(b1, "N(0, 2)");
            TestGeneralizedSem.assertEquals("N(0, 2)", pm.getParameterExpressionString(b1));
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            System.out.println(im);
            DataSet dataSet = im.simulateDataAvoidInfinity(10, false);
            System.out.println(dataSet);
        }
        catch (ParseException e) {
            System.out.println(e);
        }
    }

    public void test2() {
        RandomUtil.getInstance().setSeed(29483L);
        int sampleSize = 1000;
        ArrayList<Node> variableNodes = new ArrayList<Node>();
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        ContinuousVariable x4 = new ContinuousVariable("X4");
        ContinuousVariable x5 = new ContinuousVariable("X5");
        variableNodes.add(x1);
        variableNodes.add(x2);
        variableNodes.add(x3);
        variableNodes.add(x4);
        variableNodes.add(x5);
        EdgeListGraph _graph = new EdgeListGraph(variableNodes);
        SemGraph graph = new SemGraph(_graph);
        graph.addDirectedEdge(x1, x3);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x3, x4);
        graph.addDirectedEdge(x2, x4);
        graph.addDirectedEdge(x4, x5);
        graph.addDirectedEdge(x2, x5);
        SemPm semPm = new SemPm(graph);
        SemIm semIm = new SemIm(semPm);
        DataSet dataSet = semIm.simulateData(sampleSize, false);
        System.out.println(semPm);
        GeneralizedSemPm _semPm = new GeneralizedSemPm(semPm);
        GeneralizedSemIm _semIm = new GeneralizedSemIm(_semPm, semIm);
        DataSet _dataSet = _semIm.simulateDataMinimizeSurface(sampleSize, false);
        System.out.println(_semPm);
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            double[] col = dataSet.getDoubleData().viewColumn(j).toArray();
            double[] _col = _dataSet.getDoubleData().viewColumn(j).toArray();
            double mean = StatUtils.mean(col);
            double _mean = StatUtils.mean(_col);
            double variance = StatUtils.variance(col);
            double _variance = StatUtils.variance(_col);
            TestGeneralizedSem.assertEquals(mean, _mean, 0.3);
            TestGeneralizedSem.assertEquals(1.0, variance / _variance, 0.2);
        }
    }

    public void test3() {
        ArrayList<Node> variableNodes = new ArrayList<Node>();
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        ContinuousVariable x4 = new ContinuousVariable("X4");
        ContinuousVariable x5 = new ContinuousVariable("X5");
        variableNodes.add(x1);
        variableNodes.add(x2);
        variableNodes.add(x3);
        variableNodes.add(x4);
        variableNodes.add(x5);
        EdgeListGraph _graph = new EdgeListGraph(variableNodes);
        SemGraph graph = new SemGraph(_graph);
        graph.setShowErrorTerms(true);
        Node e1 = graph.getExogenous(x1);
        Node e2 = graph.getExogenous(x2);
        Node e3 = graph.getExogenous(x3);
        Node e4 = graph.getExogenous(x4);
        Node e5 = graph.getExogenous(x5);
        graph.addDirectedEdge(x1, x3);
        graph.addDirectedEdge(x1, x2);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x3, x4);
        graph.addDirectedEdge(x2, x4);
        graph.addDirectedEdge(x4, x5);
        graph.addDirectedEdge(x2, x5);
        graph.addDirectedEdge(x5, x1);
        GeneralizedSemPm pm = new GeneralizedSemPm(graph);
        List<Node> variablesNodes = pm.getVariableNodes();
        System.out.println(variablesNodes);
        List<Node> errorNodes = pm.getErrorNodes();
        System.out.println(errorNodes);
        try {
            pm.setNodeExpression(x1, "cos(b1) + a1 * X5 + E_X1");
            pm.setNodeExpression(x2, "a2 * X1 + E_X2");
            pm.setNodeExpression(x3, "tanh(a3*X2 + a4*X1) + E_X3");
            pm.setNodeExpression(x4, "0.1 * E^X2 + X3 + E_X4");
            pm.setNodeExpression(x5, "0.1 * E^X4 + a6* X2 + E_X5");
            pm.setNodeExpression(e1, "U(0, 1)");
            pm.setNodeExpression(e2, "U(0, 1)");
            pm.setNodeExpression(e3, "U(0, 1)");
            pm.setNodeExpression(e4, "U(0, 1)");
            pm.setNodeExpression(e5, "U(0, 1)");
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            System.out.println(im);
            DataSet dataSet = im.simulateDataNSteps(1000, false);
        }
        catch (ParseException e) {
            System.out.println(e);
        }
    }

    public void test7() {
        HashMap<String, String[]> templates = new HashMap<String, String[]>();
        templates.put("NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[0]);
        templates.put("$", new String[0]);
        templates.put("TSUM($)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("TPROD($)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("TPROD($) + X2", new String[]{"X3", "X4", "X5"});
        templates.put("TPROD($) + TSUM($)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("tanh(TSUM(NEW(a)*$))", new String[]{"X3", "X4", "X5"});
        templates.put("Normal(0, 1)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("Normal(m, s)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("Normal(NEW(m), s)", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("TSUM($) + a", new String[]{"X1", "X2", "X3", "X4", "X5"});
        templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[]{"X1", "X2", "X3", "X4", "X5"});
        for (String template : templates.keySet()) {
            GeneralizedSemPm semPm = this.makeTypicalPm();
            System.out.println(semPm.getGraph());
            HashSet<Node> shouldWork = new HashSet<Node>();
            for (String name : (String[])templates.get(template)) {
                shouldWork.add(semPm.getNode(name));
            }
            HashSet<Node> works = new HashSet<Node>();
            for (int i = 0; i < semPm.getNodes().size(); ++i) {
                System.out.println("-----------");
                System.out.println(semPm.getNodes().get(i));
                System.out.println("Trying template: " + template);
                String _template = template;
                Node node = semPm.getNodes().get(i);
                try {
                    _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node);
                }
                catch (Exception e) {
                    System.out.println("Couldn't expand template: " + template);
                    continue;
                }
                try {
                    semPm.setNodeExpression(node, _template);
                    System.out.println("Set formula " + _template + " for " + node);
                    if (!semPm.getVariableNodes().contains(node)) continue;
                    works.add(node);
                    continue;
                }
                catch (Exception e) {
                    System.out.println("Couldn't set formula " + _template + " for " + node);
                }
            }
            for (String parameter : semPm.getParameters()) {
                System.out.println("-----------");
                System.out.println(parameter);
                System.out.println("Trying template: " + template);
                String _template = template;
                try {
                    _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null);
                }
                catch (Exception e) {
                    System.out.println("Couldn't expand template: " + template);
                    continue;
                }
                try {
                    semPm.setParameterExpression(parameter, _template);
                    System.out.println("Set formula " + _template + " for " + parameter);
                }
                catch (Exception e) {
                    System.out.println("Couldn't set formula " + _template + " for " + parameter);
                }
            }
            TestGeneralizedSem.assertEquals(shouldWork, works);
        }
    }

    private GeneralizedSemPm makeTypicalPm() {
        ArrayList<Node> variableNodes = new ArrayList<Node>();
        ContinuousVariable x1 = new ContinuousVariable("X1");
        ContinuousVariable x2 = new ContinuousVariable("X2");
        ContinuousVariable x3 = new ContinuousVariable("X3");
        ContinuousVariable x4 = new ContinuousVariable("X4");
        ContinuousVariable x5 = new ContinuousVariable("X5");
        variableNodes.add(x1);
        variableNodes.add(x2);
        variableNodes.add(x3);
        variableNodes.add(x4);
        variableNodes.add(x5);
        EdgeListGraph _graph = new EdgeListGraph(variableNodes);
        SemGraph graph = new SemGraph(_graph);
        graph.addDirectedEdge(x1, x3);
        graph.addDirectedEdge(x2, x3);
        graph.addDirectedEdge(x3, x4);
        graph.addDirectedEdge(x2, x4);
        graph.addDirectedEdge(x4, x5);
        graph.addDirectedEdge(x2, x5);
        GeneralizedSemPm semPm = new GeneralizedSemPm(graph);
        return semPm;
    }

    private String replaceNewParameters(GeneralizedSemPm semPm, String formula, List<String> usedNames) {
        Matcher m;
        String parameterPattern = "\\$|(([a-zA-Z]{1})([a-zA-Z0-9-_/]*))";
        Pattern p = Pattern.compile("NEW\\((" + parameterPattern + ")\\)");
        while ((m = p.matcher(formula)).find()) {
            String group0 = Pattern.quote(m.group(0));
            String group1 = m.group(1);
            String nextName = semPm.nextParameterName(group1, usedNames);
            formula = formula.replaceFirst(group0, nextName);
            usedNames.add(nextName);
        }
        return formula;
    }

    public static Test suite() {
        return new TestSuite(TestGeneralizedSem.class);
    }
}

