/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.cluster.KMeans;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Demixer;
import edu.cmu.tetrad.search.MixtureModel;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.pitt.dbmi.data.reader.Delimiter;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DemixerMMLKun {
    private final double minWeight;

    public DemixerMMLKun() {
        this.minWeight = 0.001;
    }

    public static void main(String ... args) {
        double[] weights;
        DataSet dataSet;
        try {
            dataSet = SimpleDataLoader.loadContinuousData(new File("/Users/user/Documents/Demix_Testing/NonGaussian/sub_1500_4var_3comp.txt"), "//", '\"', "*", true, Delimiter.TAB, false);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        DemixerMMLKun pedro = new DemixerMMLKun();
        long startTime = System.currentTimeMillis();
        MixtureModel model = pedro.demix(dataSet, 25);
        long elapsed = System.currentTimeMillis() - startTime;
        for (double weight : weights = model.getWeights()) {
            System.out.print(weight + "\t");
        }
        try {
            FileWriter writer = new FileWriter("/Users/user/Documents/Demix_Testing/sub_1500_4var_3comp.txt");
            BufferedWriter bufferedWriter = new BufferedWriter(writer);
            for (int i = 0; i < dataSet.getNumRows(); ++i) {
                bufferedWriter.write(model.getDistribution(i) + "\n");
            }
            bufferedWriter.flush();
            bufferedWriter.close();
            DataSet[] dataSets = model.getDemixedData();
            for (int i = 0; i < dataSets.length; ++i) {
                writer = new FileWriter("/Users/user/Documents/Demix_Testing/sub_1500_4var_3comp_demixed_" + (i + 1) + ".txt");
                bufferedWriter = new BufferedWriter(writer);
                bufferedWriter.write(dataSets[i].toString());
                bufferedWriter.flush();
                bufferedWriter.close();
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        System.out.println("Elapsed: " + elapsed / 1000L);
    }

    private MixtureModel demix(DataSet data, int k) {
        double[][] dataArray = data.getDoubleData().toArray();
        int numVars = data.getNumColumns();
        int numCases = data.getNumRows();
        double lambda2 = Math.sqrt((Math.log(numCases) - Math.log(Math.log(numCases))) / 2.0);
        double lambda = lambda2 * (Math.pow(numVars, 2.0) * 0.5 + 1.5 * (double)numVars + 1.0);
        double epsilon = 1.0E-6;
        double threshold = 1.0E-8;
        System.out.println("Lambda: " + lambda);
        KMeans kMeans = KMeans.randomClusters(k);
        kMeans.cluster(data.getDoubleData());
        double[][] meansArray = new double[k][numVars];
        double[] weightsArray = new double[k];
        Matrix[] variancesArray = new Matrix[k];
        Matrix[] variances = new Matrix[k];
        double[][] gammaArray = new double[k][numCases];
        List<List<Integer>> clusters = kMeans.getClusters();
        for (int i = 0; i < clusters.size(); ++i) {
            int j;
            List<Integer> cluster = clusters.get(i);
            int clusterSize = cluster.size();
            double[] means = new double[numVars];
            for (j = 0; j < numVars; ++j) {
                means[j] = 0.0;
            }
            double[][] clusterMatrixArray = new double[clusterSize][numVars];
            for (j = 0; j < clusterSize; ++j) {
                MatrixUtils.sum(means, dataArray[cluster.get(j)]);
                clusterMatrixArray[j] = dataArray[cluster.get(j)];
            }
            means = MatrixUtils.scalarProduct(1.0 / (double)clusterSize, means);
            meansArray[i] = means;
            weightsArray[i] = (double)clusterSize / (double)numCases;
            DoubleDataBox box = new DoubleDataBox(clusterMatrixArray);
            List<Node> variables = data.getVariables();
            BoxDataSet clusterData = new BoxDataSet(box, variables);
            Matrix clusterCovMatrix = clusterData.getCovarianceMatrix();
            if (MatrixUtils.determinant(clusterCovMatrix.toArray()) == 0.0) {
                variances[i] = MatrixUtils.cholesky(data.getCovarianceMatrix());
                variancesArray[i] = data.getCovarianceMatrix();
                continue;
            }
            variances[i] = MatrixUtils.cholesky(clusterCovMatrix);
            variancesArray[i] = clusterCovMatrix;
        }
        for (int z = 0; z < k; ++z) {
            for (int j = 0; j < numCases; ++j) {
                double gamma;
                double divisor = gamma = weightsArray[z] * this.normalPDF(j, z, variances, meansArray, dataArray, numVars);
                for (int w = 0; w < k; ++w) {
                    if (w == z) continue;
                    divisor += weightsArray[w] * this.normalPDF(j, w, variances, meansArray, dataArray, numVars);
                }
                gammaArray[z][j] = gamma /= divisor;
            }
        }
        System.out.println("Clusters: " + k);
        System.out.println("Weights: " + Arrays.toString(weightsArray));
        double oldLogL = Double.POSITIVE_INFINITY;
        while (true) {
            DeterminingStats stats = this.innerStep(data, dataArray, weightsArray, meansArray, variancesArray, variances, gammaArray, numCases, numVars, lambda);
            meansArray = stats.getMeans();
            weightsArray = stats.getWeights();
            variancesArray = stats.getVariances();
            variances = stats.getVarMatrixArray();
            k = weightsArray.length;
            if (k == 0) break;
            System.out.println("Clusters: " + k);
            System.out.println("Weights: " + Arrays.toString(weightsArray));
            for (int i = 0; i < k; ++i) {
                for (int j = 0; j < numCases; ++j) {
                    double gamma;
                    double pdf = this.normalPDF(j, i, variances, meansArray, dataArray, numVars);
                    double divisor = gamma = weightsArray[i] * pdf;
                    for (int w = 0; w < k; ++w) {
                        if (w == i) continue;
                        divisor += weightsArray[w] * this.normalPDF(j, w, variances, meansArray, dataArray, numVars);
                    }
                    gammaArray[i][j] = gamma /= divisor;
                }
            }
            double mml = 0.0;
            for (int i = 0; i < weightsArray.length; ++i) {
                double gammaMean = 0.0;
                for (int j = 0; j < numCases; ++j) {
                    gammaMean += gammaArray[i][j];
                }
                mml += Math.log(gammaMean /= (double)numCases);
            }
            mml /= (double)weightsArray.length;
            double weightSum = 0.0;
            for (double v : weightsArray) {
                weightSum += Math.log(v / epsilon + 1.0);
            }
            double newLogL = mml + (weightSum *= lambda / (double)numCases);
            if (Math.abs(oldLogL / newLogL - 1.0) < threshold) break;
            oldLogL = newLogL;
        }
        return new MixtureModel(data, dataArray, meansArray, weightsArray, variancesArray, gammaArray);
    }

    private DeterminingStats innerStep(DataSet data, double[][] dataArray, double[] weightsArray, double[][] meansArray, Matrix[] variancesArray, Matrix[] variances, double[][] gammaArray, int numCases, int numVars, double lambda) {
        ArrayList<double[]> meansList = new ArrayList<double[]>();
        ArrayList<Matrix> varsLilst = new ArrayList<Matrix>();
        ArrayList<Matrix> varMatList = new ArrayList<Matrix>();
        for (int i = 0; i < weightsArray.length; ++i) {
            double weight;
            double pSum = 0.0;
            for (int j = 0; j < numCases; ++j) {
                pSum += gammaArray[i][j];
            }
            weightsArray[i] = weight = (pSum - lambda) / ((double)numCases - lambda * (double)weightsArray.length);
            Matrix tempVar = new Matrix(numVars, numVars);
            for (int v = 0; v < numVars; ++v) {
                double mean;
                double meanNumerator = 0.0;
                for (int j = 0; j < numCases; ++j) {
                    meanNumerator += gammaArray[i][j] * dataArray[j][v];
                }
                meansArray[i][v] = mean = meanNumerator / pSum;
                for (int v2 = v; v2 < numVars; ++v2) {
                    double var = Demixer.getVar(i, v, v2, numCases, gammaArray, dataArray, meansArray);
                    tempVar.set(v, v2, var);
                    tempVar.set(v2, v, var);
                }
            }
            Matrix varMatrix = new Matrix(tempVar);
            if (varMatrix.det() != 0.0) {
                variancesArray[i] = MatrixUtils.cholesky(tempVar);
                variances[i] = MatrixUtils.cholesky(varMatrix);
                continue;
            }
            variances[i] = MatrixUtils.cholesky(data.getCovarianceMatrix());
            variancesArray[i] = data.getCovarianceMatrix();
        }
        System.out.println();
        ArrayList<Double> weightsList = new ArrayList<Double>();
        for (int i = 0; i < weightsArray.length; ++i) {
            if (!(weightsArray[i] >= this.minWeight)) continue;
            weightsList.add(weightsArray[i]);
            meansList.add(meansArray[i]);
            varsLilst.add(variancesArray[i]);
            varMatList.add(variances[i]);
        }
        double[] tempWeightsArray = new double[weightsList.size()];
        double[][] tempMeansArray = new double[weightsList.size()][numVars];
        Matrix[] tempVarsArray = new Matrix[weightsList.size()];
        Matrix[] tempVariances = new Matrix[weightsList.size()];
        for (int i = 0; i < weightsList.size(); ++i) {
            tempWeightsArray[i] = (Double)weightsList.get(i);
            tempMeansArray[i] = (double[])meansList.get(i);
            tempVarsArray[i] = (Matrix)varsLilst.get(i);
            tempVariances[i] = (Matrix)varMatList.get(i);
        }
        weightsArray = tempWeightsArray;
        meansArray = tempMeansArray;
        variancesArray = tempVarsArray;
        variances = tempVariances;
        return new DeterminingStats(meansArray, weightsArray, variancesArray, variances);
    }

    private double normalPDF(int caseIndex, int weightIndex, Matrix[] variances, double[][] meansArray, double[][] dataArray, int numVars) {
        Matrix cov = variances[weightIndex];
        cov = cov.transpose();
        Matrix covIn = cov.inverse();
        double[] mu = meansArray[weightIndex];
        double[] thisCase = dataArray[caseIndex];
        double[][] diffs = new double[1][numVars];
        for (int i = 0; i < numVars; ++i) {
            diffs[0][i] = thisCase[i] - mu[i];
        }
        Matrix diffsMatrix = new Matrix(diffs);
        Matrix mah = diffsMatrix.times(covIn);
        double mahScal = 0.0;
        for (int i = 0; i < mah.getNumRows(); ++i) {
            for (int j = 0; j < mah.getNumColumns(); ++j) {
                double val = mah.get(i, j);
                val *= val;
                mahScal += val;
                mah.set(i, j, val);
            }
        }
        double distanceScal = Math.pow(Math.PI * 2, (double)(-numVars) / 2.0);
        distanceScal /= cov.det();
        return distanceScal *= Math.exp(-0.5 * mahScal);
    }

    private static class DeterminingStats {
        private final double[][] meansArray;
        private final double[] weightsArray;
        private final Matrix[] variancesArray;
        private final Matrix[] varMatrixArray;

        public DeterminingStats(double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, Matrix[] varMatrixArray) {
            this.meansArray = meansArray;
            this.weightsArray = weightsArray;
            this.variancesArray = variancesArray;
            this.varMatrixArray = varMatrixArray;
        }

        public double[] getWeights() {
            return this.weightsArray;
        }

        public double[][] getMeans() {
            return this.meansArray;
        }

        public Matrix[] getVariances() {
            return this.variancesArray;
        }

        public Matrix[] getVarMatrixArray() {
            return this.varMatrixArray;
        }
    }
}

