/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.EigenvalueDecomposition;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphGroup;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.FastIca;
import edu.cmu.tetrad.search.GraphGroupSearch;
import edu.cmu.tetrad.search.GraphWithParameters;
import edu.cmu.tetrad.search.PermutationMatrixPair;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.dist.Distribution;
import edu.cmu.tetrad.util.dist.GaussianPower;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Vector;

public class Ling
implements GraphGroupSearch {
    private int numSamples;
    private double threshold = 0.05;
    private long elapsedTime = 0L;
    private DataSet dataSet;

    public Ling(DataSet d) {
        this.dataSet = d;
    }

    public Ling(GraphWithParameters graphWP, int samples) {
        this.numSamples = samples;
        this.makeDataSet(graphWP);
    }

    public Ling(Graph g, int samples) {
        this.numSamples = samples;
        GraphWithParameters graphWP = new GraphWithParameters(g);
        this.makeDataSet(graphWP);
    }

    public DataSet getData() {
        return this.dataSet;
    }

    @Override
    public StoredGraphs search() {
        StoredGraphs graphs = new StoredGraphs();
        try {
            long sTime = new Date().getTime();
            DoubleMatrix2D data = this.dataSet.getDoubleData();
            FastIca fastIca = new FastIca(data.copy(), data.columns());
            fastIca.setVerbose(true);
            FastIca.IcaResult result = fastIca.findComponents();
            DoubleMatrix2D ica_A = result.getA().viewDice();
            DoubleMatrix2D ica_W = MatrixUtils.inverse(ica_A);
            int n = ica_W.rows();
            graphs = this.findCandidateModels(this.dataSet.getVariables(), ica_W, n, true);
            this.elapsedTime = new Date().getTime() - sTime;
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return graphs;
    }

    @Override
    public long getElapsedTime() {
        return this.elapsedTime;
    }

    public void setThreshold(double t) {
        this.threshold = t;
    }

    private void makeDataSet(GraphWithParameters graphWP) {
        GaussianPower gp2 = new GaussianPower(2.0);
        DoubleMatrix1D errorCoefficients = Ling.getErrorCoeffsIdentity(graphWP.getGraph().getNumNodes());
        DoubleMatrix2D inVectors = Ling.simulateCyclic(graphWP, errorCoefficients, this.numSamples, gp2);
        this.dataSet = ColtDataSet.makeContinuousData(graphWP.getGraph().getNodes(), inVectors.viewDice());
    }

    private static DoubleMatrix1D getErrorCoeffsIdentity(int n) {
        DenseDoubleMatrix1D errorCoefficients = new DenseDoubleMatrix1D(n);
        for (int i = 0; i < n; ++i) {
            errorCoefficients.set(i, 1.0);
        }
        return errorCoefficients;
    }

    private static DoubleMatrix2D simulateCyclic(GraphWithParameters dwp, DoubleMatrix1D errorCoefficients, int n, Distribution distribution) {
        DoubleMatrix2D reducedForm = Ling.reducedForm(dwp);
        DenseDoubleMatrix2D vectors = new DenseDoubleMatrix2D(dwp.getGraph().getNumNodes(), n);
        for (int j = 0; j < n; ++j) {
            DoubleMatrix1D vector = Ling.simulateReducedForm(reducedForm, errorCoefficients, distribution);
            vectors.viewColumn(j).assign(vector);
        }
        return vectors;
    }

    private static DoubleMatrix2D reducedForm(GraphWithParameters graph) {
        int n = graph.getGraph().getNumNodes();
        DoubleMatrix2D graphMatrix = graph.getGraphMatrix().getDoubleData();
        DoubleMatrix2D identityMinusGraphMatrix = MatrixUtils.linearCombination(MatrixUtils.identityMatrix(n), 1.0, graphMatrix, -1.0);
        return MatrixUtils.inverse(identityMinusGraphMatrix);
    }

    private static DoubleMatrix1D simulateReducedForm(DoubleMatrix2D reducedForm, DoubleMatrix1D errorCoefficients, Distribution distr) {
        int n = reducedForm.rows();
        DenseDoubleMatrix1D vector = new DenseDoubleMatrix1D(n);
        DenseDoubleMatrix1D samples = new DenseDoubleMatrix1D(n);
        for (int j = 0; j < n; ++j) {
            double sample = distr.nextRandom();
            double errorCoefficient = errorCoefficients.get(j);
            samples.set(j, sample * errorCoefficient);
        }
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (int j = 0; j < n; ++j) {
                double coefficient = reducedForm.get(i, j);
                double sample = samples.get(j);
                sum += coefficient * sample;
            }
            vector.set(i, sum);
        }
        return vector;
    }

    private StoredGraphs findCandidateModels(List<Node> variables, DoubleMatrix2D matrixW, int n, boolean approximateZeros) {
        int d;
        StoredGraphs gs = new StoredGraphs();
        TetradLogger.getInstance().log("lingDetails", "Calculating zeroless diagonal permutations.");
        List<PermutationMatrixPair> zldPerms = this.zerolessDiagonalPermutations(matrixW, approximateZeros);
        for (PermutationMatrixPair zldPerm : zldPerms) {
            TetradLogger.getInstance().log("lingDetails", "" + zldPerm);
            DoubleMatrix2D normalizedZldW = MatrixUtils.normalizeDiagonal(zldPerm.getMatrixW());
            zldPerm.setMatrixBhat(Ling.computeBhatMatrix(normalizedZldW, n, variables));
            boolean isStableMatrix = Ling.allEigenvaluesAreSmallerThanOneInModulus(zldPerm.getMatrixBhat().getDoubleData());
            GraphWithParameters graph = new GraphWithParameters(zldPerm.getMatrixBhat());
            gs.addGraph(graph.getGraph());
            gs.addStable(isStableMatrix);
            gs.addData(zldPerm.getMatrixBhat());
        }
        TetradLogger.getInstance().log("stableGraphs", "Stable Graphs:");
        for (d = 0; d < gs.getNumGraphs(); ++d) {
            if (!gs.isStable(d)) continue;
            TetradLogger.getInstance().log("stableGraphs", "" + gs.getGraph(d));
            if (!TetradLogger.getInstance().getLoggerConfig().isEventActive("stableGraphs")) continue;
            TetradLogger.getInstance().log("wMatrices", "" + gs.getData(d));
        }
        TetradLogger.getInstance().log("unstableGraphs", "Unstable Graphs:");
        for (d = 0; d < gs.getNumGraphs(); ++d) {
            if (gs.isStable(d)) continue;
            TetradLogger.getInstance().log("unstableGraphs", "" + gs.getGraph(d));
            if (!TetradLogger.getInstance().getLoggerConfig().isEventActive("unstableGraphs")) continue;
            TetradLogger.getInstance().log("wMatrices", "" + gs.getData(d));
        }
        return gs;
    }

    private List<PermutationMatrixPair> zerolessDiagonalPermutations(DoubleMatrix2D ica_W, boolean approximateZeros) {
        Vector<PermutationMatrixPair> permutations = new Vector<PermutationMatrixPair>();
        if (approximateZeros) {
            this.setInsignificantEntriesToZero(ica_W);
        }
        DoubleMatrix2D mat = ica_W.viewDice();
        List<List<Integer>> nRookAssignments = Ling.nRookColumnAssignments(mat, Ling.makeAllRows(mat.rows()));
        for (List<Integer> permutation : nRookAssignments) {
            DoubleMatrix2D matrixW = Ling.permuteRows(ica_W, permutation).viewDice();
            PermutationMatrixPair permMatrixPair = new PermutationMatrixPair(permutation, matrixW);
            permutations.add(permMatrixPair);
        }
        return permutations;
    }

    private void setInsignificantEntriesToZero(DoubleMatrix2D mat) {
        int n = mat.rows();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                if (!(Math.abs(mat.get(i, j)) < this.threshold)) continue;
                mat.set(i, j, 0.0);
            }
        }
    }

    private static List<Integer> makeAllRows(int n) {
        ArrayList<Integer> l = new ArrayList<Integer>();
        for (int i = 0; i < n; ++i) {
            l.add(i);
        }
        return l;
    }

    private static List<List<Integer>> nRookColumnAssignments(DoubleMatrix2D mat, List<Integer> availableRows) {
        ArrayList<List<Integer>> concats = new ArrayList<List<Integer>>();
        int n = availableRows.size();
        for (int i = 0; i < n; ++i) {
            int currentRowIndex = availableRows.get(i);
            if (mat.get(currentRowIndex, 0) == 0.0) continue;
            if (mat.columns() > 1) {
                Vector<Integer> newAvailableRows = new Vector<Integer>(availableRows);
                newAvailableRows.removeElement(currentRowIndex);
                DoubleMatrix2D subMat = mat.viewPart(0, 1, mat.rows(), mat.columns() - 1);
                List<List<Integer>> allLater = Ling.nRookColumnAssignments(subMat, newAvailableRows);
                for (List<Integer> laterPerm : allLater) {
                    laterPerm.add(0, currentRowIndex);
                    concats.add(laterPerm);
                }
                continue;
            }
            ArrayList<Integer> l = new ArrayList<Integer>();
            l.add(currentRowIndex);
            concats.add(l);
        }
        return concats;
    }

    private static DoubleMatrix2D permuteRows(DoubleMatrix2D mat, List<Integer> permutation) {
        int n = mat.columns();
        DenseDoubleMatrix2D permutedMat = new DenseDoubleMatrix2D(n, n);
        for (int j = 0; j < n; ++j) {
            DoubleMatrix1D row = mat.viewRow(j);
            permutedMat.viewRow(permutation.get(j)).assign(row);
        }
        return permutedMat;
    }

    private static DataSet computeBhatMatrix(DoubleMatrix2D normalizedZldW, int n, List<Node> nodes) {
        DoubleMatrix2D mat = MatrixUtils.linearCombination(MatrixUtils.identityMatrix(n), 1.0, normalizedZldW, -1.0);
        return ColtDataSet.makeContinuousData(nodes, mat);
    }

    private static boolean allEigenvaluesAreSmallerThanOneInModulus(DoubleMatrix2D mat) {
        EigenvalueDecomposition dec = new EigenvalueDecomposition(mat);
        DoubleMatrix1D realEigenvalues = dec.getRealEigenvalues();
        DoubleMatrix1D imagEigenvalues = dec.getImagEigenvalues();
        boolean allEigenvaluesSmallerThanOneInModulus = true;
        for (int i = 0; i < realEigenvalues.size(); ++i) {
            double realEigenvalue = realEigenvalues.get(i);
            double imagEigenvalue = imagEigenvalues.get(i);
            double modulus = Math.sqrt(Math.pow(realEigenvalue, 2.0) + Math.pow(imagEigenvalue, 2.0));
            if (!(modulus >= 1.0)) continue;
            allEigenvaluesSmallerThanOneInModulus = false;
        }
        return allEigenvaluesSmallerThanOneInModulus;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
    }

    public static class StoredGraphs
    implements GraphGroup {
        private List<Graph> graphs = new ArrayList<Graph>();
        private List<DataSet> dataSet = new ArrayList<DataSet>();
        private List<Boolean> stable = new ArrayList<Boolean>();

        @Override
        public int getNumGraphs() {
            return this.graphs.size();
        }

        @Override
        public Graph getGraph(int g) {
            return this.graphs.get(g);
        }

        public DataSet getData(int d) {
            return this.dataSet.get(d);
        }

        public boolean isStable(int s) {
            return this.stable.get(s);
        }

        @Override
        public void addGraph(Graph g) {
            this.graphs.add(g);
        }

        public void addData(DataSet d) {
            this.dataSet.add(d);
        }

        public void addStable(Boolean s) {
            this.stable.add(s);
        }
    }
}

