/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.csb.mgm;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.work_in_progress.IndTestMixedMultipleTTest;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.sem.TemplateExpander;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.pitt.dbmi.data.reader.Delimiter;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.math3.util.FastMath;

public class MixedUtils {
    public static final String EdgeStatHeader = "TD\tTU\tFL\tFD\tFU\tFPD\tFPU\tFND\tFNU\tBidir";

    public static int[] getDiscreteInds(List<Node> nodes) {
        ArrayList<Integer> indList = new ArrayList<Integer>();
        int curInd = 0;
        for (Node n : nodes) {
            if (n instanceof DiscreteVariable) {
                indList.add(curInd);
            }
            ++curInd;
        }
        int[] inds = new int[indList.size()];
        for (int i = 0; i < inds.length; ++i) {
            inds[i] = (Integer)indList.get(i);
        }
        return inds;
    }

    public static int[] getContinuousInds(List<Node> nodes) {
        ArrayList<Integer> indList = new ArrayList<Integer>();
        int curInd = 0;
        for (Node n : nodes) {
            if (n instanceof ContinuousVariable) {
                indList.add(curInd);
            }
            ++curInd;
        }
        int[] inds = new int[indList.size()];
        for (int i = 0; i < inds.length; ++i) {
            inds[i] = (Integer)indList.get(i);
        }
        return inds;
    }

    public static DataSet makeContinuousData(DataSet dsMix) {
        ArrayList<Node> contVars = new ArrayList<Node>();
        for (Node n : dsMix.getVariables()) {
            if (n instanceof DiscreteVariable) {
                ContinuousVariable nc = new ContinuousVariable(n.getName());
                contVars.add(nc);
                continue;
            }
            contVars.add(n);
        }
        return new BoxDataSet(new VerticalDoubleDataBox(dsMix.getDoubleData().transpose().toArray()), contVars);
    }

    public static DataSet makeMixedData(DataSet dsCont, Map<String, String> nodeDists, int numCategories) {
        ArrayList<Node> mixVars = new ArrayList<Node>();
        for (Node n : dsCont.getVariables()) {
            if (nodeDists.get(n.getName()).equals("Disc")) {
                DiscreteVariable nd = new DiscreteVariable(n.getName(), numCategories);
                mixVars.add(nd);
                continue;
            }
            mixVars.add(n);
        }
        return new BoxDataSet(new DoubleDataBox(dsCont.getDoubleData().toArray()), mixVars);
    }

    public static DataSet makeMixedData(DataSet dsCont, Map<String, Integer> nodeDists) {
        ArrayList<Node> mixVars = new ArrayList<Node>();
        for (Node n : dsCont.getVariables()) {
            int nC = nodeDists.get(n.getName());
            if (nC > 0) {
                DiscreteVariable nd = new DiscreteVariable(n.getName(), nC);
                mixVars.add(nd);
                continue;
            }
            mixVars.add(n);
        }
        return new BoxDataSet(new DoubleDataBox(dsCont.getDoubleData().toArray()), mixVars);
    }

    public static DataSet deepCopy(DataSet ds) {
        ArrayList<Node> vars = new ArrayList<Node>(ds.getNumColumns());
        for (Node n : ds.getVariables()) {
            if (n instanceof ContinuousVariable) {
                vars.add(new ContinuousVariable((ContinuousVariable)n));
                continue;
            }
            if (n instanceof DiscreteVariable) {
                vars.add(new DiscreteVariable((DiscreteVariable)n));
                continue;
            }
            throw new IllegalArgumentException("Variable type of node " + n + "could not be determined");
        }
        return new BoxDataSet(new DoubleDataBox(ds.getDoubleData().toArray()), vars);
    }

    public static DataSet getContinousData(DataSet ds) {
        ArrayList<Node> contVars = new ArrayList<Node>();
        for (Node n : ds.getVariables()) {
            if (!(n instanceof ContinuousVariable)) continue;
            contVars.add(n);
        }
        return ds.subsetColumns(contVars);
    }

