/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import cern.jet.math.PlusMult;
import edu.cmu.tetrad.cluster.FastIca;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Hungarian;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.EVD;
import no.uib.cipr.matrix.LowerTriangDenseMatrix;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixSingularException;
import no.uib.cipr.matrix.NotConvergedException;
import no.uib.cipr.matrix.QL;
import no.uib.cipr.matrix.Vector;

public class Lingam {
    private boolean isUpperTriangleKept = false;
    private double pruneFactor = 1.0;
    private TetradLogger logger = TetradLogger.getInstance();

    public Graph search(DataSet data) {
        DoubleMatrix2D X = data.getDoubleData();
        List<Node> nodes = data.getVariables();
        EstimateResult result = this.estimate(X);
        DoubleMatrix2D bHat = this.pruneEdgesByResampling(X, result.getK());
        EdgeListGraph graph = new EdgeListGraph(nodes);
        for (int j = 0; j < bHat.columns(); ++j) {
            for (int i = 0; i < bHat.rows(); ++i) {
                if (bHat.get(i, j) == 0.0) continue;
                graph.addDirectedEdge(nodes.get(j), nodes.get(i));
            }
        }
        this.logger.log("graph", "\nReturning this graph: " + graph);
        return graph;
    }

    public EstimateResult estimate(DoubleMatrix2D X) {
        int i;
        FastIca fastIca = new FastIca(X.copy(), X.columns());
        fastIca.setVerbose(true);
        FastIca.IcaResult result = fastIca.findComponents();
        DoubleMatrix2D A = result.getA().viewDice();
        DoubleMatrix2D W = MatrixUtils.inverse(A);
        TetradLogger.getInstance().log("lingamDetails", "\nW " + W);
        DoubleMatrix2D S = W.copy();
        S.assign(Functions.abs);
        S.assign(Functions.inv);
        int[][] assignment = Hungarian.hgAlgorithm(S.viewDice().toArray(), "min");
        int[] rowp = new int[assignment.length];
        for (i = 0; i < rowp.length; ++i) {
            rowp[i] = assignment[i][1];
        }
        TetradLogger.getInstance().log("lingamDetails", "\nrowp = ");
        for (i = 0; i < rowp.length; ++i) {
            TetradLogger.getInstance().log("lingamDetails", rowp[i] + "\t");
        }
        DoubleMatrix2D Wp = W.viewSelection(rowp, this.range(0, W.columns() - 1));
        TetradLogger.getInstance().log("lingamDetails", "\nWp = " + Wp);
        DoubleMatrix2D Wout = Wp.copy();
        DenseDoubleMatrix1D estdisturbancesstd = new DenseDoubleMatrix1D(Wp.rows());
        for (int i2 = 0; i2 < Wp.rows(); ++i2) {
            estdisturbancesstd.set(i2, 1.0 / Math.abs(Wp.get(i2, i2)));
        }
        TetradLogger.getInstance().log("lingamDetails", "\nWp = " + Wp);
        DoubleMatrix1D diag = DoubleFactory2D.dense.diagonal(Wp);
        for (int i3 = 0; i3 < Wp.rows(); ++i3) {
            for (int j = 0; j < Wp.columns(); ++j) {
                Wp.set(i3, j, Wp.get(i3, j) / diag.get(i3));
            }
        }
        TetradLogger.getInstance().log("lingamDetails", "\nWp = " + Wp);
        DoubleMatrix2D Best = DoubleFactory2D.dense.identity(Wp.rows());
        Best = Best.assign(Wp, PlusMult.plusMult(-1.0));
        TetradLogger.getInstance().log("lingamDetails", "\nBest + " + Best);
        DenseDoubleMatrix1D Xm = new DenseDoubleMatrix1D(X.columns());
        for (int j = 0; j < X.columns(); ++j) {
            double sum = 0.0;
            for (int i4 = 0; i4 < X.rows(); ++i4) {
                double v = X.get(i4, j);
                sum += v;
            }
            double mean = sum / (double)X.rows();
            Xm.set(j, mean);
        }
        DoubleMatrix1D cest = new Algebra().mult(Wp, (DoubleMatrix1D)Xm);
        TetradLogger.getInstance().log("lingamDetails", "cest = " + cest);
        StlPruneResult result1 = this.stlPrune(Best);
        DoubleMatrix2D bestCausal = result1.getBestcausal();
        int[] causalperm = result1.getCausalperm();
        TetradLogger.getInstance().log("lingamDetails", "\nBest causal " + bestCausal);
        TetradLogger.getInstance().log("lingamDetails", "\ncausalperm = " + causalperm);
        int[] icausal = this.iperm(causalperm);
        for (int i5 = 0; i5 < bestCausal.rows(); ++i5) {
            for (int j = i5 + 1; j < bestCausal.columns(); ++j) {
                bestCausal.set(i5, j, 0.0);
            }
        }
        TetradLogger.getInstance().log("lingamDetails", "\nbestCausal = " + bestCausal);
        DoubleMatrix2D B = bestCausal.viewSelection(icausal, icausal).copy();
        TetradLogger.getInstance().log("lingamDetails", "B = " + B);
        return new EstimateResult(B, estdisturbancesstd.toArray(), cest.toArray(), causalperm, Wout);
    }

