/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.special.Gamma;

public final class BdeMetricCache {
    private final DataSet dataSet;
    private final List<Node> variables;
    private final BayesPm bayesPm;
    private final Map<NodeParentsPair, Double> scores;
    private final Map<NodeParentsPair, Integer> scoreCounts;
    private double[][] observedCounts;

    public BdeMetricCache(DataSet dataSet, BayesPm bayesPm) {
        this.bayesPm = bayesPm;
        this.dataSet = dataSet;
        this.scores = new HashMap<NodeParentsPair, Double>();
        this.scoreCounts = new HashMap<NodeParentsPair, Integer>();
        this.variables = dataSet.getVariables();
    }

    public double scoreLnGam(Node node, Set<Node> parents, BayesPm bayesPmMod, BayesIm bayesIm) {
        int k;
        int j;
        NodeParentsPair nodeAndParents = new NodeParentsPair(node, parents);
        if (this.scores.containsKey(nodeAndParents)) {
            System.out.println(node + " Score came from map--counts not computed.");
            double score = this.scores.get(nodeAndParents);
            return score;
        }
        Node[] parentArray = new Node[parents.size()];
        for (int i = 0; i < parentArray.length; ++i) {
            parentArray[i] = (Node)parents.toArray()[i];
        }
        MlBayesIm bayesImMod = new MlBayesIm(bayesPmMod);
        int numRows = bayesImMod.getNumRows(bayesImMod.getNodeIndex(node));
        this.observedCounts = new double[numRows][];
        double[][] priorProbs = new double[numRows][];
        double[] observedCountsRowSum = new double[numRows];
        double[] priorProbsRowSum = new double[numRows];
        int numCols = this.bayesPm.getNumCategories(node);
        for (j = 0; j < numRows; ++j) {
            observedCountsRowSum[j] = 0.0;
            priorProbsRowSum[j] = 0.0;
            this.observedCounts[j] = new double[numCols];
            priorProbs[j] = new double[numCols];
        }
        if (bayesIm == null) {
            this.computeObservedCounts(node, parentArray);
        } else {
            this.computeObservedCountsMD(node, bayesPmMod, bayesIm);
        }
        for (j = 0; j < numRows; ++j) {
            for (k = 0; k < numCols; ++k) {
                priorProbs[j][k] = 1.0 / (double)(numRows * numCols);
            }
        }
        for (j = 0; j < numRows; ++j) {
            for (k = 0; k < numCols; ++k) {
                int n = j;
                observedCountsRowSum[n] = observedCountsRowSum[n] + this.observedCounts[j][k];
                int n2 = j;
                priorProbsRowSum[n2] = priorProbsRowSum[n2] + priorProbs[j][k];
            }
        }
        double sum = 0.0;
        for (int j2 = 0; j2 < numRows; ++j2) {
            try {
                double numerator = Gamma.logGamma(priorProbsRowSum[j2]);
                double denom = Gamma.logGamma(priorProbsRowSum[j2] + observedCountsRowSum[j2]);
                sum += numerator - denom;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            double sumk = 0.0;
            for (int k2 = 0; k2 < numCols; ++k2) {
                try {
                    sumk += Gamma.logGamma(priorProbs[j2][k2] + this.observedCounts[j2][k2]) - Gamma.logGamma(priorProbs[j2][k2]);
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            sum += sumk;
        }
        double score = sum;
        Double scoreDouble = score;
        this.scores.put(nodeAndParents, scoreDouble);
        return score;
    }

    private void computeObservedCountsMD(Node node, BayesPm bayesPmTest, BayesIm bayesIm) {
        int row;
        int numCases = this.dataSet.getNumRows();
        int numVariables = this.variables.size();
        Graph graph = bayesIm.getBayesPm().getDag();
        RowSummingExactUpdater rseu = new RowSummingExactUpdater(bayesIm);
        String name = node.getName();
        int index = this.getVarIndex(name);
        int numCols = bayesPmTest.getNumCategories(node);
        MlBayesIm bayesImTest = new MlBayesIm(bayesPmTest);
        int nodeIndexImTest = bayesImTest.getNodeIndex(node);
        int numRows = bayesImTest.getNumRows(nodeIndexImTest);
        int varIndex = bayesImTest.getNodeIndex(node);
        int[] parentVarIndices = bayesImTest.getParents(varIndex);
        if (parentVarIndices.length == 0) {
            for (int col = 0; col < numCols; ++col) {
                this.observedCounts[0][col] = 0.0;
            }
            for (int i = 0; i < numCases; ++i) {
                if (this.dataSet.getInt(i, index) != -99) {
                    double[] dArray = this.observedCounts[0];
                    int n = this.dataSet.getInt(i, index);
                    dArray[n] = dArray[n] + 1.0;
                    continue;
                }
                Evidence evidenceThisCase = Evidence.tautology(bayesIm);
                boolean existsEvidence = false;
                for (int k = 0; k < numVariables; ++k) {
                    if (this.dataSet.getInt(i, k) == -99) continue;
                    Node otherVar = this.variables.get(k);
                    existsEvidence = true;
                    String string = otherVar.getName();
                    Node otherNode = graph.getNode(string);
                    int otherIndex = bayesIm.getNodeIndex(otherNode);
                    evidenceThisCase.getProposition().setCategory(otherIndex, this.dataSet.getInt(i, k));
                }
                rseu.setEvidence(evidenceThisCase);
                int m = 0;
                while (m < numCols) {
                    double p = rseu.getMarginal(varIndex, m);
                    if (Double.isNaN(p)) {
                        System.out.println("esixtsEvidence = " + existsEvidence);
                        System.out.println("getMarginal returns NaN for ");
                        System.exit(0);
                    }
                    double[] dArray = this.observedCounts[0];
                    int n = m++;
                    dArray[n] = dArray[n] + p;
                }
            }
        } else {
            for (row = 0; row < numRows; ++row) {
                int[] parValues = bayesImTest.getParentValues(varIndex, row);
                for (int col = 0; col < numCols; ++col) {
                    this.observedCounts[row][col] = 0.0;
                }
                for (int i = 0; i < numCases; ++i) {
                    boolean parentMatch = true;
                    for (int p = 0; p < parentVarIndices.length; ++p) {
                        if (parValues[p] == this.dataSet.getInt(i, parentVarIndices[p]) || this.dataSet.getInt(i, parentVarIndices[p]) == -99) continue;
                        parentMatch = false;
                        break;
                    }
                    if (!parentMatch) continue;
                    boolean parentMissing = false;
                    for (int parentVarIndice : parentVarIndices) {
                        if (this.dataSet.getInt(i, parentVarIndice) != -99) continue;
                        parentMissing = true;
                        break;
                    }
                    if (this.dataSet.getInt(i, index) != -99 && !parentMissing) {
                        double[] dArray = this.observedCounts[row];
                        int n = this.dataSet.getInt(i, index);
                        dArray[n] = dArray[n] + 1.0;
                        continue;
                    }
                    Evidence evidence = Evidence.tautology(bayesIm);
                    rseu.setEvidence(evidence);
                    int[] parPlusChildIndices = new int[parentVarIndices.length + 1];
                    int[] parPlusChildValues = new int[parentVarIndices.length + 1];
                    parPlusChildIndices[0] = varIndex;
                    for (int pc = 1; pc < parPlusChildIndices.length; ++pc) {
                        parPlusChildIndices[pc] = parentVarIndices[pc - 1];
                        parPlusChildValues[pc] = parValues[pc - 1];
                    }
                    int m = 0;
                    while (m < numCols) {
                        parPlusChildValues[0] = m;
                        double p = rseu.getJointMarginal(parPlusChildIndices, parPlusChildValues);
                        if (Double.isNaN(p)) {
                            System.out.println("existsEvidence = false");
                            System.out.println("getJointMarginal returns NaN for ");
                            System.exit(0);
                        }
                        double[] dArray = this.observedCounts[row];
                        int n = m++;
                        dArray[n] = dArray[n] + p;
                    }
                }
            }
        }
        for (row = 0; row < numRows; ++row) {
            int col = 0;
            while (col < numCols) {
                double[] dArray = this.observedCounts[row];
                int n = col++;
                dArray[n] = dArray[n] * (double)numCases;
            }
        }
    }

    private void computeObservedCounts(Node node, Node[] parentArray) {
        int i;
        String name = node.getName();
        int index = this.getVarIndex(name);
        int numCols = this.bayesPm.getNumCategories(node);
        int[] parentVarIndices = new int[parentArray.length];
        int[] parDims = new int[parentArray.length];
        int numRows = 1;
        for (i = 0; i < parentArray.length; ++i) {
            int numCats;
            String parName = parentArray[i].getName();
            parentVarIndices[i] = this.getVarIndex(parName);
            parDims[i] = numCats = this.bayesPm.getNumCategories(parentArray[i]);
            numRows *= numCats;
        }
        this.observedCounts = new double[numRows][];
        for (int j = 0; j < numRows; ++j) {
            this.observedCounts[j] = new double[numCols];
        }
        if (parentArray.length == 0) {
            for (int col = 0; col < numCols; ++col) {
                this.observedCounts[0][col] = 0.0;
            }
            for (i = 0; i < this.dataSet.getNumRows(); ++i) {
                double[] dArray = this.observedCounts[0];
                int n = this.dataSet.getInt(i, index);
                dArray[n] = dArray[n] + 1.0;
            }
        } else {
            for (int row = 0; row < numRows; ++row) {
                int i2;
                int[] parValues = new int[parDims.length];
                int thisRow = row;
                for (i2 = parDims.length - 1; i2 >= 0; --i2) {
                    parValues[i2] = thisRow % parDims[i2];
                    thisRow /= parDims[i2];
                }
                for (int col = 0; col < numCols; ++col) {
                    this.observedCounts[row][col] = 0.0;
                }
                for (i2 = 0; i2 < this.dataSet.getNumRows(); ++i2) {
                    boolean parentMatch = true;
                    for (int p = 0; p < parentVarIndices.length; ++p) {
                        if (parValues[p] == this.dataSet.getInt(i2, parentVarIndices[p])) continue;
                        parentMatch = false;
                        break;
                    }
                    if (!parentMatch) continue;
                    double[] dArray = this.observedCounts[row];
                    int n = this.dataSet.getInt(i2, index);
                    dArray[n] = dArray[n] + 1.0;
                }
            }
        }
    }

    private int getVarIndex(String name) {
        return this.dataSet.getColumn(this.dataSet.getVariable(name));
    }

    public double[][] getObservedCounts(Node node, BayesPm bayesPm, BayesIm bayesIm) {
        System.out.println("In getObservedCounts for node = " + node.getName());
        MlBayesIm pmIm = new MlBayesIm(bayesPm);
        int inode = pmIm.getNodeIndex(node);
        int numPars = pmIm.getNumParents(inode);
        int numRows = pmIm.getNumRows(inode);
        System.out.println("Has " + numPars + " parents " + numRows + " rows in CPT.");
        this.computeObservedCountsMD(node, bayesPm, bayesIm);
        return this.observedCounts;
    }

    public int getScoreCount(Node node, Set<Node> parents) {
        int count;
        NodeParentsPair nodeParents = new NodeParentsPair(node, parents);
        if (this.scoreCounts.containsKey(nodeParents)) {
            System.out.println(node + " Score came from map.");
            count = this.scoreCounts.get(nodeParents);
        } else {
            count = nodeParents.calcCount();
            Integer countInt = count;
            this.scoreCounts.put(nodeParents, countInt);
        }
        return count;
    }

    private static final class NodeParentsPair {
        private final Node node;
        private final Set<Node> parents;

        public NodeParentsPair(Node node, Set<Node> parents) {
            this.node = node;
            this.parents = parents;
        }

        public int calcCount() {
            return this.parents.size() + 1;
        }

        public int hashCode() {
            int hash = 91;
            hash = 43 * hash + this.node.hashCode();
            hash = 43 * hash + this.parents.hashCode();
            return hash;
        }

        public boolean equals(Object other) {
            if (other == this) {
                return true;
            }
            if (!(other instanceof NodeParentsPair)) {
                return false;
            }
            NodeParentsPair npp = (NodeParentsPair)other;
            return npp.node.equals(this.node) && npp.parents.equals(this.parents);
        }
    }
}