    public static DataSet getDiscreteData(DataSet ds) {
        ArrayList<Node> discVars = new ArrayList<Node>();
        for (Node n : ds.getVariables()) {
            if (!(n instanceof DiscreteVariable)) continue;
            discVars.add(n);
        }
        return ds.subsetColumns(discVars);
    }

    public static int[] getDiscLevels(DataSet ds) {
        DataSet discDs = MixedUtils.getDiscreteData(ds);
        int[] levels = new int[discDs.getNumColumns()];
        int i = 0;
        for (Node n : discDs.getVariables()) {
            levels[i] = ((DiscreteVariable)n).getNumCategories();
            ++i;
        }
        return levels;
    }

    public static int[] colMax(DoubleMatrix2D m) {
        int[] maxVec = new int[m.columns()];
        for (int i = 0; i < m.columns(); ++i) {
            double curmax = -1.0;
            for (int j = 0; j < m.rows(); ++j) {
                double curval = m.getQuick(j, i);
                if (!(curval > curmax)) continue;
                curmax = curval;
            }
            maxVec[i] = (int)curmax;
        }
        return maxVec;
    }

    public static double vecMax(DoubleMatrix1D vec) {
        double curMax = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < vec.size(); ++i) {
            double curVal = vec.getQuick(i);
            if (!(curVal > curMax)) continue;
            curMax = curVal;
        }
        return curMax;
    }

    public static double numVals(DoubleMatrix1D vec) {
        return MixedUtils.valSet(vec).size();
    }

    public static Set<Double> valSet(DoubleMatrix1D vec) {
        HashSet<Double> vals = new HashSet<Double>();
        for (int i = 0; i < vec.size(); ++i) {
            vals.add(vec.getQuick(i));
        }
        return vals;
    }

    public static GeneralizedSemPm GaussianTrinaryPm(Graph trueGraph, HashMap<String, String> nodeDists, int maxSample, String paramTemplate) throws IllegalStateException {
        GeneralizedSemPm semPm = new GeneralizedSemPm(trueGraph);
        try {
            List<Node> variableNodes = semPm.getVariableNodes();
            int numVars = variableNodes.size();
            semPm.setStartsWithParametersTemplate("B", paramTemplate);
            semPm.setStartsWithParametersTemplate("D", paramTemplate);
            semPm.setStartsWithParametersTemplate("al", "U(.3,1.3)");
            semPm.setStartsWithParametersTemplate("s", "U(1,2)");
            String templateDisc = "DiscError(err, (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)))";
            String templateDisc0 = "DiscError(err, .001,.001,.001)";
            for (Node node : variableNodes) {
                List<Node> parents = trueGraph.getParents(node);
                Node eNode = semPm.getErrorNode(node);
                String curEx = semPm.getNodeExpressionString(node);
                String errEx = semPm.getNodeExpressionString(eNode);
                String newTemp = "";
                if (nodeDists.get(node.getName()).equals("Disc")) {
                    newTemp = parents.size() == 0 ? "DiscError(err, .001,.001,.001)" : "DiscError(err, (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)))";
                    newTemp = newTemp.replaceAll("err", eNode.getName());
                    curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
                    errEx = TemplateExpander.getInstance().expandTemplate("U(0,1)", semPm, eNode);
                }
                newTemp = "";
                if (parents.size() != 0) {
                    for (Node parNode : parents) {
                        if (!nodeDists.get(parNode.getName()).equals("Disc")) continue;
                        String curName = parNode.getName();
                        String disRep = "IF(" + curName + "=0,NEW(D),IF(" + curName + "=1,NEW(D),NEW(D)))";
                        newTemp = curEx.replaceAll("(B[0-9]*\\*" + curName + ")(?![0-9])", disRep);
                    }
                }
                if (newTemp.length() != 0) {
                    curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
                }
                semPm.setNodeExpression(node, curEx);
                semPm.setNodeExpression(eNode, errEx);
            }
        }
        catch (ParseException e) {
            throw new IllegalStateException("Parse error in fixing parameters.", e);
        }
        return semPm;
    }

    public static GeneralizedSemPm GaussianCategoricalPm(Graph trueGraph, String paramTemplate) throws IllegalStateException {
        Map<String, Integer> nodeDists = MixedUtils.getNodeDists(trueGraph);
        GeneralizedSemPm semPm = new GeneralizedSemPm(trueGraph);
        try {
            List<Node> variableNodes = semPm.getVariableNodes();
            int numVars = variableNodes.size();
            semPm.setStartsWithParametersTemplate("B", paramTemplate);
            semPm.setStartsWithParametersTemplate("C", paramTemplate);
            semPm.setStartsWithParametersTemplate("D", paramTemplate);
            semPm.setStartsWithParametersTemplate("s", "U(1,2)");
            String templateDisc0 = "DiscError(err, ";
            for (Node node : variableNodes) {
                List<Node> parents = trueGraph.getParents(node);
                Node eNode = semPm.getErrorNode(node);
                String curEx = semPm.getNodeExpressionString(node);
                String errEx = semPm.getNodeExpressionString(eNode);
                String newTemp = "";
                int curDist = nodeDists.get(node.getName());
                if (curDist == 1) {
                    throw new IllegalArgumentException("Dist for node " + node.getName() + " is set to one (i.e. constant) which is not supported.");
                }
                if (curDist > 0) {
                    int l;
                    if (parents.size() == 0) {
                        newTemp = "DiscError(err";
                        for (l = 0; l < curDist; ++l) {
                            newTemp = newTemp + ",1";
                        }
                    } else {
                        newTemp = "DiscError(err";
                        for (l = 0; l < curDist; ++l) {
                            newTemp = newTemp + ", TSUM(NEW(C)*$)";
                        }
                    }
                    newTemp = newTemp + ")";
                    newTemp = newTemp.replaceAll("err", eNode.getName());
                    curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
                    errEx = TemplateExpander.getInstance().expandTemplate("U(0,1)", semPm, eNode);
                }
                newTemp = curEx;
                if (parents.size() != 0) {
                    for (Node parNode : parents) {
                        int parDist = nodeDists.get(parNode.getName());
                        if (parDist <= 0) continue;
                        String curName = parNode.getName();
                        String disRep = "Switch(" + curName;
                        for (int l = 0; l < parDist; ++l) {
                            disRep = curDist > 0 ? disRep + ",NEW(D)" : disRep + ",NEW(C)";
                        }
                        disRep = disRep + ")";
                        if (curDist > 0) {
                            newTemp = newTemp.replaceAll("(C[0-9]*\\*" + curName + ")(?![0-9])", disRep);
                            continue;
                        }
                        newTemp = newTemp.replaceAll("(B[0-9]*\\*" + curName + ")(?![0-9])", disRep);
                    }
                }
                if (newTemp.length() != 0) {
                    curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
                }
                semPm.setNodeExpression(node, curEx);
                semPm.setNodeExpression(eNode, errEx);
            }
        }
        catch (ParseException e) {
            throw new IllegalStateException("Parse error in fixing parameters.", e);
        }
        return semPm;
    }

    public static void setStartsWith(String sta, String template, GeneralizedSemPm pm) {
        try {
            pm.setStartsWithParametersTemplate(sta, template);
            for (String param : pm.getParameters()) {
                if (!param.startsWith(sta)) continue;
                pm.setParameterExpression(param, template);
            }
        }
        catch (Throwable t) {
            t.printStackTrace();
        }
    }

    public static GeneralizedSemIm GaussianCategoricalIm(GeneralizedSemPm pm) {
        return MixedUtils.GaussianCategoricalIm(pm, true);
    }

    public static GeneralizedSemIm GaussianCategoricalIm(GeneralizedSemPm pm, boolean discParamRand) {
        Map<String, Integer> nodeDists = MixedUtils.getNodeDists(pm.getGraph());
        GeneralizedSemIm im = new GeneralizedSemIm(pm);
        List<Node> nodes = pm.getVariableNodes();
        for (Node n : nodes) {
            Set<Node> parNodes = pm.getReferencedNodes(n);
            if (parNodes.size() == 0) continue;
            for (Node par : parNodes) {
                Object newWeights;
                if (par.getNodeType() == NodeType.ERROR) continue;
                int cL = nodeDists.get(n.getName());
                int pL = nodeDists.get(par.getName());
                if (cL == 0 && pL == 0) continue;
                List<String> params = MixedUtils.getEdgeParams(n, par, pm);
                double w = im.getParameterValue(params.get(0));
                if (cL > 0 && pL > 0) {
                    int i;
                    newWeights = new double[cL][pL];
                    w = FastMath.abs(w);
                    double bgW = w / ((double)pL - 1.0);
                    int[] weightInds = new int[cL];
                    for (i = 0; i < cL; ++i) {
                        weightInds[i] = i < pL ? i : i % pL;
                    }
                    if (discParamRand) {
                        weightInds = MixedUtils.arrayPermute(weightInds);
                    }
                    for (i = 0; i < cL; ++i) {
                        for (int j = 0; j < pL; ++j) {
                            int index = i * pL + j;
                            if (weightInds[i] == j) {
                                im.setParameterValue(params.get(index), w);
                                continue;
                            }
                            im.setParameterValue(params.get(index), -bgW);
                        }
                    }
                    continue;
                }
                int curL = pL > 0 ? pL : cL;
                newWeights = discParamRand ? (Object)MixedUtils.generateMixedEdgeParams(w, curL) : (Object)MixedUtils.evenSplitVector(w, curL);
                int count = 0;
                for (String p : params) {
                    im.setParameterValue(p, (double)newWeights[count]);
                    ++count;
                }
            }
        }
        return im;
    }

    public static List<String> getEdgeParams(String s1, String s2, GeneralizedSemPm pm) {
        Node n1 = pm.getNode(s1);
        Node n2 = pm.getNode(s2);
        return MixedUtils.getEdgeParams(n1, n2, pm);
    }

    public static double[] arrayPermute(double[] a) {
        int i;
        double[] out = new double[a.length];
        ArrayList<Double> l = new ArrayList<Double>(a.length);
        for (i = 0; i < a.length; ++i) {
            l.add(i, a[i]);
        }
        RandomUtil.shuffle(l);
        for (i = 0; i < a.length; ++i) {
            out[i] = (Double)l.get(i);
        }
        return out;
    }

    public static int[] arrayPermute(int[] a) {
        int i;
        int[] out = new int[a.length];
        ArrayList<Integer> l = new ArrayList<Integer>(a.length);
        for (i = 0; i < a.length; ++i) {
            l.add(i, a[i]);
        }
        RandomUtil.shuffle(l);
        for (i = 0; i < a.length; ++i) {
            out[i] = (Integer)l.get(i);
        }
        return out;
    }

    public static double[] evenSplitVector(double w, int L) {
        double[] vec = new double[L];
        double step = 2.0 * w / ((double)L - 1.0);
        for (int i = 0; i < L; ++i) {
            vec[i] = -w + (double)i * step;
        }
        return vec;
    }

    public static List<String> getEdgeParams(Node n1, Node n2, GeneralizedSemPm pm) {
        Node parent;
        Node child;
        Set<String> allParams = pm.getParameters();
        if (pm.getReferencedNodes(n1).contains(n2)) {
            child = n1;
            parent = n2;
        } else if (pm.getReferencedNodes(n2).contains(n1)) {
            child = n2;
            parent = n1;
        } else {
            return null;
        }
        Pattern parPat = parent instanceof DiscreteVariable ? Pattern.compile("Switch\\(" + parent.getName() + ",.*?\\)") : Pattern.compile("([BC][0-9]*\\*" + parent.getName() + ")(?![0-9])");
        ArrayList<String> paramList = new ArrayList<String>();
        String ex = pm.getNodeExpressionString(child);
        Matcher mat = parPat.matcher(ex);
        while (mat.find()) {
            String curGroup = mat.group();
            if (parent instanceof DiscreteVariable) {
                curGroup = curGroup.substring(("Switch(" + parent.getName()).length() + 1, curGroup.length() - 1);
                String[] pars = curGroup.split(",");
                paramList.addAll(Arrays.asList(pars));
                continue;
            }
            String p = curGroup.split("\\*")[0];
            paramList.add(p);
        }
        return paramList;
    }

    public static double[] generateMixedEdgeParams(double w, int L) {
        double[] vec = new double[L];
        RandomUtil ru = RandomUtil.getInstance();
        for (int i = 0; i < L; ++i) {
            vec[i] = ru.nextUniform(0.0, 1.0);
        }
        double vMean = StatUtils.mean(vec);
        double vMax = 0.0;
        for (int i = 0; i < L; ++i) {
            vec[i] = vec[i] - vMean;
            if (!(FastMath.abs(vec[i]) > FastMath.abs(vMax))) continue;
            vMax = vec[i];
        }
        double scale = w / vMax;
        if (vMax < 0.0) {
            scale *= -1.0;
        }
        int i = 0;
        while (i < L) {
            int n = i++;
            vec[n] = vec[n] * scale;
        }
        return vec;
    }

    public static int[][] allEdgeStats(Graph pT, Graph pE) {
        HashMap<String, String> nd = new HashMap<String, String>();
        for (Node n : pE.getNodes()) {
            if (n instanceof DiscreteVariable) {
                nd.put(n.getName(), "Disc");
                continue;
            }
            nd.put(n.getName(), "Norm");
        }
        return MixedUtils.allEdgeStats(pT, pE, nd);
    }

    public static int[][] allEdgeStats(Graph pT, Graph pE, HashMap<String, String> nodeDists) {
        int edgeType;
        Node n2;
        Node n1;
        int[][] stats = new int[3][10];
        for (int i = 0; i < stats.length; ++i) {
            for (int j = 0; j < stats[0].length; ++j) {
                stats[i][j] = 0;
            }
        }
        Set<Edge> edgesT = pT.getEdges();
        Set<Edge> edgesE = pE.getEdges();
        for (Edge eT : edgesT) {
            n1 = pE.getNode(eT.getNode1().getName());
            n2 = pE.getNode(eT.getNode2().getName());
            edgeType = nodeDists.get(n1.getName()).equals("Norm") && nodeDists.get(n2.getName()).equals("Norm") ? 0 : (nodeDists.get(n1.getName()).equals("Disc") && nodeDists.get(n2.getName()).equals("Disc") ? 2 : 1);
            Edge eE = pE.getEdge(n1, n2);
            if (eE == null) {
                if (eT.isDirected()) {
                    int[] nArray = stats[edgeType];
                    nArray[7] = nArray[7] + 1;
                    continue;
                }
                int[] nArray = stats[edgeType];
                nArray[8] = nArray[8] + 1;
                continue;
            }
            if (eE.isDirected()) {
                if (eT.isDirected() && eT.pointsTowards(eT.getNode1()) == eE.pointsTowards(n1)) {
                    int[] nArray = stats[edgeType];
                    nArray[0] = nArray[0] + 1;
                    continue;
                }
                if (eT.isDirected()) {
                    int[] nArray = stats[edgeType];
                    nArray[2] = nArray[2] + 1;
                    continue;
                }
                int[] nArray = stats[edgeType];
                nArray[3] = nArray[3] + 1;
                continue;
            }
            if (eT.isDirected()) {
                int[] nArray = stats[edgeType];
                nArray[4] = nArray[4] + 1;
                continue;
            }
            int[] nArray = stats[edgeType];
            nArray[1] = nArray[1] + 1;
        }
        for (Edge eE : edgesE) {
            Edge eT;
            n1 = pT.getNode(eE.getNode1().getName());
            n2 = pT.getNode(eE.getNode2().getName());
            edgeType = nodeDists.get(n1.getName()).equals("Norm") && nodeDists.get(n2.getName()).equals("Norm") ? 0 : (nodeDists.get(n1.getName()).equals("Disc") && nodeDists.get(n2.getName()).equals("Disc") ? 2 : 1);
            if (eE.getEndpoint1() == Endpoint.ARROW && eE.getEndpoint2() == Endpoint.ARROW) {
                int[] nArray = stats[edgeType];
                nArray[9] = nArray[9] + 1;
            }
            if ((eT = pT.getEdge(n1, n2)) != null) continue;
            if (eE.isDirected()) {
                int[] nArray = stats[edgeType];
                nArray[5] = nArray[5] + 1;
                continue;
            }
            int[] nArray = stats[edgeType];
            nArray[6] = nArray[6] + 1;
        }
        return stats;
    }

    public static Graph makeMixedGraph(Graph g, Map<String, Integer> m) {
        List<Node> nodes = g.getNodes();
        for (int i = 0; i < nodes.size(); ++i) {
            Node n = nodes.get(i);
            int nL = m.get(n.getName());
            if (nL <= 0) continue;
            DiscreteVariable nNew = new DiscreteVariable(n.getName(), nL);
            nodes.set(i, nNew);
        }
        EdgeListGraph outG = new EdgeListGraph(nodes);
        for (Edge e : g.getEdges()) {
            Node n1 = e.getNode1();
            Node n2 = e.getNode2();
            Edge eNew = new Edge(outG.getNode(n1.getName()), outG.getNode(n2.getName()), e.getEndpoint1(), e.getEndpoint2());
            outG.addEdge(eNew);
        }
        return outG;
    }

    public static String stringFrom2dArray(int[][] arr) {
        String outStr = "";
        for (int[] ints : arr) {
            for (int j = 0; j < ints.length; ++j) {
                outStr = outStr + Integer.toString(ints[j]);
                if (j == ints.length - 1) continue;
                outStr = outStr + "\t";
            }
            outStr = outStr + "\n";
        }
        return outStr;
    }

    public static DataSet loadDataSet(String dir, String filename) throws IOException {
        File file = new File(dir, filename);
        return SimpleDataLoader.loadContinuousData(file, "//", '\"', "*", true, Delimiter.TAB, false);
    }

    public static DataSet loadDelim(String dir, String filename) throws IOException {
        File file = new File(dir, filename);
        return SimpleDataLoader.loadContinuousData(file, "//", '\"', "*", false, Delimiter.TAB, false);
    }

    public static Map<String, Integer> getNodeDists(Graph g) {
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        List<Node> nodes = g.getNodes();
        for (Node n : nodes) {
            if (n instanceof DiscreteVariable) {
                map.put(n.getName(), ((DiscreteVariable)n).getNumCategories());
                continue;
            }
            map.put(n.getName(), 0);
        }
        return map;
    }

    public static DataSet loadData(String dir, String filename) throws IOException {
        File file = new File(dir, filename);
        return SimpleDataLoader.loadContinuousData(file, "//", '\"', "*", true, Delimiter.TAB, false);
    }

    public static boolean isColinear(DataSet ds, boolean verbose) {
        List<Node> nodes = ds.getVariables();
        boolean isco = false;
        CorrelationMatrix cor = new CorrelationMatrix(MixedUtils.makeContinuousData(ds));
        for (int i = 0; i < nodes.size(); ++i) {
            for (int j = i + 1; j < nodes.size(); ++j) {
                if (cor.getValue(i, j) != 1.0) continue;
                if (verbose) {
                    isco = true;
                    System.out.println("Colinearity found between: " + nodes.get(i).getName() + " and " + nodes.get(j).getName());
                    continue;
                }
                return true;
            }
        }
        return isco;
    }

    public static DoubleMatrix2D graphToMatrix(Graph graph, double undirectedWeight, double directedWeight) {
        int n = graph.getNumNodes();
        DoubleMatrix2D matrix = DoubleFactory2D.dense.make(n, n, 0.0);
        HashMap<Node, Integer> map = new HashMap<Node, Integer>();
        int i = 0;
        for (Node node : graph.getNodes()) {
            map.put(node, i);
            ++i;
        }
        for (Edge edge : graph.getEdges()) {
            Node node1 = edge.getNode1();
            Node node2 = edge.getNode2();
            if (!edge.isDirected() || edge.getEndpoint1() == Endpoint.ARROW && edge.getEndpoint2() == Endpoint.ARROW) {
                matrix.set((Integer)map.get(node1), (Integer)map.get(node2), undirectedWeight);
                matrix.set((Integer)map.get(node2), (Integer)map.get(node1), undirectedWeight);
                continue;
            }
            if (edge.pointsTowards(node1)) {
                matrix.set((Integer)map.get(node2), (Integer)map.get(node1), directedWeight);
                continue;
            }
            matrix.set((Integer)map.get(node1), (Integer)map.get(node2), directedWeight);
        }
        return matrix;
    }

    public static DoubleMatrix2D skeletonToMatrix(Graph graph) {
        int n = graph.getNumNodes();
        DoubleMatrix2D matrix = DoubleFactory2D.dense.make(n, n, 0.0);
        HashMap<Node, Integer> map = new HashMap<Node, Integer>();
        int i = 0;
        for (Node node : graph.getNodes()) {
            map.put(node, i);
            ++i;
        }
        for (Edge edge : graph.getEdges()) {
            Node node1 = edge.getNode1();
            Node node2 = edge.getNode2();
            matrix.set((Integer)map.get(node1), (Integer)map.get(node2), 1.0);
            matrix.set((Integer)map.get(node2), (Integer)map.get(node1), 1.0);
        }
        return matrix;
    }

    public static DoubleMatrix2D graphToMatrix(Graph graph) {
        return MixedUtils.graphToMatrix(graph, 1.0, 1.0);
    }

    public static IndependenceTest IndTestFromString(String name, DataSet data, double alpha) {
        IndependenceTest test = null;
        if (name.equals("tlin")) {
            test = new IndTestMixedMultipleTTest(data, alpha);
            test.setPreferLinear(true);
        } else if (name.equals("tlog")) {
            test = new IndTestMixedMultipleTTest(data, alpha);
            test.setPreferLinear(false);
        } else {
            Class<?> cl = null;
            try {
                cl = Class.forName("edu.cmu.tetrad.search." + name);
            }
            catch (ClassNotFoundException e) {
                System.out.println("Not found: edu.cmu.tetrad.search." + name);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            if (cl == null) {
                try {
                    cl = Class.forName("edu.pitt.csb.mgm." + name);
                }
                catch (ClassNotFoundException e) {
                    throw new IllegalArgumentException("-test argument not recognized");
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            try {
                Constructor<?> con = cl.getConstructor(DataSet.class, Double.TYPE);
                test = (IndependenceTest)con.newInstance(data, alpha);
            }
            catch (NoSuchMethodException e) {
                System.err.println("Independence Test: " + name + " not found");
            }
            catch (Exception e) {
                System.err.println("Independence Test: " + name + " found but not constructed");
                e.printStackTrace();
            }
        }
        return test;
    }

    public static void main(String[] args) {
        Graph g = GraphUtils.convert("X1-->X2,X2-->X3,X3-->X4, X5-->X4");
        HashMap<String, Integer> nd = new HashMap<String, Integer>();
        nd.put("X1", 0);
        nd.put("X2", 0);
        nd.put("X3", 4);
        nd.put("X4", 4);
        nd.put("X5", 0);
        g = MixedUtils.makeMixedGraph(g, nd);
        GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-1,1,1.5)");
        System.out.println(pm);
        System.out.println("STARTS WITH");
        System.out.println(pm.getStartsWithParameterTemplate("C"));
        try {
            MixedUtils.setStartsWith("C", "Split(-.9,-.5,.5,.9)", pm);
        }
        catch (Throwable t) {
            t.printStackTrace();
        }
        System.out.println("STARTS WITH");
        System.out.println(pm.getStartsWithParameterTemplate("C"));
        System.out.println(pm);
        GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
        System.out.println(im);
        int samps = 15;
        DataSet ds = im.simulateDataFisher(15);
        System.out.println(ds);
        System.out.println("num cats " + ((DiscreteVariable)g.getNode("X4")).getNumCategories());
    }
}

