/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.LayoutUtil;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.search.Cpc;
import edu.cmu.tetrad.search.test.IndTestTrekSep;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer;

public class MimbuildTrek {
    private List<List<Node>> clustering;
    private Graph structureGraph;
    private double alpha = 0.001;
    private Knowledge knowledge = new Knowledge();
    private ICovarianceMatrix latentsCov;
    private double minimum;
    private double pValue;
    private List<Node> latents;
    private int minClusterSize = 3;

    public Graph search(List<List<Node>> clustering, List<String> latentNames, ICovarianceMatrix measuresCov) {
        ArrayList<String> _latentNames = new ArrayList<String>(latentNames);
        ArrayList<String> allVarNames = new ArrayList<String>();
        for (List<Node> list : clustering) {
            for (Node node : list) {
                allVarNames.add(node.getName());
            }
        }
        measuresCov = measuresCov.getSubmatrix(allVarNames);
        ArrayList<List<Node>> _clustering = new ArrayList<List<Node>>();
        for (List<Node> list : clustering) {
            ArrayList<Node> _cluster = new ArrayList<Node>();
            for (Node node : list) {
                _cluster.add(measuresCov.getVariable(node.getName()));
            }
            _clustering.add(_cluster);
        }
        List<Node> list = this.defineLatents(_latentNames);
        this.latents = list;
        this.removeSmallClusters(list, _clustering, this.getMinClusterSize());
        this.clustering = _clustering;
        Node[][] nodeArray = new Node[list.size()][];
        for (int i = 0; i < list.size(); ++i) {
            nodeArray[i] = new Node[((List)_clustering.get(i)).size()];
            for (int j = 0; j < ((List)_clustering.get(i)).size(); ++j) {
                nodeArray[i][j] = (Node)((List)_clustering.get(i)).get(j);
            }
        }
        Matrix cov = this.getCov(measuresCov, list, nodeArray);
        this.latentsCov = new CovarianceMatrix(list, cov, measuresCov.getSampleSize());
        Cpc search = new Cpc(new IndTestTrekSep(measuresCov, this.alpha, clustering, list));
        search.setKnowledge(this.knowledge);
        Graph graph = search.search();
        this.structureGraph = new EdgeListGraph(graph);
        LayoutUtil.fruchtermanReingoldLayout(this.structureGraph);
        return this.structureGraph;
    }

