/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataBox;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.VerticalIntDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.LocalDiscreteScore;
import java.text.DecimalFormat;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class DirichletScore
implements LocalDiscreteScore {
    private final List<Node> variables;
    private final int[][] data;
    private final int sampleSize;
    private double samplePrior = 1.0;
    private double structurePrior = 1.0;
    private final int[] numCategories;
    private double lastBumpThreshold;

    public DirichletScore(DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        if (dataSet instanceof BoxDataSet) {
            DataBox dataBox = ((BoxDataSet)dataSet).getDataBox();
            this.variables = dataSet.getVariables();
            if (!(((BoxDataSet)dataSet).getDataBox() instanceof VerticalIntDataBox)) {
                throw new IllegalArgumentException();
            }
            VerticalIntDataBox box = (VerticalIntDataBox)dataBox;
            this.data = box.getVariableVectors();
        } else {
            this.data = new int[dataSet.getNumColumns()][];
            this.variables = dataSet.getVariables();
            for (int j = 0; j < dataSet.getNumColumns(); ++j) {
                this.data[j] = new int[dataSet.getNumRows()];
                for (int i = 0; i < dataSet.getNumRows(); ++i) {
                    this.data[j][i] = dataSet.getInt(i, j);
                }
            }
        }
        this.sampleSize = dataSet.getNumRows();
        List<Node> variables = dataSet.getVariables();
        this.numCategories = new int[variables.size()];
        for (int i = 0; i < variables.size(); ++i) {
            this.numCategories[i] = this.getVariable(i).getNumCategories();
        }
    }

    private DiscreteVariable getVariable(int i) {
        return (DiscreteVariable)this.variables.get(i);
    }

    @Override
    public double localScore(int node, int[] parents) {
        int r = this.numCategories[node];
        int[] dims = new int[parents.length];
        for (int p = 0; p < parents.length; ++p) {
            dims[p] = this.numCategories[parents[p]];
        }
        int q = 1;
        for (int p = 0; p < parents.length; ++p) {
            q *= dims[p];
        }
        int[][] n_jk = new int[q][r];
        int[] n_j = new int[q];
        int[] parentValues = new int[parents.length];
        int[][] myParents = new int[parents.length][];
        for (int i = 0; i < parents.length; ++i) {
            myParents[i] = this.data[parents[i]];
        }
        int[] myChild = this.data[node];
        for (int i = 0; i < this.sampleSize; ++i) {
            for (int p = 0; p < parents.length; ++p) {
                parentValues[p] = myParents[p][i];
            }
            int childValue = myChild[i];
            if (childValue == -99) {
                throw new IllegalStateException("Please remove or impute missing values (record " + i + " column " + i + ")");
            }
            int rowIndex = DirichletScore.getRowIndex(dims, parentValues);
            int[] nArray = n_jk[rowIndex];
            int n = childValue;
            nArray[n] = nArray[n] + 1;
            int n2 = rowIndex;
            n_j[n2] = n_j[n2] + 1;
        }
        double score = 0.0;
        double cellPrior = this.getSamplePrior();
        double rowPrior = (double)r * this.getSamplePrior();
        for (int j = 0; j < q; ++j) {
            double rowSum = rowPrior + (double)n_j[j];
            int cellCount = 0;
            double rowScore = 0.0;
            for (int k = 0; k < r; ++k) {
                double alpha = cellPrior + (double)n_jk[j][k];
                double pk = alpha / rowSum;
                if (Double.isInfinite(pk)) continue;
                double _score = (alpha - 1.0) * FastMath.log(pk);
                rowScore += _score;
                ++cellCount;
            }
            if (rowScore == 0.0) continue;
            score += rowScore;
            score -= (double)(2 * cellCount);
        }
        this.lastBumpThreshold = 0.01;
        if (Double.isNaN(score) || Double.isInfinite(score)) {
            return Double.NaN;
        }
        return score;
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public double localScoreDiff(int x, int y) {
        return this.localScore(y, x) - this.localScore(y);
    }

    int[] append(int[] parents, int extra) {
        int[] all = new int[parents.length + 1];
        System.arraycopy(parents, 0, all, 0, parents.length);
        all[parents.length] = extra;
        return all;
    }

    @Override
    public double localScore(int node, int parent) {
        return this.localScore(node, new int[]{parent});
    }

    @Override
    public double localScore(int node) {
        return this.localScore(node, new int[0]);
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > this.lastBumpThreshold;
    }

    @Override
    public DataSet getDataSet() {
        throw new UnsupportedOperationException();
    }

    private static int getRowIndex(int[] dim, int[] values) {
        int rowIndex = 0;
        for (int i = 0; i < dim.length; ++i) {
            rowIndex *= dim[i];
            rowIndex += values[i];
        }
        return rowIndex;
    }

    public double getStructurePrior() {
        return this.structurePrior;
    }

    public double getSamplePrior() {
        return this.samplePrior;
    }

    @Override
    public void setStructurePrior(double structurePrior) {
        this.structurePrior = structurePrior;
    }

    @Override
    public void setSamplePrior(double samplePrior) {
        this.samplePrior = samplePrior;
    }

    @Override
    public Node getVariable(String targetName) {
        for (Node node : this.variables) {
            if (!node.getName().equals(targetName)) continue;
            return node;
        }
        return null;
    }

    @Override
    public int getMaxDegree() {
        return 1000;
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        return false;
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.00");
        return "Dirichlet Score StructP " + nf.format(this.structurePrior) + " SampP " + nf.format(this.samplePrior);
    }
}