    public double getPruneFactor() {
        return this.pruneFactor;
    }

    public void setPruneFactor(double pruneFactor) {
        if (pruneFactor <= 0.0) {
            throw new IllegalArgumentException("Prune factor must be greater than zero.");
        }
        this.pruneFactor = pruneFactor;
    }

    private StlPruneResult stlPrune(DoubleMatrix2D bHat) {
        int i;
        int m = bHat.rows();
        LinkedList<Entry> entries = this.getEntries(bHat);
        Collections.sort(entries);
        DoubleMatrix2D bHat2 = bHat.copy();
        int numUpperTriangle = m * (m + 1) / 2;
        int numTotal = m * m;
        for (i = 0; i < numUpperTriangle; ++i) {
            Entry entry = entries.get(i);
            bHat.set(entry.row, entry.column, 0.0);
        }
        for (i = numUpperTriangle; i < numTotal; ++i) {
            int[] permutation = this.algorithmB(bHat);
            if (permutation != null) {
                DoubleMatrix2D Bestcausal = this.permute(permutation, bHat2);
                return new StlPruneResult(Bestcausal, permutation);
            }
            Entry entry = entries.get(i);
            bHat.set(entry.row, entry.column, 0.0);
        }
        throw new IllegalArgumentException("No permutation was found.");
    }

    private DoubleMatrix2D permute(int[] permutation, DoubleMatrix2D data) {
        return data.viewSelection(permutation, permutation);
    }

    private LinkedList<Entry> getEntries(DoubleMatrix2D mat) {
        LinkedList<Entry> entries = new LinkedList<Entry>();
        for (int i = 0; i < mat.rows(); ++i) {
            for (int j = 0; j < mat.columns(); ++j) {
                Entry entry = new Entry(i, j, mat.get(i, j));
                entries.add(entry);
            }
        }
        return entries;
    }

    public int[] algorithmB(DoubleMatrix2D mat) {
        int i;
        ArrayList<Integer> removedIndices = new ArrayList<Integer>();
        ArrayList<Integer> permutation = new ArrayList<Integer>();
        while (removedIndices.size() < mat.rows()) {
            int allZerosRow = -1;
            for (i = 0; i < mat.rows(); ++i) {
                if (removedIndices.contains(i) || !this.zeroesInNewColumns(mat.viewRow(i), removedIndices)) continue;
                allZerosRow = i;
                break;
            }
            if (allZerosRow == -1) {
                return null;
            }
            removedIndices.add(allZerosRow);
            permutation.add(allZerosRow);
        }
        int[] _permutation = new int[permutation.size()];
        for (i = 0; i < _permutation.length; ++i) {
            _permutation[i] = (Integer)permutation.get(i);
        }
        return _permutation;
    }

    private boolean zeroesInNewColumns(DoubleMatrix1D vec, List<Integer> removedIndices) {
        for (int i = 0; i < vec.size(); ++i) {
            if (vec.get(i) == 0.0 || removedIndices.contains(i)) continue;
            return false;
        }
        return true;
    }

