/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.FastIca;
import edu.cmu.tetrad.search.utils.HungarianAlgorithm;
import edu.cmu.tetrad.search.utils.NRooks;
import edu.cmu.tetrad.search.utils.PermutationMatrixPair;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class IcaLingD {
    private double spineThreshold = 0.5;
    private double bThreshold = 0.1;

    public static boolean isAcyclic(Matrix scaledBHat) {
        ArrayList<Node> dummyVars = new ArrayList<Node>();
        for (int i = 0; i < scaledBHat.getNumRows(); ++i) {
            dummyVars.add(new GraphNode("dummy" + i));
        }
        Graph g = IcaLingD.makeGraph(scaledBHat, dummyVars);
        return !g.paths().existsDirectedCycle();
    }

    public static Matrix estimateW(DataSet data, int fastIcaMaxIter, double fastIcaTolerance, double fastIcaA) {
        double[][] _data = data.getDoubleData().transpose().toArray();
        TetradLogger.getInstance().forceLogMessage("Anderson Darling P-values Per Variables (p < alpha means Non-Gaussian)");
        TetradLogger.getInstance().forceLogMessage("");
        for (int i = 0; i < _data.length; ++i) {
            Node node = data.getVariable(i);
            AndersonDarlingTest test = new AndersonDarlingTest(_data[i]);
            double p = test.getP();
            DecimalFormat nf = new DecimalFormat("0.000");
            TetradLogger.getInstance().forceLogMessage(node.getName() + ": p = " + nf.format(p));
        }
        TetradLogger.getInstance().forceLogMessage("");
        Matrix X = data.getDoubleData();
        X = DataTransforms.centerData(X).transpose();
        FastIca fastIca = new FastIca(X, X.getNumRows());
        fastIca.setVerbose(false);
        fastIca.setMaxIterations(fastIcaMaxIter);
        fastIca.setAlgorithmType(FastIca.PARALLEL);
        fastIca.setTolerance(fastIcaTolerance);
        fastIca.setFunction(FastIca.LOGCOSH);
        fastIca.setRowNorm(false);
        fastIca.setAlpha(fastIcaA);
        FastIca.IcaResult result11 = fastIca.findComponents();
        return result11.getW();
    }

    @NotNull
    public static Graph makeGraph(Matrix B, List<Node> variables) {
        EdgeListGraph g = new EdgeListGraph(variables);
        for (int j = 0; j < B.getNumColumns(); ++j) {
            for (int i = 0; i < B.getNumRows(); ++i) {
                if (B.get(i, j) == 0.0) continue;
                g.addDirectedEdge(variables.get(j), variables.get(i));
            }
        }
        return g;
    }

    public static PermutationMatrixPair hungarianDiagonal(Matrix W) {
        return IcaLingD.hungarian(W.transpose());
    }

    public static boolean isStable(Matrix bHat) {
        EigenDecomposition eigen = new EigenDecomposition(new BlockRealMatrix(bHat.toArray()));
        double[] realEigenvalues = eigen.getRealEigenvalues();
        double[] imagEigenvalues = eigen.getImagEigenvalues();
        for (int i = 0; i < realEigenvalues.length; ++i) {
            double realEigenvalue = realEigenvalues[i];
            double imagEigenvalue = imagEigenvalues[i];
            double modulus = FastMath.sqrt(FastMath.pow(realEigenvalue, 2) + FastMath.pow(imagEigenvalue, 2));
            System.out.println("Modulus for eigenvalue " + (i + 1) + " = " + modulus);
            if (!(modulus >= 1.0)) continue;
            return false;
        }
        return true;
    }

    public static Matrix scale(Matrix M) {
        Matrix _M = M.like();
        for (int i = 0; i < _M.getNumRows(); ++i) {
            for (int j = 0; j < _M.getNumColumns(); ++j) {
                _M.set(i, j, M.get(i, j) / M.get(j, j));
            }
        }
        return _M;
    }

    public static Matrix threshold(Matrix M, double threshold) {
        if (threshold < 0.0) {
            throw new IllegalArgumentException("Expecting a non-negative number: " + threshold);
        }
        Matrix _M = M.copy();
        for (int i = 0; i < M.getNumRows(); ++i) {
            for (int j = 0; j < M.getNumColumns(); ++j) {
                if (!(FastMath.abs(M.get(i, j)) < FastMath.abs(threshold))) continue;
                _M.set(i, j, 0.0);
            }
        }
        return _M;
    }

    public static Matrix getScaledBHat(PermutationMatrixPair pair, double bThreshold) {
        Matrix WTilde = pair.getPermutedMatrix().transpose();
        WTilde = IcaLingD.scale(WTilde);
        Matrix BHat = Matrix.identity(WTilde.getNumColumns()).minus(WTilde);
        BHat = IcaLingD.threshold(BHat, FastMath.abs(bThreshold));
        int[] perm = pair.getRowPerm();
        int[] inverse = IcaLingD.inversePermutation(perm);
        PermutationMatrixPair inversePair = new PermutationMatrixPair(BHat, inverse, inverse);
        return inversePair.getPermutedMatrix();
    }

    @NotNull
    private static PermutationMatrixPair hungarian(Matrix W) {
        double[][] costMatrix = new double[W.getNumRows()][W.getNumColumns()];
        for (int i = 0; i < W.getNumRows(); ++i) {
            for (int j = 0; j < W.getNumColumns(); ++j) {
                costMatrix[i][j] = W.get(i, j) != 0.0 ? 1.0 / FastMath.abs(W.get(i, j)) : 1000.0;
            }
        }
        HungarianAlgorithm alg = new HungarianAlgorithm(costMatrix);
        int[][] assignment = alg.findOptimalAssignment();
        int[] perm = new int[assignment.length];
        for (int i = 0; i < perm.length; ++i) {
            perm[i] = assignment[i][1];
        }
        return new PermutationMatrixPair(W, perm, null);
    }

    @NotNull
    private static List<PermutationMatrixPair> pairsNRook(Matrix W, double spineThreshold) {
        boolean[][] allowablePositions = new boolean[W.getNumRows()][W.getNumColumns()];
        for (int i = 0; i < W.getNumRows(); ++i) {
            for (int j = 0; j < W.getNumColumns(); ++j) {
                allowablePositions[i][j] = FastMath.abs(W.get(i, j)) > spineThreshold;
            }
        }
        ArrayList<PermutationMatrixPair> pairs = new ArrayList<PermutationMatrixPair>();
        ArrayList<int[]> colPermutations = NRooks.nRooks(allowablePositions);
        for (int[] colPermutation : colPermutations) {
            pairs.add(new PermutationMatrixPair(W, null, colPermutation));
        }
        return pairs;
    }

    static int[] inversePermutation(int[] perm) {
        int[] inverse = new int[perm.length];
        for (int i = 0; i < perm.length; ++i) {
            inverse[perm[i]] = i;
        }
        return inverse;
    }

    public List<Matrix> fit(DataSet D) {
        Matrix W = IcaLingD.estimateW(D, 5000, 1.0E-6, 1.2);
        return this.fitW(W);
    }

    public List<Matrix> fitW(Matrix W) {
        List<PermutationMatrixPair> pairs = IcaLingD.pairsNRook(W.transpose(), this.spineThreshold);
        if (pairs.isEmpty()) {
            throw new IllegalArgumentException("Could not find an N Rooks solution with that threshold.");
        }
        ArrayList<Matrix> results = new ArrayList<Matrix>();
        for (PermutationMatrixPair pair : pairs) {
            Matrix bHat = IcaLingD.getScaledBHat(pair, this.bThreshold);
            results.add(bHat);
        }
        return results;
    }

    public void setBThreshold(double bThreshold) {
        if (bThreshold < 0.0) {
            throw new IllegalArgumentException("Expecting a non-negative number: " + bThreshold);
        }
        this.bThreshold = bThreshold;
    }

    public void setSpineThreshold(double spineThreshold) {
        if (spineThreshold < 0.0) {
            throw new IllegalArgumentException("Expecting a non-negative number: " + spineThreshold);
        }
        this.spineThreshold = spineThreshold;
    }
}

