/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.BayesUpdater;
import edu.cmu.tetrad.bayes.BayesUtils;
import edu.cmu.tetrad.bayes.DiscreteProbs;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.IntAveDataSetProbs;
import edu.cmu.tetrad.bayes.Proposition;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.VariableSource;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Node;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

public final class OnTheFlyMarginalCalculator
implements BayesUpdater,
VariableSource {
    static final long serialVersionUID = 23L;
    private BayesPm bayesPm;
    private Node[] nodes;
    private int[][] parents;
    private int[][] parentDims;
    private DataSet dataSet;
    private Evidence evidence;
    private transient DiscreteProbs discreteProbs;

    public OnTheFlyMarginalCalculator(BayesPm bayesPm, DataSet dataSet) throws IllegalArgumentException {
        if (bayesPm == null) {
            throw new NullPointerException();
        }
        if (dataSet == null) {
            throw new NullPointerException();
        }
        BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
        this.bayesPm = new BayesPm(bayesPm);
        Dag graph = bayesPm.getDag();
        this.nodes = graph.getNodes().toArray(new Node[0]);
        this.initialize();
        List<Node> variables = this.getVariables();
        this.dataSet = dataSet.subsetColumns(variables);
        this.evidence = new Evidence(Proposition.tautology(this));
    }

    public static OnTheFlyMarginalCalculator serializableInstance() {
        return new OnTheFlyMarginalCalculator(BayesPm.serializableInstance(), DataUtils.discreteSerializableInstance());
    }

    @Override
    public void setEvidence(Evidence evidence) {
        if (evidence == null) {
            throw new NullPointerException();
        }
        if (evidence.getProposition().getVariableSource() != this) {
            throw new IllegalArgumentException("Can only take evidence for this particular object; please convert the evidence.");
        }
        this.evidence = evidence;
    }

    public int getNodeIndex(Node node) {
        for (int i = 0; i < this.nodes.length; ++i) {
            if (node != this.nodes[i]) continue;
            return i;
        }
        return -1;
    }

    @Override
    public List<Node> getVariables() {
        LinkedList<Node> variables = new LinkedList<Node>();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            Node node = this.getNode(i);
            variables.add(this.bayesPm.getVariable(node));
        }
        return variables;
    }

    @Override
    public List<String> getVariableNames() {
        LinkedList<String> variableNames = new LinkedList<String>();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            Node node = this.getNode(i);
            variableNames.add(this.bayesPm.getVariable(node).getName());
        }
        return variableNames;
    }

    public int[] getParentDims(int nodeIndex) {
        int[] dims = this.parentDims[nodeIndex];
        int[] copy = new int[dims.length];
        System.arraycopy(dims, 0, copy, 0, dims.length);
        return copy;
    }

    @Override
    public double getMarginal(int variable, int category) {
        if (category >= this.getNumCategories(variable)) {
            throw new IllegalArgumentException();
        }
        return this.getUpdatedMarginalFromModel(variable, category);
    }

    @Override
    public boolean isJointMarginalSupported() {
        return false;
    }

    @Override
    public double getJointMarginal(int[] variables, int[] values) {
        throw new UnsupportedOperationException();
    }

    @Override
    public BayesIm getBayesIm() {
        throw new UnsupportedOperationException();
    }

    @Override
    public double[] calculatePriorMarginals(int nodeIndex) {
        Evidence evidence = this.getEvidence();
        this.setEvidence(Evidence.tautology(evidence.getVariableSource()));
        double[] marginals = new double[evidence.getNumCategories(nodeIndex)];
        for (int i = 0; i < this.getBayesIm().getNumColumns(nodeIndex); ++i) {
            marginals[i] = this.getMarginal(nodeIndex, i);
        }
        this.setEvidence(evidence);
        return marginals;
    }

    @Override
    public double[] calculateUpdatedMarginals(int nodeIndex) {
        double[] marginals = new double[this.evidence.getNumCategories(nodeIndex)];
        for (int i = 0; i < this.getBayesIm().getNumColumns(nodeIndex); ++i) {
            marginals[i] = this.getMarginal(nodeIndex, i);
        }
        return marginals;
    }

    private BayesPm getBayesPm() {
        return this.bayesPm;
    }

    private Evidence getEvidence() {
        return new Evidence(this.evidence.getProposition());
    }

    private int getNumNodes() {
        return this.nodes.length;
    }

    private Node getNode(int nodeIndex) {
        return this.nodes[nodeIndex];
    }

    private int getNumCategories(int nodeIndex) {
        Node node = this.nodes[nodeIndex];
        return this.getBayesPm().getNumCategories(node);
    }

    private int[] getParents(int nodeIndex) {
        int[] nodeParents = this.parents[nodeIndex];
        int[] copy = new int[nodeParents.length];
        System.arraycopy(nodeParents, 0, copy, 0, nodeParents.length);
        return copy;
    }

    private DiscreteProbs getDiscreteProbs() {
        if (this.discreteProbs == null) {
            this.discreteProbs = new IntAveDataSetProbs(this.dataSet);
        }
        return this.discreteProbs;
    }

    private void initialize() {
        this.parents = new int[this.nodes.length][];
        this.parentDims = new int[this.nodes.length][];
        for (int nodeIndex = 0; nodeIndex < this.nodes.length; ++nodeIndex) {
            this.initializeNode(nodeIndex);
        }
    }

    private void initializeNode(int nodeIndex) {
        Node node = this.nodes[nodeIndex];
        Dag graph = this.getBayesPm().getDag();
        ArrayList<Node> parentList = new ArrayList<Node>(graph.getParents(node));
        int[] parentArray = new int[parentList.size()];
        for (int i = 0; i < parentList.size(); ++i) {
            parentArray[i] = this.getNodeIndex((Node)parentList.get(i));
        }
        Arrays.sort(parentArray);
        this.parents[nodeIndex] = parentArray;
        int[] dims = new int[parentArray.length];
        for (int i = 0; i < dims.length; ++i) {
            Node parNode = this.nodes[parentArray[i]];
            dims[i] = this.getBayesPm().getNumCategories(parNode);
        }
        this.parentDims[nodeIndex] = dims;
    }

    private double getUpdatedMarginalFromModel(int variable, int category) {
        Proposition evidence = this.getEvidence().getProposition();
        int[] variableValues = new int[evidence.getNumVariables()];
        for (int i = 0; i < evidence.getNumVariables(); ++i) {
            variableValues[i] = OnTheFlyMarginalCalculator.nextValue(evidence, i, -1);
        }
        variableValues[variableValues.length - 1] = -1;
        double sum = 0.0;
        block1: while (true) {
            for (int i = evidence.getNumVariables() - 1; i >= 0; --i) {
                if (!OnTheFlyMarginalCalculator.hasNextValue(evidence, i, variableValues[i])) continue;
                variableValues[i] = OnTheFlyMarginalCalculator.nextValue(evidence, i, variableValues[i]);
                for (int j = i + 1; j < evidence.getNumVariables(); ++j) {
                    if (!OnTheFlyMarginalCalculator.hasNextValue(evidence, j, -1)) break block1;
                    variableValues[j] = OnTheFlyMarginalCalculator.nextValue(evidence, j, -1);
                }
                double product = 1.0;
                for (int m = 0; m < this.getNumNodes(); ++m) {
                    int[] parents;
                    Proposition assertion = Proposition.tautology(this);
                    assertion.setCategory(variable, category);
                    Proposition condition = new Proposition(evidence);
                    for (int parent : parents = this.getParents(m)) {
                        condition.disallowComplement(parent, variableValues[parent]);
                    }
                    if (!condition.existsCombination()) continue;
                    product *= this.getDiscreteProbs().getConditionalProb(assertion, condition);
                }
                sum += product;
            }
            break;
        }
        return sum;
    }

    private static boolean hasNextValue(Proposition proposition, int variable, int currentIndex) {
        return OnTheFlyMarginalCalculator.nextValue(proposition, variable, currentIndex) != -1;
    }

    private static int nextValue(Proposition proposition, int variable, int currentIndex) {
        for (int i = currentIndex + 1; i < proposition.getNumCategories(variable); ++i) {
            if (!proposition.isAllowed(variable, i)) continue;
            return i;
        }
        return -1;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.bayesPm == null) {
            throw new NullPointerException();
        }
        if (this.nodes == null) {
            throw new NullPointerException();
        }
        if (this.dataSet == null) {
            throw new NullPointerException();
        }
        if (this.evidence == null) {
            throw new NullPointerException();
        }
        if (this.parents == null) {
            throw new NullPointerException();
        }
        if (this.parentDims == null) {
            throw new NullPointerException();
        }
    }
}

