/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.score;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataBox;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.VerticalIntDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.DiscreteScore;
import java.text.DecimalFormat;
import java.util.List;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.FastMath;

public class BdeuScore
implements DiscreteScore {
    private final int[][] data;
    private final int sampleSize;
    private final int[] numCategories;
    private final DataSet dataSet;
    private List<Node> variables;
    private double samplePrior = 1.0;
    private double structurePrior = 1.0;

    public BdeuScore(DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("Data was not provided.");
        }
        this.dataSet = dataSet;
        if (dataSet instanceof BoxDataSet && ((BoxDataSet)dataSet).getDataBox() instanceof VerticalIntDataBox) {
            DataBox dataBox = ((BoxDataSet)dataSet).getDataBox();
            this.variables = dataSet.getVariables();
            VerticalIntDataBox box = (VerticalIntDataBox)dataBox;
            this.data = box.getVariableVectors();
            this.sampleSize = box.numRows();
        } else {
            this.data = new int[dataSet.getNumColumns()][];
            this.variables = dataSet.getVariables();
            for (int j = 0; j < dataSet.getNumColumns(); ++j) {
                this.data[j] = new int[dataSet.getNumRows()];
                for (int i = 0; i < dataSet.getNumRows(); ++i) {
                    this.data[j][i] = dataSet.getInt(i, j);
                }
            }
            this.sampleSize = dataSet.getNumRows();
        }
        List<Node> variables = dataSet.getVariables();
        this.numCategories = new int[variables.size()];
        for (int i = 0; i < variables.size(); ++i) {
            this.numCategories[i] = this.getVariable(i).getNumCategories();
        }
    }

    private static int getRowIndex(int[] dim, int[] values) {
        int rowIndex = 0;
        for (int i = 0; i < dim.length; ++i) {
            rowIndex *= dim[i];
            rowIndex += values[i];
        }
        return rowIndex;
    }

    @Override
    public double localScore(int node, int[] parents) {
        int c = this.numCategories[node];
        int[] dims = new int[parents.length];
        for (int p = 0; p < parents.length; ++p) {
            dims[p] = this.numCategories[parents[p]];
        }
        int r = 1;
        for (int p = 0; p < parents.length; ++p) {
            r *= dims[p];
        }
        int[][] n_jk = new int[r][c];
        int[] n_j = new int[r];
        int[] parentValues = new int[parents.length];
        int[][] myParents = new int[parents.length][];
        for (int i = 0; i < parents.length; ++i) {
            myParents[i] = this.data[parents[i]];
        }
        int[] myChild = this.data[node];
        int N = 0;
        block3: for (int i = 0; i < this.sampleSize; ++i) {
            for (int p = 0; p < parents.length; ++p) {
                if (myParents[p][i] == -99) continue block3;
                parentValues[p] = myParents[p][i];
            }
            int childValue = myChild[i];
            if (childValue == -99) continue;
            int rowIndex = BdeuScore.getRowIndex(dims, parentValues);
            int[] nArray = n_jk[rowIndex];
            int n = childValue;
            nArray[n] = nArray[n] + 1;
            int n2 = rowIndex;
            n_j[n2] = n_j[n2] + 1;
            ++N;
        }
        double score = 0.0;
        score += this.getPriorForStructure(parents.length, N);
        double cellPrior = this.getSamplePrior() / (double)(c * r);
        double rowPrior = this.getSamplePrior() / (double)r;
        for (int j = 0; j < r; ++j) {
            score -= Gamma.logGamma(rowPrior + (double)n_j[j]);
            for (int k = 0; k < c; ++k) {
                score += Gamma.logGamma(cellPrior + (double)n_jk[j][k]);
            }
        }
        score += (double)r * Gamma.logGamma(rowPrior);
        if (Double.isNaN(score -= (double)(c * r) * Gamma.logGamma(cellPrior)) || Double.isInfinite(score)) {
            return Double.NaN;
        }
        return score;
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    void setVariables(List<Node> variables) {
        for (int i = 0; i < variables.size(); ++i) {
            if (variables.get(i).getName().equals(this.variables.get(i).getName())) continue;
            throw new IllegalArgumentException("Variable in index " + (i + 1) + " does not have the same name as the variable being substituted for it.");
        }
        this.variables = variables;
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    @Override
    public DataSet getDataSet() {
        return this.dataSet;
    }

    public double getStructurePrior() {
        return this.structurePrior;
    }

    @Override
    public void setStructurePrior(double structurePrior) {
        this.structurePrior = structurePrior;
    }

    public double getSamplePrior() {
        return this.samplePrior;
    }

    @Override
    public void setSamplePrior(double samplePrior) {
        this.samplePrior = samplePrior;
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.00");
        return "BDeu Score SampP " + nf.format(this.samplePrior) + " StuctP " + nf.format(this.structurePrior);
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.sampleSize));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        throw new UnsupportedOperationException("The BDeu score does not implement a 'determines' method.");
    }

    private DiscreteVariable getVariable(int i) {
        return (DiscreteVariable)this.variables.get(i);
    }

    private double getPriorForStructure(int numParents, int N) {
        double e = this.getStructurePrior();
        if (e == 0.0) {
            return 0.0;
        }
        int vm = N - 1;
        return (double)numParents * FastMath.log(e / (double)vm) + (double)(vm - numParents) * FastMath.log(1.0 - e / (double)vm);
    }
}

