/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Hungarian;
import edu.cmu.tetrad.search.ShimizuResult;
import edu.cmu.tetrad.search.fastica.FastICA;
import edu.cmu.tetrad.util.MatrixUtils;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Vector;

public final class Shimizu2006SearchCleanup {
    public static ColtDataSet lingamDiscoveryBhat(DataSet dataSet) {
        DoubleMatrix2D A = null;
        Object Btilde = null;
        DoubleMatrix2D Bhat = null;
        DoubleMatrix2D zldWprime = null;
        DoubleMatrix2D data = dataSet.getDoubleData().viewDice();
        boolean runColt = false;
        boolean runArray = true;
        try {
            if (runArray) {
                double[][] inV = MatrixUtils.convert(data);
                long sTime = new Date().getTime();
                FastICA fica = new FastICA(inV, data.rows());
                long eTime = new Date().getTime();
                A = MatrixUtils.convertToColt(fica.getMixingMatrix());
            } else if (runColt) {
                // empty if block
            }
            DoubleMatrix2D W = MatrixUtils.inverse(A);
            int n = W.rows();
            DoubleMatrix2D zldW = Shimizu2006SearchCleanup.permuteZerolessDiagonal(W);
            zldWprime = MatrixUtils.normalizeDiagonal(zldW);
            for (int i = 0; i < zldWprime.rows(); ++i) {
            }
            Bhat = MatrixUtils.linearCombination(MatrixUtils.identityMatrix(n), 1.0, zldWprime, -1.0);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return Shimizu2006SearchCleanup.makeDataSet(Bhat.viewDice(), dataSet.getVariables());
    }

    public static ColtDataSet makeDataSet(DoubleMatrix2D inVectors, List<Node> nodes) {
        if (inVectors.columns() != nodes.size()) {
            new Exception("dimensions don't match!").printStackTrace();
        }
        Vector<Node> variables = new Vector<Node>();
        for (Node node : nodes) {
            variables.add(new ContinuousVariable(node.getName()));
        }
        return ColtDataSet.makeContinuousData(variables, inVectors);
    }

    public static Vector<Integer> dagPermutation(DoubleMatrix2D mat) {
        Vector<Integer> removedIndices = new Vector<Integer>();
        Vector<Integer> v = new Vector<Integer>();
        while (removedIndices.size() < mat.rows()) {
            int allZerosRow = -1;
            for (int i = 0; i < mat.rows(); ++i) {
                if (removedIndices.contains(i) || !Shimizu2006SearchCleanup.containsAllZeros(mat.viewRow(i), removedIndices)) continue;
                allZerosRow = i;
                break;
            }
            if (allZerosRow == -1) {
                return null;
            }
            removedIndices.add(allZerosRow);
            v.add(allZerosRow);
        }
        return v;
    }

    public static Vector<Entry> getEntries(DoubleMatrix2D mat) {
        Vector<Entry> entries = new Vector<Entry>();
        for (int i = 0; i < mat.rows(); ++i) {
            for (int j = 0; j < mat.columns(); ++j) {
                Entry entry = new Entry(i, j, mat.get(i, j));
                entries.add(entry);
            }
        }
        return entries;
    }

    public static ColtDataSet iterativeC(ColtDataSet dataSet) {
        Vector<Integer> permutation;
        DoubleMatrix2D mat = dataSet.getDoubleData();
        int n = mat.rows();
        Vector<Entry> remainingEntries = Shimizu2006SearchCleanup.getEntries(mat);
        Collections.sort(remainingEntries);
        Vector<Entry> entryQueue = Shimizu2006SearchCleanup.getNFirst(n * (n + 1) / 2, remainingEntries);
        DoubleMatrix2D tempMat = Shimizu2006SearchCleanup.setEntriesToZero(entryQueue, mat);
        remainingEntries.removeAll(entryQueue);
        entryQueue.removeAllElements();
        while ((permutation = Shimizu2006SearchCleanup.dagPermutation(tempMat)) == null) {
            Entry entry = remainingEntries.get(0);
            tempMat = Shimizu2006SearchCleanup.setEntryToZero(entry, tempMat);
            remainingEntries.remove(0);
        }
        return Shimizu2006SearchCleanup.permute(permutation, dataSet);
    }

    public static Vector<Entry> getNFirst(int n, Vector<Entry> list) {
        return new Vector<Entry>(list.subList(0, n));
    }

    public static DoubleMatrix2D removeColumn(DoubleMatrix2D mat, int index) {
        int[] rows = new int[mat.rows()];
        for (int i = 0; i < mat.rows(); ++i) {
            rows[i] = i;
        }
        int[] cols = new int[mat.columns() - 1];
        int m = -1;
        for (int i = 0; i < mat.columns(); ++i) {
            if (i == index) continue;
            cols[++m] = i;
        }
        return mat.viewSelection(rows, cols).copy();
    }

    public static DoubleMatrix2D removeRow(DoubleMatrix2D mat, int index) {
        int[] cols = new int[mat.columns()];
        for (int i = 0; i < mat.columns(); ++i) {
            cols[i] = i;
        }
        int[] rows = new int[mat.columns() - 1];
        int m = -1;
        for (int i = 0; i < mat.columns(); ++i) {
            if (i == index) continue;
            rows[++m] = i;
        }
        return mat.viewSelection(rows, cols).copy();
    }

    public static ColtDataSet lingamDiscoveryStep5(ColtDataSet bhat) {
        return Shimizu2006SearchCleanup.iterativeC(bhat);
    }

    public static ShimizuResult makeDagWithParms(ColtDataSet ltDataSet) {
        DoubleMatrix2D ltMat = ltDataSet.getDoubleData();
        int n = ltMat.rows();
        List<Node> variables = ltDataSet.getVariables();
        Dag dag = new Dag(variables);
        ShimizuResult dwp = new ShimizuResult(dag);
        for (int i = 0; i < ltMat.rows(); ++i) {
            for (int j = 0; j < i; ++j) {
                if (ltMat.get(i, j) == 0.0) continue;
                Edge edge = new Edge(variables.get(i), variables.get(j), Endpoint.TAIL, Endpoint.ARROW);
                dwp.getGraph().addEdge(edge);
                dwp.setWeight(edge, ltMat.get(i, j));
            }
        }
        return dwp;
    }

    public static ShimizuResult lingamDiscoveryDag(DataSet dataSet) {
        ColtDataSet Bhat = Shimizu2006SearchCleanup.lingamDiscoveryBhat(dataSet);
        int nPieces = new Double(Math.floor(Math.sqrt(dataSet.getDoubleData().rows()) / 2.0)).intValue();
        Shimizu2006SearchCleanup.pruneEdges(Bhat, dataSet, 0.05, nPieces);
        ColtDataSet B = Shimizu2006SearchCleanup.lingamDiscoveryStep5(Bhat);
        ShimizuResult icaDag = Shimizu2006SearchCleanup.makeDagWithParms(B);
        return icaDag;
    }

    private static DoubleMatrix2D assignment2Matrix(DoubleMatrix2D mat, int[][] assignment) {
        DenseDoubleMatrix2D swappedMat = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int i = 0; i < mat.rows(); ++i) {
            int newRowIndex = assignment[i][1];
            swappedMat.viewRow(newRowIndex).assign(mat.viewRow(i));
        }
        return swappedMat;
    }