    public DoubleMatrix2D pruneEdgesByResampling(DoubleMatrix2D data, int[] k) {
        if (k.length != data.columns()) {
            throw new IllegalArgumentException("Execting a permutation.");
        }
        LinkedHashSet<Integer> set = new LinkedHashSet<Integer>();
        for (int i = 0; i < k.length; ++i) {
            if (k[i] >= k.length) {
                throw new IllegalArgumentException("Expecting a permutation.");
            }
            if (set.contains(i)) {
                throw new IllegalArgumentException("Expecting a permutation.");
            }
            set.add(i);
        }
        DenseMatrix X = new DenseMatrix(data.viewDice().toArray());
        int npieces = 10;
        int cols = X.numColumns();
        int rows = X.numRows();
        int piecesize = (int)Math.floor(cols / npieces);
        ArrayList<DenseMatrix> bpieces = new ArrayList<DenseMatrix>();
        ArrayList<DenseVector> diststdpieces = new ArrayList<DenseVector>();
        ArrayList<DenseVector> cpieces = new ArrayList<DenseVector>();
        for (int p = 0; p < npieces; ++p) {
            DenseMatrix invSqrt;
            DenseMatrix sqrt;
            int i;
            int p0 = p * piecesize;
            int p1 = (p + 1) * piecesize - 1;
            int[] range = this.range(p0, p1);
            Matrix Xp = Matrices.getSubMatrix(X, k, range);
            double[] Xpm = new double[rows];
            for (i = 0; i < rows; ++i) {
                double sum = 0.0;
                for (int j = 0; j < Xp.numColumns(); ++j) {
                    sum += Xp.get(i, j);
                }
                Xpm[i] = sum / (double)Xp.numColumns();
            }
            for (i = 0; i < rows; ++i) {
                for (int j = 0; j < Xp.numColumns(); ++j) {
                    Xp.set(i, j, Xp.get(i, j) - Xpm[i]);
                }
            }
            DenseMatrix XpT = new DenseMatrix(Xp.numColumns(), rows);
            Matrix Xpt = Xp.transpose(XpT);
            Matrix cov = new DenseMatrix(rows, rows);
            cov = Xp.mult(Xpt, cov);
            for (int i2 = 0; i2 < cov.numRows(); ++i2) {
                for (int j = 0; j < cov.numColumns(); ++j) {
                    cov.set(i2, j, cov.get(i2, j) / (double)Xp.numColumns());
                }
            }
            boolean posDef = MatrixUtils.isPositiveDefinite(new DenseDoubleMatrix2D(Matrices.getArray(cov)));
            try {
                sqrt = this.sqrt(new DenseMatrix(cov));
            }
            catch (NotConvergedException e) {
                throw new RuntimeException(e);
            }
            DenseMatrix I = Matrices.identity(rows);
            DenseMatrix AI = I.copy();
            try {
                invSqrt = new DenseMatrix(sqrt.solve(I, AI));
            }
            catch (MatrixSingularException e) {
                throw new RuntimeException("Singular matrix.", e);
            }
            QL ql = QL.factorize(invSqrt);
            LowerTriangDenseMatrix L = ql.getL();
            Vector newestdisturbancestd = new DenseVector(rows);
            for (int t = 0; t < rows; ++t) {
                newestdisturbancestd.set(t, 1.0 / Math.abs(L.get(t, t)));
            }
            for (int s = 0; s < rows; ++s) {
                for (int t = 0; t <= s; ++t) {
                    L.set(s, t, L.get(s, t) / L.get(s, s));
                }
            }
            Matrix bnewest = Matrices.identity(rows);
            bnewest = bnewest.add(-1.0, L);
            Vector cnewest = new DenseVector(rows);
            cnewest = L.mult(new DenseVector(Xpm), cnewest);
            int[] ik = this.iperm(k);
            bnewest = Matrices.getSubMatrix(bnewest, ik, ik);
            newestdisturbancestd = Matrices.getSubVector(newestdisturbancestd, ik);
            cnewest = Matrices.getSubVector(cnewest, ik);
            bpieces.add((DenseMatrix)bnewest);
            diststdpieces.add((DenseVector)newestdisturbancestd);
            cpieces.add((DenseVector)cnewest);
        }
        double prunefactor = 1.0;
        DenseMatrix means = new DenseMatrix(rows, rows);
        DenseMatrix stds = new DenseMatrix(rows, rows);
        DenseMatrix BFinal = new DenseMatrix(rows, rows);
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < rows; ++j) {
                double sum = 0.0;
                for (int y = 0; y < npieces; ++y) {
                    sum += ((Matrix)bpieces.get(y)).get(i, j);
                }
                double themean = sum / (double)npieces;
                double sumVar = 0.0;
                for (int y = 0; y < npieces; ++y) {
                    sumVar += Math.pow(((Matrix)bpieces.get(y)).get(i, j) - themean, 2.0);
                }
                double thestd = Math.sqrt(sumVar / (double)npieces);
                means.set(i, j, themean);
                stds.set(i, j, thestd);
                if (Math.abs(themean) < this.getPruneFactor() * thestd) {
                    BFinal.set(i, j, 0.0);
                    continue;
                }
                BFinal.set(i, j, themean);
            }
        }
        return new DenseDoubleMatrix2D(Matrices.getArray(BFinal));
    }

    private void printVector(Vector vector) {
        System.out.println();
        for (int i = 0; i < vector.size(); ++i) {
            System.out.print(vector.get(i) + "\t");
        }
        System.out.println();
    }

    private void printMatrix(Matrix matrix) {
        System.out.println();
        for (int i = 0; i < matrix.numRows(); ++i) {
            for (int j = 0; j < matrix.numColumns(); ++j) {
                System.out.print(matrix.get(i, j) + "\t");
            }
            System.out.println();
        }
        System.out.println();
    }

    public int[] iperm(int[] k) {
        int[] ik = new int[k.length];
        for (int i = 0; i < k.length; ++i) {
            for (int j = 0; j < k.length; ++j) {
                if (k[i] != j) continue;
                ik[j] = i;
            }
        }
        return ik;
    }

    private DenseMatrix sqrt(DenseMatrix m) throws NotConvergedException {
        EVD eig = new EVD(m.numRows());
        eig.factor(m);
        double[] r = eig.getRealEigenvalues();
        DenseMatrix v = eig.getLeftEigenvectors();
        DenseMatrix d = new DenseMatrix(m.numRows(), m.numRows());
        for (int i = 0; i < d.numRows(); ++i) {
            d.set(i, i, Math.sqrt(Math.abs(r[i])));
        }
        Matrix vd = new DenseMatrix(m.numRows(), m.numRows());
        vd = v.mult(d, vd);
        Matrix vT = new DenseMatrix(m.numRows(), m.numRows());
        vT = v.transpose(vT);
        DenseMatrix prod = new DenseMatrix(m.numRows(), m.numRows());
        vd.mult(vT, prod);
        return prod;
    }

    private int[] range(int i1, int i2) {
        if (i2 < i1) {
            throw new IllegalArgumentException("i2 must be >=  i2 " + i1 + ", " + i2);
        }
        int[] series = new int[i2 - i1 + 1];
        for (int j = i1; j <= i2; ++j) {
            series[j - i1] = j;
        }
        return series;
    }

    private static class Entry
    implements Comparable<Entry> {
        private int row;
        private int column;
        private double value;

        public Entry(int row, int col, double val) {
            this.row = row;
            this.column = col;
            this.value = val;
        }

        @Override
        public int compareTo(Entry entry) {
            double thisVal = Math.abs(this.value);
            double entryVal = Math.abs(entry.value);
            return new Double(thisVal).compareTo(entryVal);
        }

        public String toString() {
            return "[" + this.row + "," + this.column + "]:" + this.value + " ";
        }
    }

    private static class StlPruneResult {
        private DoubleMatrix2D Bestcausal;
        private int[] causalperm;

        public StlPruneResult(DoubleMatrix2D Bestcausal, int[] causalPerm) {
            this.Bestcausal = Bestcausal;
            this.causalperm = causalPerm;
        }

        public DoubleMatrix2D getBestcausal() {
            return this.Bestcausal;
        }

        public int[] getCausalperm() {
            return this.causalperm;
        }
    }

    public static class EstimateResult {
        private DoubleMatrix2D B;
        private double[] stde;
        private double[] ci;
        private int[] k;
        private DoubleMatrix2D Wout;

        public EstimateResult(DoubleMatrix2D B, double[] stde, double[] ci, int[] k, DoubleMatrix2D Wout) {
            this.B = B;
            this.stde = stde;
            this.ci = ci;
            this.k = k;
            this.Wout = Wout;
        }

        public DoubleMatrix2D getB() {
            return this.B;
        }

        public double[] getStde() {
            return this.stde;
        }

        public double[] getCi() {
            return this.ci;
        }

        public int[] getK() {
            return this.k;
        }

        public DoubleMatrix2D getWout() {
            return this.Wout;
        }
    }
}

