/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.ExploreFastIca;
import edu.cmu.tetrad.search.GaussianConfidenceInterval;
import edu.cmu.tetrad.search.GraphWithParameters;
import edu.cmu.tetrad.search.Hungarian;
import edu.cmu.tetrad.search.SemLearningMethod;
import edu.cmu.tetrad.search.fastica.FastICA;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.StatUtils;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Vector;

public class Shimizu2006SearchOld
implements SemLearningMethod {
    private double alpha;
    static boolean isGraphUsed = false;
    static GraphWithParameters lastShimizuGraph = null;
    public static Double lastAlpha = null;

    public double getAlpha() {
        return this.alpha;
    }

    public static boolean isGraphUsed() {
        return isGraphUsed;
    }

    public static void setIsGraphUsed(boolean value) {
        isGraphUsed = value;
    }

    public double getLastAlpha() {
        return lastAlpha;
    }

    public Shimizu2006SearchOld(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public String getName() {
        return "Shimizu(alpha=" + this.alpha + ")";
    }

    public GraphWithParameters run(DataSet dataSet) {
        lastShimizuGraph = Shimizu2006SearchOld.lingamDiscovery_DAG(dataSet, this.alpha);
        Shimizu2006SearchOld.setIsGraphUsed(false);
        return lastShimizuGraph;
    }

    private static DoubleMatrix2D assignment2Matrix(DoubleMatrix2D mat, int[][] assignment) {
        DenseDoubleMatrix2D swappedMat = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int i = 0; i < mat.rows(); ++i) {
            int newRowIndex = assignment[i][1];
            swappedMat.viewRow(newRowIndex).assign(mat.viewRow(i));
        }
        return swappedMat;
    }

    private static DoubleMatrix2D permuteZerolessDiagonal(DoubleMatrix2D w) {
        DenseDoubleMatrix2D temp = new DenseDoubleMatrix2D(w.rows(), w.columns());
        ((DoubleMatrix2D)temp).assign(w);
        ((DoubleMatrix2D)temp).assign(Functions.inv);
        ((DoubleMatrix2D)temp).assign(Functions.abs);
        int[][] assignment = Hungarian.hgAlgorithm(MatrixUtils.convert(temp), "min");
        return Shimizu2006SearchOld.assignment2Matrix(w, assignment);
    }

    public static ColtDataSet lingamDiscovery_Bhat(DataSet dataSet) {
        DoubleMatrix2D A = null;
        Object Btilde = null;
        DoubleMatrix2D Bhat = null;
        DoubleMatrix2D zldWprime = null;
        DoubleMatrix2D data = dataSet.getDoubleData().viewDice();
        boolean runColt = false;
        boolean runArray = true;
        try {
            if (runArray) {
                double[][] inV = MatrixUtils.convert(data);
                long sTime = new Date().getTime();
                FastICA fica = new FastICA(inV, data.rows());
                long eTime = new Date().getTime();
                A = MatrixUtils.convertToColt(fica.getMixingMatrix());
            } else if (runColt) {
                // empty if block
            }
            DoubleMatrix2D W = MatrixUtils.inverse(A);
            int n = W.rows();
            if (W.rows() != W.columns()) {
                System.out.println("W = " + W);
                new Exception("W is not square!").printStackTrace();
            }
            if (W.rows() != dataSet.getNumColumns()) {
                new Exception("W does not have the right number of dimensions!").printStackTrace();
            }
            DoubleMatrix2D zldW = Shimizu2006SearchOld.permuteZerolessDiagonal(W);
            zldWprime = MatrixUtils.normalizeDiagonal(zldW);
            for (int i = 0; i < zldWprime.rows(); ++i) {
            }
            Bhat = MatrixUtils.linearCombination(MatrixUtils.identityMatrix(n), 1.0, zldWprime, -1.0);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        if (Bhat == null) {
            return null;
        }
        DoubleMatrix2D BhatT = Bhat.viewDice();
        return Shimizu2006SearchOld.makeDataSet(BhatT, dataSet.getVariables());
    }

    public static Vector<Integer> dagPermutation(DoubleMatrix2D mat) {
        Vector<Integer> removedIndices = new Vector<Integer>();
        Vector<Integer> v = new Vector<Integer>();
        while (removedIndices.size() < mat.rows()) {
            int allZerosRow = -1;
            for (int i = 0; i < mat.rows(); ++i) {
                if (removedIndices.contains(i) || !Shimizu2006SearchOld.containsAllZeros(mat.viewRow(i), removedIndices)) continue;
                allZerosRow = i;
                break;
            }
            if (allZerosRow == -1) {
                return null;
            }
            removedIndices.add(allZerosRow);
            v.add(allZerosRow);
        }
        return v;
    }

    public static Vector<Entry> getEntries(DoubleMatrix2D mat) {
        Vector<Entry> entries = new Vector<Entry>();
        for (int i = 0; i < mat.rows(); ++i) {
            for (int j = 0; j < mat.columns(); ++j) {
                Entry entry = new Entry(i, j, mat.get(i, j));
                entries.add(entry);
            }
        }
        return entries;
    }

    public static ColtDataSet iterativeC(ColtDataSet dataSet) {
        Vector<Integer> permutation;
        DoubleMatrix2D mat = dataSet.getDoubleData();
        int n = mat.rows();
        Vector<Entry> remainingEntries = Shimizu2006SearchOld.getEntries(mat);
        Collections.sort(remainingEntries);
        Vector<Entry> entryQueue = Shimizu2006SearchOld.getNFirst(n * (n + 1) / 2, remainingEntries);
        DoubleMatrix2D tempMat = Shimizu2006SearchOld.setEntriesToZero(entryQueue, mat);
        remainingEntries.removeAll(entryQueue);
        entryQueue.removeAllElements();
        while ((permutation = Shimizu2006SearchOld.dagPermutation(tempMat)) == null) {
            Entry entry = remainingEntries.get(0);
            tempMat = Shimizu2006SearchOld.setEntryToZero(entry, tempMat);
            remainingEntries.remove(0);
        }
        return Shimizu2006SearchOld.permute(permutation, dataSet);
    }

    private static ColtDataSet permute(Vector<Integer> permutation, ColtDataSet dataSet) {
        List<Node> resNodes = Shimizu2006SearchOld.permute(permutation, dataSet.getVariables());
        DoubleMatrix2D resMatrix = Shimizu2006SearchOld.permuteRows(permutation, dataSet.getDoubleData());
        resMatrix = Shimizu2006SearchOld.permuteColumns(permutation, resMatrix);
        ColtDataSet resDataSet = Shimizu2006SearchOld.makeDataSet(resMatrix, resNodes);
        return resDataSet;
    }

    private static DoubleMatrix2D permuteRows(Vector<Integer> permutation, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int i = 0; i < mat.rows(); ++i) {
            DoubleMatrix1D row = mat.viewRow(permutation.get(i));
            m.viewRow(i).assign(row);
        }
        return m;
    }

    private static DoubleMatrix2D permuteColumns(Vector<Integer> permutation, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int i = 0; i < mat.columns(); ++i) {
            DoubleMatrix1D col = mat.viewColumn(permutation.get(i));
            m.viewColumn(i).assign(col);
        }
        return m;
    }

    private static List<Node> permute(Vector<Integer> permutation, List<Node> variables) {
        int i;
        Vector<Node> nodes = new Vector<Node>(variables.size());
        for (i = 0; i < variables.size(); ++i) {
            nodes.add(null);
        }
        for (i = 0; i < variables.size(); ++i) {
            nodes.set(i, variables.get(permutation.get(i)));
        }
        return nodes;
    }

    private static DoubleMatrix2D setEntriesToZero(Vector<Entry> entries, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        ((DoubleMatrix2D)m).assign(mat);
        for (Entry entry : entries) {
            m.set(entry.row, entry.column, 0.0);
        }
        return m;
    }

    private static DoubleMatrix2D setEntryToZero(Entry entry, DoubleMatrix2D mat) {
        DenseDoubleMatrix2D m = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        ((DoubleMatrix2D)m).assign(mat);
        m.set(entry.row, entry.column, 0.0);
        return m;
    }

    public static Vector<Entry> getNFirst(int n, Vector<Entry> list) {
        return new Vector<Entry>(list.subList(0, n));
    }

    public static DoubleMatrix2D removeColumn(DoubleMatrix2D mat, int index) {
        int[] rows = new int[mat.rows()];
        for (int i = 0; i < mat.rows(); ++i) {
            rows[i] = i;
        }
        int[] cols = new int[mat.columns() - 1];
        int m = -1;
        for (int i = 0; i < mat.columns(); ++i) {
            if (i == index) continue;
            cols[++m] = i;
        }
        return mat.viewSelection(rows, cols).copy();
    }

    public static DoubleMatrix2D removeRow(DoubleMatrix2D mat, int index) {
        int[] cols = new int[mat.columns()];
        for (int i = 0; i < mat.columns(); ++i) {
            cols[i] = i;
        }
        int[] rows = new int[mat.columns() - 1];
        int m = -1;
        for (int i = 0; i < mat.columns(); ++i) {
            if (i == index) continue;
            rows[++m] = i;
        }
        return mat.viewSelection(rows, cols).copy();
    }

    private static DoubleMatrix2D removeNthRowAndColumn(int n, DoubleMatrix2D mat) {
        System.out.println("removeNthRowAndColumn: mat = " + mat);
        DoubleMatrix2D result = Shimizu2006SearchOld.removeRow(mat, n);
        result = Shimizu2006SearchOld.removeColumn(result, n);
        return result;
    }

    private static boolean containsAllZeros(DoubleMatrix1D vec, Vector<Integer> removedIndices) {
        if (vec.size() > 20) {
            System.exit(0);
        }
        for (int i = 0; i < vec.size(); ++i) {
            if (vec.get(i) == 0.0 || removedIndices.contains(i)) continue;
            return false;
        }
        return true;
    }

    public static ColtDataSet lingamDiscovery_step5(ColtDataSet bhat) {
        return Shimizu2006SearchOld.iterativeC(bhat);
    }

    public static GraphWithParameters makeDagWithParms(ColtDataSet ltDataSet) {
        DoubleMatrix2D ltMat = ltDataSet.getDoubleData();
        int n = ltMat.rows();
        List<Node> variables = ltDataSet.getVariables();
        Dag dag = new Dag(variables);
        GraphWithParameters dwp = new GraphWithParameters(dag);
        for (int i = 0; i < ltMat.rows(); ++i) {
            for (int j = 0; j < i; ++j) {
                if (ltMat.get(i, j) == 0.0) continue;
                Edge edge = new Edge(variables.get(i), variables.get(j), Endpoint.TAIL, Endpoint.ARROW);
                dwp.getGraph().addEdge(edge);
                dwp.getWeightHash().put(edge, ltMat.get(i, j));
            }
        }
        return dwp;
    }

    public GraphWithParameters lingamDiscovery_DAG(DataSet dataSet) {
        return Shimizu2006SearchOld.lingamDiscovery_DAG(dataSet, this.alpha);
    }

    public static GraphWithParameters lingamDiscovery_DAG(DataSet dataSet, double alpha) {
        GraphWithParameters icaGraph;
        ColtDataSet Bhat = Shimizu2006SearchOld.lingamDiscovery_Bhat(dataSet);
        if (Bhat == null) {
            System.out.println("ICA throws an exception!");
            return null;
        }
        ColtDataSet B = Shimizu2006SearchOld.lingamDiscovery_step5(Bhat);
        int nSamples = 15;
        Shimizu2006SearchOld.pruneEdges(B, dataSet, alpha, nSamples);
        lastShimizuGraph = icaGraph = Shimizu2006SearchOld.makeDagWithParms(B);
        isGraphUsed = false;
        return icaGraph;
    }

    private static void pruneEdges(ColtDataSet B, DataSet dataSet, double alpha, int nSamples) {
        Shimizu2006SearchOld.pruneEdgesBySampling(B, dataSet, alpha, nSamples, true);
    }

    private static void pruneEdgesBySampling(ColtDataSet B, DataSet dataSet, double alpha, int nSamples, boolean isBootstrapSampling) {
        ColtDataSet[] bHats = new ColtDataSet[nSamples];
        for (int i = 0; i < nSamples; ++i) {
            System.out.print("i = ");
            System.out.print("" + i + " ");
            ColtDataSet bHat = null;
            while (bHat == null) {
                DataSet currentBootstrapSample = isBootstrapSampling ? Shimizu2006SearchOld.getBootstrapSample(dataSet) : Shimizu2006SearchOld.getSubsetOfDataSet(dataSet, nSamples, i);
                bHat = Shimizu2006SearchOld.lingamDiscovery_Bhat(currentBootstrapSample);
            }
            bHats[i] = bHat;
        }
        System.out.println("");
        DenseDoubleMatrix2D meanMat = new DenseDoubleMatrix2D(B.getDoubleData().rows(), B.getDoubleData().columns());
        for (int i = 0; i < B.getDoubleData().rows(); ++i) {
            for (int j = 0; j < B.getDoubleData().columns(); ++j) {
                double sum = 0.0;
                for (int z = 0; z < nSamples; ++z) {
                    sum += bHats[z].getDouble(i, j);
                }
                meanMat.set(i, j, sum / (double)nSamples);
            }
        }
        DenseDoubleMatrix2D sdMat = new DenseDoubleMatrix2D(B.getDoubleData().rows(), B.getDoubleData().columns());
        for (int i = 0; i < B.getDoubleData().rows(); ++i) {
            for (int j = 0; j < B.getDoubleData().columns(); ++j) {
                double varianceSum = 0.0;
                for (int z = 0; z < nSamples; ++z) {
                    varianceSum += Math.pow(bHats[z].getDouble(i, j) - meanMat.get(i, j), 2.0);
                }
                double variance = varianceSum / (double)nSamples;
                sdMat.set(i, j, Math.pow(variance, 0.5));
            }
        }
        GaussianConfidenceInterval gci = new GaussianConfidenceInterval(alpha);
        for (int i = 0; i < B.getDoubleData().rows(); ++i) {
            for (int j = 0; j < B.getDoubleData().columns(); ++j) {
                if (gci.test(0.0, meanMat.get(i, j), sdMat.get(i, j), nSamples)) continue;
                List<Node> swappedVariables = B.getVariables();
                List<Node> originalVariables = dataSet.getVariables();
                int xi = Shimizu2006SearchOld.findColumn(i, swappedVariables, originalVariables);
                int xj = Shimizu2006SearchOld.findColumn(j, swappedVariables, originalVariables);
                B.setDouble(xi, xj, 0.0);
            }
        }
    }

    private static int diversity(DataSet currentDataSet) {
        int n = 0;
        DoubleMatrix2D mat = currentDataSet.getDoubleData();
        for (int i = 0; i < mat.rows(); ++i) {
            if (Shimizu2006SearchOld.isContained(mat.viewRow(i), mat.viewPart(0, 0, i, mat.columns()))) continue;
            ++n;
        }
        return n;
    }

    private static boolean isContained(DoubleMatrix1D vec, DoubleMatrix2D mat) {
        for (int i = 0; i < mat.rows(); ++i) {
            if (!mat.viewRow(i).equals(vec)) continue;
            return true;
        }
        return false;
    }

    private static int findColumn(int index, List<Node> variables, List<Node> originalVariables) {
        int i = 0;
        for (Node node : variables) {
            if (node.getName().equals(originalVariables.get(index).getName())) {
                return i;
            }
            ++i;
        }
        System.out.println(" index = " + index + " variables = " + variables + " originalVariables = " + originalVariables);
        new Exception("not found").printStackTrace();
        return -1;
    }

    private static DataSet getSubsetOfDataSet(DataSet dataSet, int pieces, int pieceIndex) {
        DoubleMatrix2D mat = dataSet.getDoubleData();
        int pieceSize = mat.rows() / pieces;
        DoubleMatrix2D res = mat.viewPart(pieceSize * pieceIndex, 0, pieceSize, mat.columns());
        return Shimizu2006SearchOld.makeDataSet(res, dataSet.getVariables());
    }

    private static DataSet getBootstrapSample(DataSet dataSet) {
        DataSet sample = null;
        int diversity = 0;
        while (diversity <= dataSet.getNumColumns()) {
            sample = Shimizu2006SearchOld.makeBootstrapSample(dataSet);
            diversity = Shimizu2006SearchOld.diversity(sample);
        }
        return sample;
    }

    private static DataSet makeBootstrapSample(DataSet dataSet) {
        int dataSize;
        DoubleMatrix2D mat = dataSet.getDoubleData();
        DenseDoubleMatrix2D bootstrapSample = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        int sampleSize = dataSize = mat.rows();
        for (int i = 0; i < sampleSize; ++i) {
            DoubleMatrix1D point = mat.viewRow(StatUtils.dieToss(dataSize));
            bootstrapSample.viewRow(i).assign(point);
        }
        return Shimizu2006SearchOld.makeDataSet(bootstrapSample, dataSet.getVariables());
    }

    private static DataSet normalizeVariance(DataSet dataSet) {
        DoubleMatrix2D mat = dataSet.getDoubleData();
        DenseDoubleMatrix2D res = new DenseDoubleMatrix2D(mat.rows(), mat.columns());
        for (int j = 0; j < mat.columns(); ++j) {
            double sum = 0.0;
            for (int i = 0; i < mat.rows(); ++i) {
                sum += mat.get(i, j);
            }
            double mean = sum / (double)mat.rows();
            double sumDistSq = 0.0;
            for (int i = 0; i < mat.rows(); ++i) {
                sumDistSq += Math.pow(mat.get(i, j) - mean, 2.0);
            }
            double variance = sumDistSq / (double)mat.rows();
            double sd = Math.pow(variance, 0.5);
            for (int i = 0; i < mat.rows(); ++i) {
                res.set(i, j, mat.get(i, j) / sd);
            }
        }
        return Shimizu2006SearchOld.makeDataSet(res, dataSet.getVariables());
    }

    @Override
    public GraphWithParameters run(DataSet dataSet, boolean estimateCoefficients, ExploreFastIca.PwpPlusGeneratingParameters standardPwpPlusParms) {
        return this.run(dataSet);
    }

    public static ColtDataSet makeDataSet(DoubleMatrix2D inVectors) {
        int n = inVectors.columns();
        Vector<Node> nodes = new Vector<Node>();
        for (int i = 0; i < n; ++i) {
            ContinuousVariable node = new ContinuousVariable("X" + (i + 1));
            nodes.add(node);
        }
        return Shimizu2006SearchOld.makeDataSet(inVectors, nodes);
    }

    public static ColtDataSet makeDataSet(DoubleMatrix2D inVectors, List<Node> nodes) {
        if (nodes == null) {
            // empty if block
        }
        if (inVectors.columns() != nodes.size()) {
            System.out.println("inVectors.columns() = " + inVectors.columns());
            System.out.println("nodes.size() = " + nodes.size());
            new Exception("dimensions don't match!").printStackTrace();
        }
        Vector<Node> variables = new Vector<Node>();
        for (Node node : nodes) {
            variables.add(new ContinuousVariable(node.getName()));
        }
        ColtDataSet data = ColtDataSet.makeContinuousData(variables, inVectors);
        return data;
    }

    public static class Entry
    implements Comparable<Entry> {
        int row;
        int column;
        double value;

        public Entry(int row, int col, double val) {
            this.row = row;
            this.column = col;
            this.value = val;
        }

        @Override
        public int compareTo(Entry entry) {
            double thisVal = Math.abs(this.value);
            double entryVal = Math.abs(entry.value);
            return new Double(thisVal).compareTo(entryVal);
        }

        public String toString() {
            return "[" + this.row + "," + this.column + "]:" + this.value + " ";
        }
    }
}

