/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.search.MixtureModel;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;

public class Demixer {
    private final int numVars;
    private final int numCases;
    private final int numClusters;
    private final DataSet data;
    private final double[][] dataArray;
    private final Matrix[] variances;
    private final double[][] meansArray;
    private final Matrix[] variancesArray;
    private final double[] weightsArray;
    private final double[][] gammaArray;
    private boolean demixed = false;

    public Demixer(DataSet data, int k) {
        int i;
        this.numClusters = k;
        this.data = data;
        this.dataArray = data.getDoubleData().toArray();
        this.numVars = data.getNumColumns();
        this.numCases = data.getNumRows();
        this.meansArray = new double[k][this.numVars];
        this.weightsArray = new double[k];
        this.variancesArray = new Matrix[k];
        this.variances = new Matrix[k];
        this.gammaArray = new double[k][this.numCases];
        Random rand = new Random();
        for (i = 0; i < this.numVars; ++i) {
            for (int j = 0; j < k; ++j) {
                this.meansArray[j][i] = this.calcMean(data.getDoubleData().getColumn(i)) + rand.nextGaussian();
            }
        }
        for (i = 0; i < k; ++i) {
            this.weightsArray[i] = Math.abs(1.0 / (double)k);
        }
        for (i = 0; i < k; ++i) {
            this.variances[i] = data.getCovarianceMatrix();
        }
    }

    public MixtureModel demix() {
        double[] tempWeights = new double[this.numClusters];
        System.arraycopy(this.weightsArray, 0, tempWeights, 0, this.numClusters);
        boolean weightsUnequal = true;
        int iterCounter = 0;
        System.out.println("Weights: " + Arrays.toString(this.weightsArray));
        while (weightsUnequal) {
            this.expectation();
            this.maximization();
            System.out.println("Weights: " + Arrays.toString(this.weightsArray));
            ArrayList<Double> diffsList = new ArrayList<Double>();
            for (int i = 0; i < this.numClusters; ++i) {
                diffsList.add(Math.abs(this.weightsArray[i] - tempWeights[i]));
            }
            Collections.sort(diffsList);
            if ((Double)diffsList.get(this.numClusters - 1) < 1.0E-4 || iterCounter > 100) {
                weightsUnequal = false;
            }
            System.arraycopy(this.weightsArray, 0, tempWeights, 0, this.numClusters);
            ++iterCounter;
        }
        MixtureModel model = new MixtureModel(this.data, this.dataArray, this.meansArray, this.weightsArray, this.variancesArray, this.gammaArray);
        this.demixed = true;
        return model;
    }

    public boolean isDemixed() {
        return this.demixed;
    }

    private void expectation() {
        for (int i = 0; i < this.numClusters; ++i) {
            for (int j = 0; j < this.numCases; ++j) {
                double gamma;
                double divisor = gamma = this.weightsArray[i] * this.normalPDF(j, i);
                for (int w = 0; w < this.numClusters; ++w) {
                    if (w == i) continue;
                    divisor += this.weightsArray[w] * this.normalPDF(j, w);
                }
                this.gammaArray[i][j] = gamma /= divisor;
            }
        }
    }

    private void maximization() {
        for (int i = 0; i < this.numClusters; ++i) {
            double weight = 0.0;
            for (int j = 0; j < this.numCases; ++j) {
                weight += this.gammaArray[i][j];
            }
            this.weightsArray[i] = weight /= (double)this.numCases;
        }
        for (int i = 0; i < this.numClusters; ++i) {
            for (int v = 0; v < this.numVars; ++v) {
                double mean;
                double meanNumerator = 0.0;
                double meanDivisor = 0.0;
                for (int j = 0; j < this.numCases; ++j) {
                    meanNumerator += this.gammaArray[i][j] * this.dataArray[j][v];
                    meanDivisor += this.gammaArray[i][j];
                }
                this.meansArray[i][v] = mean = meanNumerator / meanDivisor;
            }
        }
        for (int i = 0; i < this.numClusters; ++i) {
            for (int v = 0; v < this.numVars; ++v) {
                for (int v2 = v; v2 < this.numVars; ++v2) {
                    double var = Demixer.getVar(i, v, v2, this.numCases, this.gammaArray, this.dataArray, this.meansArray);
                    this.variancesArray[i].set(v, v2, var);
                    this.variancesArray[i].set(v2, v, var);
                }
            }
            this.variances[i] = new Matrix(this.variancesArray[i]);
        }
    }

    static double getVar(int i, int v, int v2, int numCases, double[][] gammaArray, double[][] dataArray, double[][] meansArray) {
        double varNumerator = 0.0;
        double varDivisor = 0.0;
        for (int j = 0; j < numCases; ++j) {
            varNumerator += gammaArray[i][j] * (dataArray[j][v] - meansArray[i][v]) * (dataArray[j][v2] - meansArray[i][v2]);
            varDivisor += gammaArray[i][j];
        }
        double var = varNumerator / varDivisor;
        return var;
    }

    private double normalPDF(int caseIndex, int weightIndex) {
        Matrix cov = this.variances[weightIndex];
        Matrix covIn = cov.inverse();
        double[] mu = this.meansArray[weightIndex];
        double[] thisCase = this.dataArray[caseIndex];
        double[][] diffs = new double[1][this.numVars];
        for (int i = 0; i < this.numVars; ++i) {
            diffs[0][i] = thisCase[i] - mu[i];
        }
        Matrix diffsMatrix = new Matrix(diffs);
        Matrix diffsTranspose = diffsMatrix.transpose();
        Matrix distance = covIn.times(diffsTranspose);
        distance = diffsMatrix.times(distance);
        double distanceScal = distance.get(0, 0);
        distanceScal *= -0.5;
        distanceScal = Math.exp(distanceScal);
        return distanceScal /= Math.sqrt(Math.PI * 2 * cov.det());
    }

    private double calcMean(Vector dataPoints) {
        double sum = 0.0;
        for (int i = 0; i < dataPoints.size(); ++i) {
            sum += dataPoints.get(i);
        }
        return sum / (double)dataPoints.size();
    }
}

