/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CovarianceMatrixOnTheFly;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.score.SemBicScore;
import edu.cmu.tetrad.util.Matrix;

public class MixtureModel {
    private final DataSet data;
    private final int[] cases;
    private final int[] caseCounts;
    private final double[][] dataArray;
    private final double[][] meansArray;
    private final double[] weightsArray;
    private final double[][] gammaArray;
    private final Matrix[] variancesArray;
    private final int numModels;

    public MixtureModel(DataSet data, double[][] dataArray, double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, double[][] gammaArray) {
        int i;
        this.data = data;
        this.dataArray = dataArray;
        this.meansArray = meansArray;
        this.weightsArray = weightsArray;
        this.variancesArray = variancesArray;
        this.numModels = weightsArray.length;
        this.gammaArray = gammaArray;
        this.cases = new int[data.getNumRows()];
        for (i = 0; i < this.cases.length; ++i) {
            this.cases[i] = this.getDistribution(i);
        }
        this.caseCounts = new int[this.numModels];
        for (i = 0; i < this.numModels; ++i) {
            this.caseCounts[i] = 0;
        }
        block2: for (int aCase : this.cases) {
            for (int j = 0; j < this.numModels; ++j) {
                if (aCase != j) continue;
                int n = j;
                this.caseCounts[n] = this.caseCounts[n] + 1;
                continue block2;
            }
        }
    }

    public double[][] getData() {
        return this.dataArray;
    }

    public double[][] getMeans() {
        return this.meansArray;
    }

    public double[] getWeights() {
        return this.weightsArray;
    }

    public Matrix[] getVariances() {
        return this.variancesArray;
    }

    public int[] getCases() {
        return this.cases;
    }

    public int getDistribution(int caseNum) {
        int dist = 0;
        double highest = 0.0;
        for (int i = 0; i < this.numModels; ++i) {
            if (!(this.gammaArray[i][caseNum] > highest)) continue;
            highest = this.gammaArray[i][caseNum];
            dist = i;
        }
        return dist;
    }

    public DataSet[] getDemixedData() {
        DoubleDataBox[] dataBoxes = new DoubleDataBox[this.numModels];
        int[] caseIndices = new int[this.numModels];
        for (int i = 0; i < this.numModels; ++i) {
            dataBoxes[i] = new DoubleDataBox(this.caseCounts[i], this.data.getNumColumns());
            caseIndices[i] = 0;
        }
        for (int i = 0; i < this.cases.length; ++i) {
            int index = this.cases[i];
            DoubleDataBox box = dataBoxes[index];
            int count = caseIndices[index];
            for (int j = 0; j < this.data.getNumColumns(); ++j) {
                box.set(count, j, this.data.getDouble(i, j));
            }
            dataBoxes[index] = box;
            caseIndices[index] = count + 1;
        }
        DataSet[] dataSets = new DataSet[this.numModels];
        for (int i = 0; i < this.numModels; ++i) {
            dataSets[i] = new BoxDataSet(dataBoxes[i], this.data.getVariables());
        }
        return dataSets;
    }

    public double[] searchDemixedData() {
        DataSet[] dataSets = this.getDemixedData();
        double[] bicScores = new double[this.numModels];
        for (int i = 0; i < this.numModels; ++i) {
            double bic;
            DataSet dataSet = dataSets[i];
            SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
            score.setPenaltyDiscount(2.0);
            Fges fges = new Fges(score);
            fges.search();
            bicScores[i] = bic = fges.getModelScore();
        }
        return bicScores;
    }
}

