/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.data.Clusters;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.sem.MimBuildEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public final class IndTestMimBuild
implements IndependenceTest {
    public static final int MIMBUILD_MLE = 0;
    public static final int MIMBUILD_2SLS = 1;
    public static final int MIMBUILD_BOOTSTRAP = 2;
    public static final int MIMBUILD_GES_ABIC = 0;
    public static final int MIMBUILD_GES_SBIC = -1;
    public static final int MIMBUILD_PC = 1;
    private DataSet dataSet;
    private CovarianceMatrix covMatrix;
    private List<Node> vars;
    private Knowledge measurements;
    private List<String> latents;
    private SemGraph graph;
    private double sig = Double.NaN;
    private Hashtable measureTable;
    private int testType;
    private int algorithmType;
    private int numBootstrapSamples;
    private double[][][] bootstrapSamples;

    public IndTestMimBuild(DataSet dataSet, double sig, Clusters measurements) {
        this.setData(dataSet);
        this.vars = dataSet.getVariables();
        this.latents = new ArrayList<String>();
        this.measureTable = new Hashtable();
        this.setMeasurementsSource(measurements);
        this.setSignificance(sig);
        this.testType = 0;
        this.algorithmType = 0;
        this.numBootstrapSamples = 100;
    }

    public IndTestMimBuild(CovarianceMatrix covMatrix, double sig, Clusters measurements) {
        this.setCovMatrix(covMatrix);
        this.vars = covMatrix.getVariables();
        this.latents = new ArrayList<String>();
        this.measureTable = new Hashtable();
        this.setMeasurementsSource(measurements);
        this.setSignificance(sig);
        this.testType = 0;
        this.algorithmType = 0;
        this.numBootstrapSamples = 100;
    }

    public IndependenceTest indTestSubset(List vars) {
        throw new UnsupportedOperationException();
    }

    public List<String> getAllVariablesStrings() {
        LinkedList<String> list = new LinkedList<String>();
        Iterator<KnowledgeEdge> it = this.measurements.requiredEdgesIterator();
        while (it.hasNext()) {
            KnowledgeEdge temp = it.next();
            String x = temp.getFrom();
            String y = temp.getTo();
            if (list.indexOf(x) == -1) {
                list.add(x);
            }
            list.add(y);
        }
        return list;
    }

    public List<Node> getVariableList() {
        LinkedList<String> listNames = new LinkedList<String>();
        LinkedList<Node> outputList = new LinkedList<Node>();
        Iterator<KnowledgeEdge> it = this.measurements.requiredEdgesIterator();
        while (it.hasNext()) {
            KnowledgeEdge temp = it.next();
            String x = temp.getFrom();
            String y = temp.getTo();
            if (listNames.indexOf(x) == -1) {
                listNames.add(x);
                outputList.add(new ContinuousVariable(x));
            }
            if (listNames.indexOf(y) != -1) continue;
            listNames.add(y);
            outputList.add(new ContinuousVariable(y));
        }
        return outputList;
    }

    public void setData(DataSet dataSet) {
        this.dataSet = dataSet;
        this.covMatrix = new CovarianceMatrix(dataSet);
    }

    @Override
    public DataSet getData() {
        return this.dataSet;
    }

    public void setCovMatrix(CovarianceMatrix covMatrix) {
        this.covMatrix = covMatrix;
    }

    public CovarianceMatrix getCovMatrix() {
        return this.covMatrix;
    }

    public void setNumBootstrapSamples(int numSamples) {
        this.numBootstrapSamples = numSamples;
    }

    public int getNumBootstrapSamples() {
        return this.numBootstrapSamples;
    }

    public Knowledge getMeasurements() {
        return this.measurements;
    }

    public void initMeasurements() {
        this.latents.clear();
        this.measureTable.clear();
        Iterator<KnowledgeEdge> it = this.measurements.requiredEdgesIterator();
        while (it.hasNext()) {
            KnowledgeEdge temp = it.next();
            String x = temp.getFrom();
            String y = temp.getTo();
            if (!this.measureTable.containsKey(x)) {
                this.latents.add(x);
                ArrayList<String> measures = new ArrayList<String>();
                measures.add(y);
                this.measureTable.put(x, measures);
                continue;
            }
            ((List)this.measureTable.get(x)).add(y);
        }
    }

    public void setMeasurementsSource(Clusters clusters) {
        this.measurements = new Knowledge();
        Node[] included_latents = new Node[clusters.getNumClusters()];
        for (String varName : clusters.getClusters().keySet()) {
            ArrayList<Integer> listRelated;
            Integer relatedClusters = clusters.getClusters().get(varName);
            if (relatedClusters instanceof Integer) {
                listRelated = new ArrayList<Integer>();
                listRelated.add(relatedClusters);
            } else {
                listRelated = (ArrayList<Integer>)((Object)relatedClusters);
            }
            for (Object e : listRelated) {
                int cluster_id = (Integer)e;
                String latent_string = "_L" + (cluster_id + 1);
                if (included_latents[cluster_id] == null) {
                    included_latents[cluster_id] = new GraphNode(latent_string);
                    Node tetradNode = included_latents[cluster_id];
                    tetradNode.setNodeType(NodeType.LATENT);
                }
                this.measurements.setEdgeRequired(latent_string, varName, true);
            }
        }
        this.initMeasurements();
    }

    public void setSignificance(double sig) {
        if (!(sig >= 0.0) || !(sig <= 1.0)) {
            throw new IllegalArgumentException("Significance out of range.");
        }
        this.sig = sig;
    }

    public double getSignificance() {
        return this.sig;
    }

    public void setAlgorithmType(int algoType) {
        if (algoType != -1 && algoType != 0 && algoType != 1) {
            throw new IllegalArgumentException("Invalid algorithm test.");
        }
        this.algorithmType = algoType;
    }

    public int getAlgorithmType() {
        return this.algorithmType;
    }

    @Override
    public List<Node> getVariables() {
        return Collections.unmodifiableList(this.vars);
    }

    @Override
    public boolean isIndependent(Node x, Node y, List<Node> z) {
        GraphNode measured;
        System.out.println("\n\n************************************************");
        System.out.println(" Testing " + x + " against " + y);
        System.out.print(" Conditional on " + z);
        System.out.println();
        System.out.println("************************************************");
        if (this.testType == 2) {
            return this.isIndependentBootstrap(x, y, z);
        }
        ArrayList<String> subset = new ArrayList<String>();
        String[] z_names = new String[z.size()];
        Node[] node_z = new Node[z.size()];
        this.graph = new SemGraph();
        GraphNode node_x = new GraphNode(x.getName());
        node_x.setNodeType(NodeType.LATENT);
        this.graph.addNode(node_x);
        GraphNode node_y = new GraphNode(y.getName());
        node_y.setNodeType(NodeType.LATENT);
        this.graph.addNode(node_y);
        Iterator<Node> it = z.iterator();
        int i = 0;
        while (it.hasNext()) {
            String current_z = ((Object)it.next()).toString();
            node_z[i] = new GraphNode(current_z);
            node_z[i].setNodeType(NodeType.LATENT);
            z_names[i] = current_z;
            this.graph.addNode(node_z[i]);
            this.graph.addDirectedEdge(node_z[i], node_x);
            this.graph.addDirectedEdge(node_z[i], node_y);
            ++i;
        }
        for (int p = 0; p < z.size() - 1; ++p) {
            for (int q = p + 1; q < z.size(); ++q) {
                this.graph.addDirectedEdge(node_z[p], node_z[q]);
            }
        }
        for (String next_measure : (List)this.measureTable.get(((Object)x).toString())) {
            measured = new GraphNode(next_measure);
            measured.setNodeType(NodeType.MEASURED);
            this.graph.addNode(measured);
            this.graph.addDirectedEdge(node_x, measured);
            subset.add(next_measure);
        }
        for (String next_measure : (List)this.measureTable.get(((Object)y).toString())) {
            measured = new GraphNode(next_measure);
            measured.setNodeType(NodeType.MEASURED);
            this.graph.addNode(measured);
            this.graph.addDirectedEdge(node_y, measured);
            subset.add(next_measure);
        }
        for (i = 0; i < z.size(); ++i) {
            for (String next_measure : (List)this.measureTable.get(z_names[i])) {
                measured = new GraphNode(next_measure);
                measured.setNodeType(NodeType.MEASURED);
                this.graph.addNode(measured);
                this.graph.addDirectedEdge(node_z[i], measured);
                subset.add(next_measure);
            }
        }
        String[] v = new String[this.graph.getNodes().size()];
        int count = 0;
        for (Node node : this.graph.getNodes()) {
            v[count++] = node.getName();
        }
        String[] variables = new String[subset.size()];
        for (int j = 0; j < subset.size(); ++j) {
            variables[j] = (String)subset.get(j);
        }
        CovarianceMatrix newCov = this.covMatrix.getSubmatrix(variables);
        if (this.testType == 0) {
            SemPm pm = new SemPm(new SemGraph(this.graph));
            MimBuildEstimator estimator = MimBuildEstimator.newInstance(newCov, pm);
            System.out.println("\nEvaluating model without edge, MLE...");
            estimator.estimate();
            SemIm sem = estimator.getEstimatedSem();
            double prob_wo_edge = sem.getPValue();
            System.out.println("Prob significance = " + prob_wo_edge);
            this.graph.addDirectedEdge(node_x, node_y);
            pm = new SemPm(this.graph);
            estimator = MimBuildEstimator.newInstance(newCov, pm);
            System.out.println("Evaluating model with edge, MLE...");
            estimator.estimate();
            SemIm sem2 = estimator.getEstimatedSem();
            double prob_w_edge = sem2.getPValue();
            System.out.println("Prob significance = " + prob_w_edge);
            double pValue = 1.0 - ProbUtils.chisqCdf(sem.getChiSquare() - sem2.getChiSquare(), 1.0);
            if (pValue > this.sig) {
                TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, pValue));
                System.out.println("Independent!");
            } else {
                TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, pValue));
                System.out.println("NOT independent!");
            }
            return pValue > this.sig;
        }
        if (this.testType == 1) {
            throw new RuntimeException("Not currently supported!");
        }
        return true;
    }

    @Override
    public boolean isIndependent(Node x, Node y, Node ... z) {
        List<Node> zList = Arrays.asList(z);
        return this.isIndependent(x, y, zList);
    }

    @Override
    public boolean isDependent(Node x, Node y, List<Node> z) {
        return !this.isIndependent(x, y, z);
    }

    @Override
    public boolean isDependent(Node x, Node y, Node ... z) {
        List<Node> zList = Arrays.asList(z);
        return this.isDependent(x, y, zList);
    }

    public boolean isIndependentBootstrap(Node x, Node y, List<Node> z) {
        int size = z.size() + 2;
        int[] indices = new int[size];
        indices[0] = this.latents.indexOf(x);
        indices[1] = this.latents.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = this.latents.indexOf(z.get(i));
        }
        double sum_r = 0.0;
        double sum_r2 = 0.0;
        for (int iter = 0; iter < this.numBootstrapSamples; ++iter) {
            double[][] submatrix = new double[size][size];
            for (int i = 0; i < size; ++i) {
                for (int j = 0; j < size; ++j) {
                    submatrix[i][j] = this.bootstrapSamples[iter][indices[i]][indices[j]];
                }
            }
            try {
                submatrix = MatrixUtils.inverse(submatrix);
            }
            catch (Exception e) {
                throw new RuntimeException("Matrix singularity detected while using correlations \nto check for independence; probably due to collinearity \nin the data. The independence fact being checked was \n" + x + " _||_ " + y + " | " + z + ".", e);
            }
            double r = -1.0 * submatrix[0][1] / Math.pow(submatrix[0][0] * submatrix[1][1], 0.5);
            sum_r += r;
            sum_r2 += r * r;
        }
        double mean = sum_r / (double)this.numBootstrapSamples;
        double variance = sum_r2 / (double)this.numBootstrapSamples - mean * mean;
        System.out.println("Statistic: " + mean / variance);
        return this.isZeroBootstrap(mean, variance, this.sig);
    }

    boolean isZeroBootstrap(double mean, double variance, double sig) {
        return Math.abs(mean / variance) < 1.96;
    }

    @Override
    public double getPValue() {
        return Double.NaN;
    }

    @Override
    public List<String> getVariableNames() {
        List<Node> variables = this.getVariables();
        ArrayList<String> variableNames = new ArrayList<String>();
        for (Node variable : variables) {
            variableNames.add(variable.getName());
        }
        return variableNames;
    }

    public boolean determines(List z, Node x1) {
        throw new UnsupportedOperationException("This independence test does not test whether Z determines X for list Z of variable and variable X.");
    }

    @Override
    public double getAlpha() {
        return this.sig;
    }

    @Override
    public void setAlpha(double alpha) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Node getVariable(String name) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            if (!variable.getName().equals(name)) continue;
            return variable;
        }
        return null;
    }

    public void bootstrap() {
        this.bootstrapSamples = this.getBootstrapSamples(this.numBootstrapSamples);
    }

    private double[][][] getBootstrapSamples(int numSamples) {
        SemGraph graph = new SemGraph();
        DataSet dataContinuous = this.getData();
        int totalLatents = this.latents.size();
        Node[] latentsArray = new Node[totalLatents];
        double[][][] samples = new double[numSamples][totalLatents][totalLatents];
        int count = 0;
        for (String current_z : this.latents) {
            latentsArray[count] = new GraphNode(current_z);
            latentsArray[count].setNodeType(NodeType.LATENT);
            graph.addNode(latentsArray[count]);
            ++count;
        }
        for (int p = 0; p < this.latents.size() - 1; ++p) {
            for (int q = p + 1; q < this.latents.size(); ++q) {
                graph.addDirectedEdge(latentsArray[p], latentsArray[q]);
            }
        }
        for (Node aLatentsArray : latentsArray) {
            String key = ((Object)aLatentsArray).toString();
            List list = (List)this.measureTable.get(key);
            for (String next_measure : list) {
                GraphNode measured = new GraphNode(next_measure);
                measured.setNodeType(NodeType.MEASURED);
                graph.addNode(measured);
                graph.addDirectedEdge(aLatentsArray, measured);
            }
        }
        SemPm pm = new SemPm(graph);
        this.fixLatentOrder(pm);
        int sampleSize = dataContinuous.getNumRows();
        int numColumns = dataContinuous.getNumColumns();
        ColtDataSet dummyDataSet = new ColtDataSet(sampleSize, this.getVariables());
        for (int iter = 0; iter < numSamples; ++iter) {
            int row;
            for (int i = 0; i < sampleSize; ++i) {
                row = RandomUtil.getInstance().nextInt(sampleSize);
                for (int j = 0; j < numColumns; ++j) {
                    dummyDataSet.setDouble(row, j, dataContinuous.getDouble(row, j));
                }
            }
            System.out.println("********\n Estimating latent covariance matrix #" + iter + "...");
            MimBuildEstimator estimator = MimBuildEstimator.newInstance(dummyDataSet, pm);
            estimator.estimate();
            row = 0;
            int i = 0;
            DoubleMatrix2D implCovarC = estimator.getEstimatedSem().getImplCovar();
            double[][] implCov = implCovarC.toArray();
            for (Node pmNext1 : pm.getVariableNodes()) {
                if (pmNext1.getNodeType() == NodeType.LATENT) {
                    int column = 0;
                    int j = 0;
                    for (Node pmNext2 : pm.getVariableNodes()) {
                        if (pmNext2.getNodeType() == NodeType.LATENT) {
                            samples[iter][i][j++] = implCov[row][column];
                        }
                        ++column;
                    }
                    ++i;
                }
                ++row;
            }
        }
        return samples;
    }

    private void fixLatentOrder(SemPm semPm) {
        ArrayList<String> newLatents = new ArrayList<String>(this.latents.size());
        block0: for (Node pmNext : semPm.getVariableNodes()) {
            if (pmNext.getNodeType() != NodeType.LATENT) continue;
            for (String latentNext : this.latents) {
                if (!latentNext.equals(pmNext.getName())) continue;
                newLatents.add(latentNext);
                continue block0;
            }
        }
        this.latents = newLatents;
        System.out.println(this.latents.toString());
    }

    public String toString() {
        return "MimBuild independence test";
    }
}

