/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesImDistanceFunction;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.BayesUtils;
import edu.cmu.tetrad.bayes.DirichletBayesIm;
import edu.cmu.tetrad.bayes.DirichletEstimator;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public final class EmBayesEstimator {
    private final BayesPm bayesPm;
    private final DataSet dataSet;
    private DataSet mixedData;
    private List<Node> allVariables;
    private final Node[] nodes;
    private final Graph graph;
    private BayesPm bayesPmObs;
    private BayesIm observedIm;
    private BayesIm estimatedIm;
    private double[][][] estimatedCounts;
    private double[][] estimatedCountsDenom;
    private double[][][] condProbs;

    public EmBayesEstimator(BayesPm bayesPm, DataSet dataSet) {
        if (bayesPm == null) {
            throw new NullPointerException();
        }
        if (dataSet == null) {
            throw new NullPointerException();
        }
        ArrayList<Node> observedVars = new ArrayList<Node>();
        this.bayesPm = bayesPm;
        this.dataSet = dataSet;
        this.graph = bayesPm.getDag();
        this.nodes = new Node[this.graph.getNumNodes()];
        Iterator<Node> it = this.graph.getNodes().iterator();
        for (int i = 0; i < this.nodes.length; ++i) {
            this.nodes[i] = it.next();
        }
        for (Node node : this.nodes) {
            if (node.getNodeType() != NodeType.MEASURED) continue;
            observedVars.add(bayesPm.getVariable(node));
        }
        for (Node observedVar : observedVars) {
            try {
                this.dataSet.getVariable(observedVar.getName());
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Some observed ar in the Bayes net is not in the dataset: " + observedVar);
            }
        }
        this.findBayesNetObserved();
        this.initialize();
    }

    public EmBayesEstimator(BayesIm inputBayesIm, DataSet dataSet) {
        this(inputBayesIm.getBayesPm(), dataSet);
    }

    private void initialize() {
        DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(this.bayesPmObs, 0.5);
        this.observedIm = DirichletEstimator.estimate(prior, this.dataSet);
        int numFullCases = this.dataSet.getNumRows();
        LinkedList<Node> variables = new LinkedList<Node>();
        for (Node node : this.nodes) {
            if (node.getNodeType() == NodeType.LATENT) {
                int numCategories = this.bayesPm.getNumCategories(node);
                DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
                latentVar.setNodeType(NodeType.LATENT);
                variables.add(latentVar);
                continue;
            }
            String name = this.bayesPm.getVariable(node).getName();
            Node variable = this.dataSet.getVariable(name);
            variables.add(variable);
        }
        BoxDataSet dsMixed = new BoxDataSet(new DoubleDataBox(numFullCases, variables.size()), variables);
        for (int j = 0; j < this.nodes.length; ++j) {
            if (this.nodes[j].getNodeType() == NodeType.LATENT) {
                for (int i = 0; i < numFullCases; ++i) {
                    dsMixed.setInt(i, j, -99);
                }
                continue;
            }
            String name = this.bayesPm.getVariable(this.nodes[j]).getName();
            Node variable = this.dataSet.getVariable(name);
            int index = this.dataSet.getColumn(variable);
            for (int i = 0; i < numFullCases; ++i) {
                dsMixed.setInt(i, j, this.dataSet.getInt(i, index));
            }
        }
        this.mixedData = dsMixed;
        this.allVariables = this.mixedData.getVariables();
        this.estimateIM(this.bayesPm, this.mixedData);
        this.estimatedCounts = new double[this.nodes.length][][];
        this.estimatedCountsDenom = new double[this.nodes.length][];
        this.condProbs = new double[this.nodes.length][][];
        for (int i = 0; i < this.nodes.length; ++i) {
            int numRows = this.estimatedIm.getNumRows(i);
            this.estimatedCounts[i] = new double[numRows][];
            this.estimatedCountsDenom[i] = new double[numRows];
            this.condProbs[i] = new double[numRows][];
            for (int j = 0; j < this.estimatedIm.getNumRows(i); ++j) {
                int numCols = this.estimatedIm.getNumColumns(i);
                this.estimatedCounts[i][j] = new double[numCols];
                this.condProbs[i][j] = new double[numCols];
            }
        }
    }

    private void expectation(BayesIm inputBayesIm) {
        int numRows;
        int m;
        int numCases = this.mixedData.getNumRows();
        int numVariables = this.allVariables.size();
        RowSummingExactUpdater rseu = new RowSummingExactUpdater(inputBayesIm);
        for (int j = 0; j < numVariables; ++j) {
            DiscreteVariable var = (DiscreteVariable)this.allVariables.get(j);
            String varName = var.getName();
            Node varNode = this.graph.getNode(varName);
            int varIndex = inputBayesIm.getNodeIndex(varNode);
            int[] parentVarIndices = inputBayesIm.getParents(varIndex);
            if (parentVarIndices.length == 0) {
                for (int col = 0; col < var.getNumCategories(); ++col) {
                    this.estimatedCounts[j][0][col] = 0.0;
                }
                for (int i = 0; i < numCases; ++i) {
                    if (this.mixedData.getInt(i, j) != -99) {
                        double[] dArray = this.estimatedCounts[j][0];
                        int n = this.mixedData.getInt(i, j);
                        dArray[n] = dArray[n] + 1.0;
                        continue;
                    }
                    Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);
                    boolean existsEvidence = false;
                    for (int k = 0; k < numVariables; ++k) {
                        if (k == j) continue;
                        Node otherVar = this.allVariables.get(k);
                        if (this.mixedData.getInt(i, k) == -99) continue;
                        existsEvidence = true;
                        String otherVarName = otherVar.getName();
                        Node otherNode = this.graph.getNode(otherVarName);
                        int otherIndex = inputBayesIm.getNodeIndex(otherNode);
                        evidenceThisCase.getProposition().setCategory(otherIndex, this.mixedData.getInt(i, k));
                    }
                    if (!existsEvidence) continue;
                    rseu.setEvidence(evidenceThisCase);
                    for (m = 0; m < var.getNumCategories(); ++m) {
                        double[] dArray = this.estimatedCounts[j][0];
                        int n = m;
                        dArray[n] = dArray[n] + rseu.getMarginal(varIndex, m);
                    }
                }
                continue;
            }
            numRows = inputBayesIm.getNumRows(varIndex);
            for (int row = 0; row < numRows; ++row) {
                int[] parValues = inputBayesIm.getParentValues(varIndex, row);
                this.estimatedCountsDenom[varIndex][row] = 0.0;
                for (int col = 0; col < var.getNumCategories(); ++col) {
                    this.estimatedCounts[varIndex][row][col] = 0.0;
                }
                for (int i = 0; i < numCases; ++i) {
                    boolean parentMatch = true;
                    for (int p = 0; p < parentVarIndices.length; ++p) {
                        if (parValues[p] == this.mixedData.getInt(i, parentVarIndices[p]) || this.mixedData.getInt(i, parentVarIndices[p]) == -99) continue;
                        parentMatch = false;
                        break;
                    }
                    if (!parentMatch) continue;
                    boolean parentMissing = false;
                    for (int parentVarIndice : parentVarIndices) {
                        if (this.mixedData.getInt(i, parentVarIndice) != -99) continue;
                        parentMissing = true;
                        break;
                    }
                    if (this.mixedData.getInt(i, j) != -99 && !parentMissing) {
                        double[] dArray = this.estimatedCounts[j][row];
                        int n = this.mixedData.getInt(i, j);
                        dArray[n] = dArray[n] + 1.0;
                        double[] dArray2 = this.estimatedCountsDenom[j];
                        int n2 = row;
                        dArray2[n2] = dArray2[n2] + 1.0;
                        continue;
                    }
                    Evidence.tautology(inputBayesIm);
                }
            }
        }
        MlBayesIm outputBayesIm = new MlBayesIm(this.bayesPm);
        for (int j = 0; j < this.nodes.length; ++j) {
            DiscreteVariable var = (DiscreteVariable)this.allVariables.get(j);
            String varName = var.getName();
            Node varNode = this.graph.getNode(varName);
            int varIndex = inputBayesIm.getNodeIndex(varNode);
            numRows = inputBayesIm.getNumRows(j);
            int numCols = inputBayesIm.getNumColumns(j);
            if (numRows == 1) {
                int m2;
                double sum = 0.0;
                for (m2 = 0; m2 < numCols; ++m2) {
                    sum += this.estimatedCounts[j][0][m2];
                }
                for (m2 = 0; m2 < numCols; ++m2) {
                    this.condProbs[j][0][m2] = this.estimatedCounts[j][0][m2] / sum;
                    outputBayesIm.setProbability(varIndex, 0, m2, this.condProbs[j][0][m2]);
                }
                continue;
            }
            for (int row = 0; row < numRows; ++row) {
                for (m = 0; m < numCols; ++m) {
                    this.condProbs[j][row][m] = this.estimatedCountsDenom[j][row] != 0.0 ? this.estimatedCounts[j][row][m] / this.estimatedCountsDenom[j][row] : Double.NaN;
                    outputBayesIm.setProbability(varIndex, row, m, this.condProbs[j][row][m]);
                }
            }
        }
    }

    public BayesIm maximization(double threshhold) {
        double distance = Double.MAX_VALUE;
        BayesIm oldBayesIm = this.estimatedIm;
        BayesIm newBayesIm = null;
        while (Double.isNaN(distance) || distance > threshhold) {
            this.expectation(oldBayesIm);
            newBayesIm = this.getEstimatedIm();
            distance = BayesImDistanceFunction.distance(newBayesIm, oldBayesIm);
            oldBayesIm = newBayesIm;
        }
        return newBayesIm;
    }

    private void findBayesNetObserved() {
        Dag dagObs = new Dag(this.graph);
        for (Node node : this.nodes) {
            if (node.getNodeType() != NodeType.LATENT) continue;
            dagObs.removeNode(node);
        }
        this.bayesPmObs = new BayesPm(dagObs, this.bayesPm);
    }

    private void estimateIM(BayesPm bayesPm, DataSet dataSet) {
        if (bayesPm == null) {
            throw new NullPointerException();
        }
        if (dataSet == null) {
            throw new NullPointerException();
        }
        BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
        this.estimatedIm = new MlBayesIm(bayesPm, 1);
        int numNodes = this.estimatedIm.getNumNodes();
        for (int node = 0; node < numNodes; ++node) {
            int numRows = this.estimatedIm.getNumRows(node);
            int numCols = this.estimatedIm.getNumColumns(node);
            int[] parentVarIndices = this.estimatedIm.getParents(node);
            if (this.nodes[node].getNodeType() == NodeType.LATENT) continue;
            Node nodeObs = this.observedIm.getNode(this.nodes[node].getName());
            int nodeObsIndex = this.observedIm.getNodeIndex(nodeObs);
            boolean anyParentLatent = false;
            for (int parentVarIndice : parentVarIndices) {
                if (this.nodes[parentVarIndice].getNodeType() != NodeType.LATENT) continue;
                anyParentLatent = true;
                break;
            }
            if (anyParentLatent) continue;
            for (int row = 0; row < numRows; ++row) {
                for (int col = 0; col < numCols; ++col) {
                    double p = this.observedIm.getProbability(nodeObsIndex, row, col);
                    this.estimatedIm.setProbability(node, row, col, p);
                }
            }
        }
    }

    public DataSet getMixedDataSet() {
        return this.mixedData;
    }

    public BayesIm getEstimatedIm() {
        return this.estimatedIm;
    }
}

