/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.search.LocalDiscreteScore;
import edu.cmu.tetrad.search.LocalScoreCache;
import edu.cmu.tetrad.util.ProbUtils;

public class BDeuScore
implements LocalDiscreteScore {
    private final LocalScoreCache localScoreCache = new LocalScoreCache();
    private DataSet dataSet;
    private double samplePrior = 10.0;
    private double structurePrior = 0.001;

    public BDeuScore(DataSet dataSet, double samplePrior, double structurePrior) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.dataSet = dataSet;
        this.samplePrior = samplePrior;
        this.structurePrior = structurePrior;
    }

    @Override
    public double localScore(int i, int[] parents) {
        double oldScore = this.localScoreCache.get(i, parents);
        if (!Double.isNaN(oldScore)) {
            return oldScore;
        }
        int r = this.numCategories(i);
        int[] dims = new int[parents.length];
        for (int p = 0; p < parents.length; ++p) {
            dims[p] = this.numCategories(parents[p]);
        }
        int q = 1;
        for (int p = 0; p < parents.length; ++p) {
            q *= dims[p];
        }
        int[][] n_ijk = new int[q][r];
        int[] n_ij = new int[q];
        int[] values = new int[parents.length];
        for (int n = 0; n < this.sampleSize(); ++n) {
            for (int p = 0; p < parents.length; ++p) {
                int parentValue = this.dataSet().getInt(n, parents[p]);
                if (parentValue == -99) {
                    throw new IllegalStateException("Please remove or impute missing values.");
                }
                values[p] = parentValue;
            }
            int childValue = this.dataSet().getInt(n, i);
            if (childValue == -99) {
                throw new IllegalStateException("Please remove or impute missing values (record " + n + " column " + i + ")");
            }
            int[] nArray = n_ijk[this.getRowIndex(dims, values)];
            int n2 = childValue;
            nArray[n2] = nArray[n2] + 1;
        }
        for (int j = 0; j < q; ++j) {
            for (int k = 0; k < r; ++k) {
                int n = j;
                n_ij[n] = n_ij[n] + n_ijk[j][k];
            }
        }
        double score = (double)((r - 1) * q) * Math.log(this.getStructurePrior());
        for (int j = 0; j < q; ++j) {
            for (int k = 0; k < r; ++k) {
                score += ProbUtils.lngamma(this.getSamplePrior() / (double)(r * q) + (double)n_ijk[j][k]);
            }
            score -= ProbUtils.lngamma(this.getSamplePrior() / (double)q + (double)n_ij[j]);
        }
        score += (double)q * ProbUtils.lngamma(this.getSamplePrior() / (double)q);
        this.localScoreCache.add(i, parents, score -= (double)r * ProbUtils.lngamma(this.getSamplePrior() / (double)(r * q)));
        return score;
    }

    @Override
    public DataSet getDataSet() {
        return this.dataSet;
    }

    private int getRowIndex(int[] dim, int[] values) {
        int rowIndex = 0;
        for (int i = 0; i < dim.length; ++i) {
            rowIndex *= dim[i];
            rowIndex += values[i];
        }
        return rowIndex;
    }

    private int sampleSize() {
        return this.dataSet().getNumRows();
    }

    private int numCategories(int i) {
        return ((DiscreteVariable)this.dataSet().getVariable(i)).getNumCategories();
    }

    private DataSet dataSet() {
        return this.dataSet;
    }

    public double getStructurePrior() {
        return this.structurePrior;
    }

    public double getSamplePrior() {
        return this.samplePrior;
    }

    @Override
    public void setStructurePrior(double structurePrior) {
        this.structurePrior = structurePrior;
    }

    @Override
    public void setSamplePrior(double samplePrior) {
        this.samplePrior = samplePrior;
    }
}

