/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.EigenvalueDecomposition;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.CyclicDiscoveryMethod;
import edu.cmu.tetrad.search.GraphWithParameters;
import edu.cmu.tetrad.search.PermutationMatrixPair;
import edu.cmu.tetrad.search.fastica.FastICA;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Vector;

public class LacerdaSpirtesRamsey2007Search
implements CyclicDiscoveryMethod {
    HashMap<Integer, List<Double>> moments;

    @Override
    public String getName() {
        return "LacerdaSpirtesRamsey2007";
    }

    public double estimateNthMoment(int n, DoubleMatrix1D vec) {
        double sum = 0.0;
        for (int i = 0; i < vec.size(); ++i) {
            sum += Math.pow(vec.get(i), n);
        }
        return sum / (double)vec.size();
    }

    public static double impliedNthMoment(DoubleMatrix1D mixingCoeffs, DoubleMatrix1D errorTermsNthMoments, int n) {
        double sum = 0.0;
        for (int j = 0; j < mixingCoeffs.size(); ++j) {
            double coeff = mixingCoeffs.get(j);
            sum += Math.pow(coeff, n) * errorTermsNthMoments.get(j);
        }
        return sum;
    }

    public static double impliedFourthMoment(DoubleMatrix1D mixingCoeffs, DoubleMatrix1D errorTermsNthMoments) {
        return LacerdaSpirtesRamsey2007Search.impliedNthMoment(mixingCoeffs, errorTermsNthMoments, 4);
    }

    @Override
    public GraphWithParameters run(DataSet dataSet) {
        DoubleMatrix2D ica_A = null;
        Object Btilde = null;
        Object Bhat = null;
        Object normalizedZldW = null;
        DoubleMatrix2D data = dataSet.getDoubleData().viewDice();
        GraphWithParameters bestGraph = null;
        try {
            double[][] inV = MatrixUtils.convert(data);
            long sTime = new Date().getTime();
            FastICA fica = new FastICA(inV, data.rows());
            long eTime = new Date().getTime();
            ica_A = MatrixUtils.convertToColt(fica.getMixingMatrix());
            DoubleMatrix2D ica_W = MatrixUtils.inverse(ica_A);
            int n = ica_W.rows();
            System.out.println("W = " + ica_W);
            if (ica_W.rows() != ica_W.columns()) {
                new Exception("W is not square!").printStackTrace();
            }
            if (ica_W.rows() != dataSet.getNumColumns()) {
                new Exception("W does not have the right number of dimensions!").printStackTrace();
            }
            LacerdaSpirtesRamsey2007Search.findCandidateModels(dataSet.getVariables(), ica_W, data, n, true);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return bestGraph;
    }

    public static boolean allEigenvaluesAreSmallerThanOneInModulus(DoubleMatrix2D mat) {
        EigenvalueDecomposition dec = new EigenvalueDecomposition(mat);
        DoubleMatrix1D realEigenvalues = dec.getRealEigenvalues();
        DoubleMatrix1D imagEigenvalues = dec.getImagEigenvalues();
        boolean allEigenvaluesSmallerThanOneInModulus = true;
        for (int i = 0; i < realEigenvalues.size(); ++i) {
            double realEigenvalue = realEigenvalues.get(i);
            double imagEigenvalue = imagEigenvalues.get(i);
            System.out.println("eigenvalue #" + i + " = " + realEigenvalue + "+" + imagEigenvalue + "i");
            double argument = LacerdaSpirtesRamsey2007Search.arg(realEigenvalue, imagEigenvalue);
            System.out.println("eigenvalue #" + i + " has argument = " + argument);
            double modulus = Math.sqrt(Math.pow(realEigenvalue, 2.0) + Math.pow(imagEigenvalue, 2.0));
            System.out.println("eigenvalue #" + i + " has modulus = " + modulus);
            double modulusCubed = Math.pow(modulus, 3.0);
            System.out.println("eigenvalue #" + i + " has modulus^3 = " + modulusCubed);
            if (!(modulus >= 1.0)) continue;
            allEigenvaluesSmallerThanOneInModulus = false;
        }
        return allEigenvaluesSmallerThanOneInModulus;
    }

    private static double arg(double realPart, double imagPart) {
        return Math.atan(imagPart / realPart);
    }

    public static void findCandidateModels(List<Node> variables, DoubleMatrix2D matrixW, DoubleMatrix2D data, int n, boolean approximateZeros) {
        System.out.println("variables = " + variables);
        List<PermutationMatrixPair> zldPerms = LacerdaSpirtesRamsey2007Search.zerolessDiagonalPermutations(matrixW, approximateZeros);
        Vector losses = new Vector();
        int i = 0;
        for (PermutationMatrixPair zldPerm : zldPerms) {
            System.out.println();
            System.out.println("------------------------");
            System.out.println("---- candidate #" + (i + 1) + " ------");
            System.out.println();
            List<Integer> zldWpermutation = zldPerm.getPermutation();
            DoubleMatrix2D normalizedZldW = MatrixUtils.normalizeDiagonal(zldPerm.getMatrixW());
            if (data != null) {
                // empty if block
            }
            zldPerm.setMatrixBhat(LacerdaSpirtesRamsey2007Search.computeBhatMatrix(normalizedZldW, n, variables));
            System.out.println("matrixBhat = " + zldPerm.getMatrixBhat());
            boolean isShrinkingMatrix = LacerdaSpirtesRamsey2007Search.allEigenvaluesAreSmallerThanOneInModulus(zldPerm.getMatrixBhat().getDoubleData());
            System.out.println("isShrinkingMatrix = " + isShrinkingMatrix);
            GraphWithParameters graph = new GraphWithParameters(zldPerm.getMatrixBhat());
            System.out.println("graph:\n" + graph);
            ++i;
        }
        System.out.println("------------------------");
        System.out.println("There are " + zldPerms.size() + " candidates.");
    }

    private static double difference(DoubleMatrix1D data1, DoubleMatrix1D data2) {
        double sum = 0.0;
        for (int i = 0; i < data1.size(); ++i) {
            double eltDiff = data1.get(i) - data2.get(i);
            sum += Math.pow(eltDiff, 2.0);
        }
        return sum;
    }

    private double estimateNthMoment(DoubleMatrix2D mat, int n) {
        double sum = 0.0;
        for (int i = 0; i < mat.rows(); ++i) {
            double variableNthMoment = LacerdaSpirtesRamsey2007Search.estimateNthMoment(mat.viewRow(i), n);
            sum += variableNthMoment;
        }
        return sum;
    }

    private static double estimateNthMoment(DoubleMatrix1D vec, int n) {
        double sum = 0.0;
        for (int j = 0; j < vec.size(); ++j) {
            sum += Math.pow(vec.get(j), n);
        }
        return sum;
    }

    private static DataSet computeBhatMatrix(DoubleMatrix2D normalizedZldW, int n, List<Node> nodes) {
        DoubleMatrix2D mat = MatrixUtils.linearCombination(MatrixUtils.identityMatrix(n), 1.0, normalizedZldW, -1.0);
        return ColtDataSet.makeContinuousData(nodes, mat);
    }

    private DoubleMatrix2D permuteRows(int[] zldWpermutation, DoubleMatrix2D ica_A) {
        for (int i = 0; i < zldWpermutation.length; ++i) {
        }
        return null;
    }

    private int argmin(List<Double> scores) {
        int minIndex = 0;
        double min = scores.get(0);
        for (int i = 0; i < scores.size(); ++i) {
            double value = scores.get(i);
            if (!(value < min)) continue;
            minIndex = i;
            min = value;
        }
        return minIndex;
    }

    private double sumSquaredDifferences(DoubleMatrix2D m1, DoubleMatrix2D m2) {
        int n = m1.rows();
        int m = m1.columns();
        double sum = 0.0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j) {
                double diff = m1.get(i, j) - m2.get(i, j);
                sum += Math.pow(diff, 2.0);
            }
        }
        return sum;
    }

    private static List<PermutationMatrixPair> zerolessDiagonalPermutations(DoubleMatrix2D ica_W, boolean approximateZeros) {
        Vector<PermutationMatrixPair> permutations = new Vector<PermutationMatrixPair>();
        if (approximateZeros) {
            System.out.println("before pruning, ica_W = " + ica_W);
            LacerdaSpirtesRamsey2007Search.setInsignificantEntriesToZero(ica_W, 0.05);
            System.out.println("after pruning, ica_W = " + ica_W);
        }
        List<List<Integer>> nRookAssignments = LacerdaSpirtesRamsey2007Search.nRookRowAssignments(ica_W);
        for (List<Integer> permutation : nRookAssignments) {
            DoubleMatrix2D matrixW = LacerdaSpirtesRamsey2007Search.permuteRows(ica_W, permutation);
            PermutationMatrixPair permMatrixPair = new PermutationMatrixPair(permutation, matrixW);
            TetradLogger.getInstance().log("lingDetails", "adding to permutations: " + permMatrixPair);
            permutations.add(permMatrixPair);
        }
        return permutations;
    }

    private List<Integer> inverse(List<Integer> permutation) {
        System.out.println("inverting a permutation: " + permutation);
        int[] inv = new int[permutation.size()];
        for (int i = 0; i < permutation.size(); ++i) {
            inv[permutation.get((int)i).intValue()] = i;
        }
        return this.makeVector(inv);
    }

    private List<Integer> makeVector(int[] array) {
        Vector<Integer> v = new Vector<Integer>();
        for (int i = 0; i < array.length; ++i) {
            v.add(array[i]);
        }
        return v;
    }

    private static void setInsignificantEntriesToZero(DoubleMatrix2D mat, double threshold) {
        int n = mat.rows();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                if (!(Math.abs(mat.get(i, j)) < threshold)) continue;
                mat.set(i, j, 0.0);
            }
        }
    }

    private DoubleMatrix2D permuteRowsAndColumns(DoubleMatrix2D mat, List<Integer> permutation) {
        return LacerdaSpirtesRamsey2007Search.permuteRows(LacerdaSpirtesRamsey2007Search.permuteRows(mat, permutation), permutation);
    }

    private DoubleMatrix2D permuteColumns(DoubleMatrix2D mat, List<Integer> permutation) {
        int n = mat.columns();
        DenseDoubleMatrix2D permutedMat = new DenseDoubleMatrix2D(n, n);
        for (int j = 0; j < n; ++j) {
            DoubleMatrix1D fromColumn = mat.viewColumn(j);
            int toColumnIndex = permutation.get(j);
            DoubleMatrix1D toColumn = permutedMat.viewColumn(toColumnIndex);
            toColumn.assign(fromColumn);
        }
        return permutedMat;
    }

    private static DoubleMatrix2D permuteRows(DoubleMatrix2D mat, List<Integer> permutation) {
        int n = mat.columns();
        DenseDoubleMatrix2D permutedMat = new DenseDoubleMatrix2D(n, n);
        for (int j = 0; j < n; ++j) {
            DoubleMatrix1D row = mat.viewRow(j);
            permutedMat.viewRow(permutation.get(j)).assign(row);
        }
        return permutedMat;
    }

    public static List<List<Integer>> nRookColumnAssignments(DoubleMatrix2D mat) {
        int n = mat.rows();
        List<Integer> allRows = LacerdaSpirtesRamsey2007Search.makeAllRows(n);
        return LacerdaSpirtesRamsey2007Search.nRookColumnAssignments(mat, allRows);
    }

    public static List<List<Integer>> nRookRowAssignments(DoubleMatrix2D mat) {
        return LacerdaSpirtesRamsey2007Search.nRookColumnAssignments(mat.viewDice());
    }

    private static List<Integer> makeAllRows(int n) {
        Vector<Integer> l = new Vector<Integer>();
        for (int i = 0; i < n; ++i) {
            l.add(i);
        }
        return l;
    }

    public static DoubleMatrix2D displayNRookAssignment(List<Integer> perm) {
        int n = perm.size();
        DenseDoubleMatrix2D mat = new DenseDoubleMatrix2D(n, n);
        for (int j = 0; j < n; ++j) {
            mat.set(perm.get(j), j, 1.0);
        }
        return mat;
    }

    public static List<List<Integer>> nRookColumnAssignments(DoubleMatrix2D mat, List<Integer> availableRows) {
        Vector<List<Integer>> concats = new Vector<List<Integer>>();
        int n = availableRows.size();
        if (mat.columns() > 1) {
            for (int i = 0; i < n; ++i) {
                int currentRowIndex = availableRows.get(i);
                if (mat.get(currentRowIndex, 0) == 0.0) continue;
                Vector<Integer> newAvailableRows = new Vector<Integer>(availableRows);
                newAvailableRows.removeElement(currentRowIndex);
                DoubleMatrix2D subMat = mat.viewPart(0, 1, mat.rows(), mat.columns() - 1);
                List<List<Integer>> allLater = LacerdaSpirtesRamsey2007Search.nRookColumnAssignments(subMat, newAvailableRows);
                for (List<Integer> laterPerm : allLater) {
                    laterPerm.add(0, currentRowIndex);
                    concats.add(laterPerm);
                }
            }
        } else {
            for (int i = 0; i < n; ++i) {
                int currentRowIndex = availableRows.get(i);
                if (mat.get(currentRowIndex, 0) == 0.0) continue;
                Vector<Integer> l = new Vector<Integer>();
                l.add(currentRowIndex);
                concats.add(l);
            }
        }
        return concats;
    }
}