    public List<List<Node>> getClustering() {
        return this.clustering;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public ICovarianceMatrix getLatentsCov() {
        return this.latentsCov;
    }

    public double getpValue() {
        return this.pValue;
    }

    public Graph getFullGraph() {
        EdgeListGraph graph = new EdgeListGraph(this.structureGraph);
        for (int i = 0; i < this.latents.size(); ++i) {
            Node latent = this.latents.get(i);
            List<Node> measuredGuys = this.getClustering().get(i);
            for (Node measured : measuredGuys) {
                if (!graph.containsNode(measured)) {
                    graph.addNode(measured);
                }
                graph.addDirectedEdge(latent, measured);
            }
        }
        return graph;
    }

    public void setEpsilon(double epsilon) {
        if (epsilon < 0.0) {
            throw new IllegalArgumentException("Epsilon mut be >= 0: " + epsilon);
        }
    }

    private List<Node> defineLatents(List<String> names) {
        ArrayList<Node> latents = new ArrayList<Node>();
        for (String name : names) {
            GraphNode node = new GraphNode(name);
            node.setNodeType(NodeType.LATENT);
            latents.add(node);
        }
        return latents;
    }

    private void removeSmallClusters(List<Node> latents, List<List<Node>> clustering, int minimumSize) {
        for (int i = new ArrayList<Node>(latents).size() - 1; i >= 0; --i) {
            if (clustering.get(i).size() >= minimumSize) continue;
            clustering.remove(clustering.get(i));
            latents.remove(latents.get(i));
        }
    }

    private Matrix getCov(ICovarianceMatrix _measurescov, List<Node> latents, Node[][] indicators) {
        int i;
        if (latents.size() != indicators.length) {
            throw new IllegalArgumentException();
        }
        Matrix measurescov = _measurescov.getMatrix();
        Matrix latentscov = new Matrix(latents.size(), latents.size());
        for (int i2 = 0; i2 < latentscov.getNumRows(); ++i2) {
            for (int j = i2; j < latentscov.getNumColumns(); ++j) {
                if (i2 == j) {
                    latentscov.set(i2, j, 1.0);
                    continue;
                }
                double v = 0.5;
                latentscov.set(i2, j, 0.5);
                latentscov.set(j, i2, 0.5);
            }
        }
        double[][] loadings = new double[indicators.length][];
        for (i = 0; i < indicators.length; ++i) {
            loadings[i] = new double[indicators[i].length];
        }
        for (i = 0; i < indicators.length; ++i) {
            loadings[i] = new double[indicators[i].length];
            for (int j = 0; j < indicators[i].length; ++j) {
                loadings[i][j] = 0.5;
            }
        }
        int[][] indicatorIndices = new int[indicators.length][];
        List<Node> measures = _measurescov.getVariables();
        for (int i3 = 0; i3 < indicators.length; ++i3) {
            indicatorIndices[i3] = new int[indicators[i3].length];
            for (int j = 0; j < indicators[i3].length; ++j) {
                indicatorIndices[i3][j] = measures.indexOf(indicators[i3][j]);
            }
        }
        double[] delta = new double[measurescov.getNumRows()];
        Arrays.fill(delta, 1.0);
        double[] allParams1 = this.getAllParams(indicators, latentscov, loadings, delta);
        this.optimizeNonMeasureVariancesQuick(indicators, measurescov, latentscov, loadings, indicatorIndices);
        int numParams = allParams1.length;
        this.optimizeAllParamsSimultaneously(indicators, measurescov, latentscov, loadings, indicatorIndices, delta);
        double N = _measurescov.getSampleSize();
        int p = _measurescov.getDimension();
        int df = p * (p + 1) / 2 - numParams;
        double x = (N - 1.0) * this.minimum;
        this.pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(x);
        return latentscov;
    }

    private void optimizeNonMeasureVariancesQuick(Node[][] indicators, Matrix measurescov, Matrix latentscov, double[][] loadings, int[][] indicatorIndices) {
        int j;
        int i;
        int count = 0;
        for (int i2 = 0; i2 < indicators.length; ++i2) {
            for (int j2 = i2; j2 < indicators.length; ++j2) {
                ++count;
            }
        }
        for (Node[] indicator : indicators) {
            for (int j3 = 0; j3 < indicator.length; ++j3) {
                ++count;
            }
        }
        double[] values = new double[count];
        count = 0;
        for (i = 0; i < indicators.length; ++i) {
            for (j = i; j < indicators.length; ++j) {
                values[count++] = latentscov.get(i, j);
            }
        }
        for (i = 0; i < indicators.length; ++i) {
            for (j = 0; j < indicators[i].length; ++j) {
                values[count++] = loadings[i][j];
            }
        }
        Function1 function1 = new Function1(indicatorIndices, measurescov, loadings, latentscov);
        PowellOptimizer search = new PowellOptimizer(1.0E-7, 1.0E-7);
        PointValuePair pair = search.optimize(new InitialGuess(values), new ObjectiveFunction(function1), GoalType.MINIMIZE, new MaxEval(100000));
        this.minimum = (Double)pair.getValue();
    }

    private void optimizeAllParamsSimultaneously(Node[][] indicators, Matrix measurescov, Matrix latentscov, double[][] loadings, int[][] indicatorIndices, double[] delta) {
        double[] values = this.getAllParams(indicators, latentscov, loadings, delta);
        Function4 function = new Function4(indicatorIndices, measurescov, loadings, latentscov, delta);
        PowellOptimizer search = new PowellOptimizer(1.0E-7, 1.0E-7);
        PointValuePair pair = search.optimize(new InitialGuess(values), new ObjectiveFunction(function), GoalType.MINIMIZE, new MaxEval(100000));
        this.minimum = (Double)pair.getValue();
    }

    private double[] getAllParams(Node[][] indicators, Matrix latentscov, double[][] loadings, double[] delta) {
        int j;
        int i;
        int count = 0;
        for (int i2 = 0; i2 < indicators.length; ++i2) {
            for (int j2 = i2; j2 < indicators.length; ++j2) {
                ++count;
            }
        }
        for (Node[] indicator : indicators) {
            for (int j3 = 0; j3 < indicator.length; ++j3) {
                ++count;
            }
        }
        for (int i3 = 0; i3 < delta.length; ++i3) {
            ++count;
        }
        double[] values = new double[count];
        count = 0;
        for (i = 0; i < indicators.length; ++i) {
            for (j = i; j < indicators.length; ++j) {
                values[count] = latentscov.get(i, j);
                ++count;
            }
        }
        for (i = 0; i < indicators.length; ++i) {
            for (j = 0; j < indicators[i].length; ++j) {
                values[count] = loadings[i][j];
                ++count;
            }
        }
        double[] dArray = delta;
        int n = dArray.length;
        for (int k = 0; k < n; ++k) {
            double v;
            values[count] = v = dArray[k];
            ++count;
        }
        return values;
    }

    public int getMinClusterSize() {
        return this.minClusterSize;
    }

    public void setMinClusterSize(int minClusterSize) {
        if (minClusterSize < 3) {
            throw new IllegalArgumentException("Minimum cluster size must be >= 3: " + minClusterSize);
        }
        this.minClusterSize = minClusterSize;
    }

    private Matrix impliedCovariance(int[][] indicatorIndices, double[][] loadings, Matrix cov, Matrix loadingscov, double[] delta) {
        int i;
        Matrix implied = new Matrix(cov.getNumRows(), cov.getNumColumns());
        for (i = 0; i < loadings.length; ++i) {
            for (int j = 0; j < loadings.length; ++j) {
                for (int k = 0; k < loadings[i].length; ++k) {
                    for (int l = 0; l < loadings[j].length; ++l) {
                        double prod = loadings[i][k] * loadings[j][l] * loadingscov.get(i, j);
                        implied.set(indicatorIndices[i][k], indicatorIndices[j][l], prod);
                    }
                }
            }
        }
        for (i = 0; i < implied.getNumRows(); ++i) {
            implied.set(i, i, implied.get(i, i) + delta[i]);
        }
        return implied;
    }

    private double sumOfDifferences(int[][] indicatorIndices, Matrix cov, double[][] loadings, Matrix loadingscov) {
        int i;
        double sum = 0.0;
        for (i = 0; i < loadings.length; ++i) {
            for (int k = 0; k < loadings[i].length; ++k) {
                for (int l = k + 1; l < loadings[i].length; ++l) {
                    double _cov = cov.get(indicatorIndices[i][k], indicatorIndices[i][l]);
                    double prod = loadings[i][k] * loadings[i][l] * loadingscov.get(i, i);
                    double diff = _cov - prod;
                    sum += diff * diff;
                }
            }
        }
        for (i = 0; i < loadings.length; ++i) {
            for (int j = i + 1; j < loadings.length; ++j) {
                for (int k = 0; k < loadings[i].length; ++k) {
                    for (int l = 0; l < loadings[j].length; ++l) {
                        double _cov = cov.get(indicatorIndices[i][k], indicatorIndices[j][l]);
                        double prod = loadings[i][k] * loadings[j][l] * loadingscov.get(i, j);
                        double diff = _cov - prod;
                        sum += 2.0 * diff * diff;
                    }
                }
            }
        }
        return sum;
    }

    private class Function1
    implements MultivariateFunction {
        private final int[][] indicatorIndices;
        private final Matrix measurescov;
        private final double[][] loadings;
        private final Matrix latentscov;

        public Function1(int[][] indicatorIndices, Matrix measurescov, double[][] loadings, Matrix latentscov) {
            this.indicatorIndices = indicatorIndices;
            this.measurescov = measurescov;
            this.loadings = loadings;
            this.latentscov = latentscov;
        }

        @Override
        public double value(double[] values) {
            int j;
            int i;
            int count = 0;
            for (i = 0; i < this.loadings.length; ++i) {
                for (j = i; j < this.loadings.length; ++j) {
                    this.latentscov.set(i, j, values[count]);
                    this.latentscov.set(j, i, values[count]);
                    ++count;
                }
            }
            for (i = 0; i < this.loadings.length; ++i) {
                for (j = 0; j < this.loadings[i].length; ++j) {
                    this.loadings[i][j] = values[count];
                    ++count;
                }
            }
            return MimbuildTrek.this.sumOfDifferences(this.indicatorIndices, this.measurescov, this.loadings, this.latentscov);
        }
    }

    private class Function4
    implements MultivariateFunction {
        private final int[][] indicatorIndices;
        private final Matrix measurescov;
        private final Matrix measuresCovInverse;
        private final double[][] loadings;
        private final Matrix latentscov;
        private final double[] delta;

        public Function4(int[][] indicatorIndices, Matrix measurescov, double[][] loadings, Matrix latentscov, double[] delta) {
            this.indicatorIndices = indicatorIndices;
            this.measurescov = measurescov;
            this.loadings = loadings;
            this.latentscov = latentscov;
            this.delta = delta;
            this.measuresCovInverse = measurescov.inverse();
        }

        @Override
        public double value(double[] values) {
            int j;
            int i;
            int count = 0;
            for (i = 0; i < this.loadings.length; ++i) {
                for (j = i; j < this.loadings.length; ++j) {
                    this.latentscov.set(i, j, values[count]);
                    this.latentscov.set(j, i, values[count]);
                    ++count;
                }
            }
            for (i = 0; i < this.loadings.length; ++i) {
                for (j = 0; j < this.loadings[i].length; ++j) {
                    this.loadings[i][j] = values[count];
                    ++count;
                }
            }
            for (i = 0; i < this.delta.length; ++i) {
                this.delta[i] = values[count];
                ++count;
            }
            Matrix implied = MimbuildTrek.this.impliedCovariance(this.indicatorIndices, this.loadings, this.measurescov, this.latentscov, this.delta);
            Matrix I = Matrix.identity(implied.getNumRows());
            Matrix diff = I.minus(implied.times(this.measuresCovInverse));
            return 0.5 * diff.times(diff).trace();
        }
    }
}

