/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import cern.jet.stat.Gamma;
import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.ProbUtils;

public final class BdeMetric {
    private final DataSet dataSet;
    private final BayesPm bayesPm;
    private BayesIm bayesIm;
    private int[][][] observedCounts;

    public BdeMetric(DataSet dataSet, BayesPm bayesPm) {
        this.dataSet = dataSet;
        this.bayesPm = bayesPm;
    }

    public double score() {
        int k;
        int j;
        int i;
        Dag graph = this.bayesPm.getDag();
        Node[] nodes = new Node[graph.getNumNodes()];
        this.observedCounts = new int[nodes.length][][];
        double[][][] priorProbs = new double[nodes.length][][];
        int[][] observedCountsRowSum = new int[nodes.length][];
        double[][] priorProbsRowSum = new double[nodes.length][];
        this.bayesIm = new MlBayesIm(this.bayesPm);
        for (i = 0; i < nodes.length; ++i) {
            int numRows = this.bayesIm.getNumRows(i);
            this.observedCounts[i] = new int[numRows][];
            priorProbs[i] = new double[numRows][];
            observedCountsRowSum[i] = new int[numRows];
            priorProbsRowSum[i] = new double[numRows];
            for (int j2 = 0; j2 < numRows; ++j2) {
                observedCountsRowSum[i][j2] = 0;
                priorProbsRowSum[i][j2] = 0.0;
                int numCols = this.bayesIm.getNumColumns(i);
                this.observedCounts[i][j2] = new int[numCols];
                priorProbs[i][j2] = new double[numCols];
            }
        }
        this.computeObservedCounts();
        for (i = 0; i < nodes.length; ++i) {
            for (j = 0; j < this.bayesIm.getNumRows(i); ++j) {
                for (k = 0; k < this.bayesIm.getNumColumns(i); ++k) {
                    priorProbs[i][j][k] = 1.0;
                }
            }
        }
        for (i = 0; i < nodes.length; ++i) {
            for (j = 0; j < this.bayesIm.getNumRows(i); ++j) {
                for (k = 0; k < this.bayesIm.getNumColumns(i); ++k) {
                    int[] nArray = observedCountsRowSum[i];
                    int n = j;
                    nArray[n] = nArray[n] + this.observedCounts[i][j][k];
                    double[] dArray = priorProbsRowSum[i];
                    int n2 = j;
                    dArray[n2] = dArray[n2] + priorProbs[i][j][k];
                }
            }
        }
        double product = 1.0;
        int n = nodes.length;
        for (int i2 = 0; i2 < n; ++i2) {
            int qi = this.bayesIm.getNumRows(i2);
            double prodj = 1.0;
            for (int j3 = 0; j3 < qi; ++j3) {
                try {
                    double numerator = Gamma.gamma(priorProbsRowSum[i2][j3]);
                    double denom = Gamma.gamma(priorProbsRowSum[i2][j3] + (double)observedCountsRowSum[i2][j3]);
                    prodj *= numerator / denom;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                int ri = this.bayesIm.getNumColumns(i2);
                double prodk = 1.0;
                for (int k2 = 0; k2 < ri; ++k2) {
                    try {
                        prodk *= Gamma.gamma(priorProbs[i2][j3][k2] + (double)this.observedCounts[i2][j3][k2]) / Gamma.gamma(priorProbs[i2][j3][k2]);
                        continue;
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                prodj *= prodk;
            }
            product *= prodj;
        }
        return product;
    }

    public double scoreLnGam() {
        int k;
        int j;
        int i;
        Dag graph = this.bayesPm.getDag();
        Node[] nodes = new Node[graph.getNumNodes()];
        this.observedCounts = new int[nodes.length][][];
        double[][][] priorProbs = new double[nodes.length][][];
        int[][] observedCountsRowSum = new int[nodes.length][];
        double[][] priorProbsRowSum = new double[nodes.length][];
        this.bayesIm = new MlBayesIm(this.bayesPm);
        for (i = 0; i < nodes.length; ++i) {
            int numRows = this.bayesIm.getNumRows(i);
            this.observedCounts[i] = new int[numRows][];
            priorProbs[i] = new double[numRows][];
            observedCountsRowSum[i] = new int[numRows];
            priorProbsRowSum[i] = new double[numRows];
            for (int j2 = 0; j2 < numRows; ++j2) {
                observedCountsRowSum[i][j2] = 0;
                priorProbsRowSum[i][j2] = 0.0;
                int numCols = this.bayesIm.getNumColumns(i);
                this.observedCounts[i][j2] = new int[numCols];
                priorProbs[i][j2] = new double[numCols];
            }
        }
        this.computeObservedCounts();
        for (i = 0; i < nodes.length; ++i) {
            for (j = 0; j < this.bayesIm.getNumRows(i); ++j) {
                for (k = 0; k < this.bayesIm.getNumColumns(i); ++k) {
                    priorProbs[i][j][k] = 1.0;
                }
            }
        }
        for (i = 0; i < nodes.length; ++i) {
            for (j = 0; j < this.bayesIm.getNumRows(i); ++j) {
                for (k = 0; k < this.bayesIm.getNumColumns(i); ++k) {
                    int[] nArray = observedCountsRowSum[i];
                    int n = j;
                    nArray[n] = nArray[n] + this.observedCounts[i][j][k];
                    double[] dArray = priorProbsRowSum[i];
                    int n2 = j;
                    dArray[n2] = dArray[n2] + priorProbs[i][j][k];
                }
            }
        }
        double sum = 0.0;
        int n = nodes.length;
        for (int i2 = 0; i2 < n; ++i2) {
            int qi = this.bayesIm.getNumRows(i2);
            double sumj = 0.0;
            for (int j3 = 0; j3 < qi; ++j3) {
                try {
                    double numerator = ProbUtils.lngamma(priorProbsRowSum[i2][j3]);
                    double denom = ProbUtils.lngamma(priorProbsRowSum[i2][j3] + (double)observedCountsRowSum[i2][j3]);
                    sumj += numerator - denom;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                int ri = this.bayesIm.getNumColumns(i2);
                double sumk = 0.0;
                for (int k2 = 0; k2 < ri; ++k2) {
                    try {
                        sumk += ProbUtils.lngamma(priorProbs[i2][j3][k2] + (double)this.observedCounts[i2][j3][k2]) - ProbUtils.lngamma(priorProbs[i2][j3][k2]);
                        continue;
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                sumj += sumk;
            }
            sum += sumj;
        }
        return sum;
    }

    private void computeObservedCounts() {
        for (int j = 0; j < this.dataSet.getNumColumns(); ++j) {
            DiscreteVariable var = (DiscreteVariable)this.dataSet.getVariables().get(j);
            String varName = var.getName();
            Node varNode = this.bayesPm.getDag().getNode(varName);
            int varIndex = this.bayesIm.getNodeIndex(varNode);
            int[] parentVarIndices = this.bayesIm.getParents(varIndex);
            if (parentVarIndices.length == 0) {
                for (int col = 0; col < var.getNumCategories(); ++col) {
                    this.observedCounts[j][0][col] = 0;
                }
                for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
                    int[] nArray = this.observedCounts[j][0];
                    int n = this.dataSet.getInt(i, j);
                    nArray[n] = (int)((double)nArray[n] + 1.0);
                }
                continue;
            }
            int numRows = this.bayesIm.getNumRows(varIndex);
            for (int row = 0; row < numRows; ++row) {
                int[] parValues = this.bayesIm.getParentValues(varIndex, row);
                for (int col = 0; col < var.getNumCategories(); ++col) {
                    this.observedCounts[varIndex][row][col] = 0;
                }
                for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
                    boolean parentMatch = true;
                    for (int p = 0; p < parentVarIndices.length; ++p) {
                        if (parValues[p] == this.dataSet.getInt(i, parentVarIndices[p])) continue;
                        parentMatch = false;
                        break;
                    }
                    if (!parentMatch) continue;
                    int[] nArray = this.observedCounts[j][row];
                    int n = this.dataSet.getInt(i, j);
                    nArray[n] = nArray[n] + 1;
                }
            }
        }
    }
}