    private static DoubleMatrix2D permuteZerolessDiagonal(DoubleMatrix2D w) {
        DenseDoubleMatrix2D temp = new DenseDoubleMatrix2D(w.rows(), w.columns());
        ((DoubleMatrix2D)temp).assign(w);
        ((DoubleMatrix2D)temp).assign(Functions.inv);
        ((DoubleMatrix2D)temp).assign(Functions.abs);
        int[][] assignment = Hungarian.hgAlgorithm(MatrixUtils.convert(temp), "min");
        return Shimizu2006SearchCleanup.assignment2Matrix(w, assignment);
    }

    private static ColtDataSet permute(Vector<Integer> permutation, ColtDataSet dataSet) {
        List<Node> resNodes = Shimizu2006SearchCleanup.permute(permutation, dataSet.getVariables());
        DoubleMatrix2D resMatrix = Shimizu2006SearchCleanup.permuteRows(permutation, dataSet.getDoubleData());
        resMatrix = Shimizu2006SearchCleanup.permuteColumns(permutation, resMatrix);
        ColtDataSet resDataSet = Shimizu2006SearchCleanup.makeDataSet(resMatrix, resNodes);
        return resDataSet;
    }

    private static DoubleMatrix2D permuteRows(Vector<Integer> permutation, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int i = 0; i < mat.rows(); ++i) {
            DoubleMatrix1D row = mat.viewRow(permutation.get(i));
            m.viewRow(i).assign(row);
        }
        return m;
    }

    private static DoubleMatrix2D permuteColumns(Vector<Integer> permutation, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int i = 0; i < mat.columns(); ++i) {
            DoubleMatrix1D col = mat.viewColumn(permutation.get(i));
            m.viewColumn(i).assign(col);
        }
        return m;
    }

    private static List<Node> permute(Vector<Integer> permutation, List<Node> variables) {
        int i;
        Vector<Node> nodes = new Vector<Node>(variables.size());
        for (i = 0; i < variables.size(); ++i) {
            nodes.add(null);
        }
        for (i = 0; i < variables.size(); ++i) {
            nodes.set(i, variables.get(permutation.get(i)));
        }
        return nodes;
    }

    private static DoubleMatrix2D setEntriesToZero(Vector<Entry> entries, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        ((DoubleMatrix2D)m).assign(mat);
        for (Entry entry : entries) {
            m.set(entry.row, entry.column, 0.0);
        }
        return m;
    }

    private static DoubleMatrix2D setEntryToZero(Entry entry, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        ((DoubleMatrix2D)m).assign(mat);
        m.set(entry.row, entry.column, 0.0);
        return m;
    }

    private static DoubleMatrix2D removeNthRowAndColumn(int n, DoubleMatrix2D mat) {
        DoubleMatrix2D result = Shimizu2006SearchCleanup.removeRow(mat, n);
        result = Shimizu2006SearchCleanup.removeColumn(result, n);
        return result;
    }

    private static boolean containsAllZeros(DoubleMatrix1D vec, Vector<Integer> removedIndices) {
        if (vec.size() > 20) {
            System.exit(0);
        }
        for (int i = 0; i < vec.size(); ++i) {
            if (vec.get(i) == 0.0 || removedIndices.contains(i)) continue;
            return false;
        }
        return true;
    }

    private static ColtDataSet pruneEdges(ColtDataSet B, DataSet dataSet, double alpha, int nPieces) {
        return Shimizu2006SearchCleanup.pruneEdgesBySampling(B, dataSet, alpha, nPieces);
    }

    private static ColtDataSet pruneEdgesBySampling(ColtDataSet B, DataSet dataSet, double alpha, int nPieces) {
        ColtDataSet[] bHats = new ColtDataSet[nPieces];
        for (int i = 0; i < nPieces; ++i) {
            DataSet currentDataSet = Shimizu2006SearchCleanup.getSubsetOfDataSet(dataSet, nPieces, i);
            bHats[i] = Shimizu2006SearchCleanup.lingamDiscoveryBhat(currentDataSet);
        }
        DenseDoubleMatrix2D meanMat = new DenseDoubleMatrix2D(B.getDoubleData().rows(), B.getDoubleData().columns());
        for (int i = 0; i < B.getDoubleData().rows(); ++i) {
            for (int j = 0; j < B.getDoubleData().columns(); ++j) {
                double sum = 0.0;
                for (int z = 0; z < nPieces; ++z) {
                    sum += bHats[z].getDouble(i, j);
                }
                meanMat.set(i, j, sum / (double)nPieces);
            }
        }
        DenseDoubleMatrix2D sdMat = new DenseDoubleMatrix2D(B.getDoubleData().rows(), B.getDoubleData().columns());
        for (int i = 0; i < B.getDoubleData().rows(); ++i) {
            for (int j = 0; j < B.getDoubleData().columns(); ++j) {
                double varianceSum = 0.0;
                for (int z = 0; z < nPieces; ++z) {
                    varianceSum += Math.pow(bHats[z].getDouble(i, j) - meanMat.get(i, j), 2.0);
                }
                double variance = varianceSum / (double)nPieces;
                sdMat.set(i, j, Math.pow(variance, 0.5));
            }
        }
        double pruneFactor = 1.0;
        for (int i = 0; i < B.getDoubleData().rows(); ++i) {
            for (int j = 0; j < B.getDoubleData().columns(); ++j) {
                if (!(Math.abs(meanMat.get(i, j)) < sdMat.get(i, j) * pruneFactor)) continue;
                B.getDoubleData().set(i, j, 0.0);
            }
        }
        return null;
    }

    private static DataSet getSubsetOfDataSet(DataSet dataSet, int pieces, int pieceIndex) {
        DoubleMatrix2D mat = dataSet.getDoubleData();
        int pieceSize = mat.rows() / pieces;
        DoubleMatrix2D res = mat.viewPart(pieceSize * pieceIndex, 0, pieceSize, mat.columns());
        return Shimizu2006SearchCleanup.makeDataSet(res, dataSet.getVariables());
    }

    private static DataSet normalizeVariance(DataSet dataSet) {
        DoubleMatrix2D mat = dataSet.getDoubleData();
        DenseDoubleMatrix2D res = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int j = 0; j < mat.columns(); ++j) {
            double sum = 0.0;
            for (int i = 0; i < mat.rows(); ++i) {
                sum += mat.get(i, j);
            }
            double mean = sum / (double)mat.rows();
            double sumDistSq = 0.0;
            for (int i = 0; i < mat.rows(); ++i) {
                sumDistSq += Math.pow(mat.get(i, j) - mean, 2.0);
            }
            double variance = sumDistSq / (double)mat.rows();
            double sd = Math.pow(variance, 0.5);
            for (int i = 0; i < mat.rows(); ++i) {
                res.set(i, j, mat.get(i, j) / sd);
            }
        }
        return Shimizu2006SearchCleanup.makeDataSet(res, dataSet.getVariables());
    }

    public static class Entry
    implements Comparable<Entry> {
        int row;
        int column;
        double value;

        public Entry(int row, int col, double val) {
            this.row = row;
            this.column = col;
            this.value = val;
        }

        @Override
        public int compareTo(Entry entry) {
            double thisVal = Math.abs(this.value);
            double entryVal = Math.abs(entry.value);
            return new Double(thisVal).compareTo(entryVal);
        }

        public String toString() {
            return "[" + this.row + "," + this.column + "]:" + this.value + " ";
        }
    }
}

