/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.UpdatedBayesIm;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;

public final class CptInvariantMarginalCalculator
implements TetradSerializable {
    private static final long serialVersionUID = 23L;
    private final BayesIm bayesIm;
    private final Evidence evidence;
    private final UpdatedBayesIm updatedBayesIm;
    private double[][] storedMarginals;

    public CptInvariantMarginalCalculator(BayesIm bayesIm, Evidence evidence) {
        if (bayesIm == null) {
            throw new NullPointerException();
        }
        if (evidence == null) {
            throw new NullPointerException();
        }
        if (evidence.isIncompatibleWith(bayesIm)) {
            throw new IllegalArgumentException("The variables for the given Bayes IM and evidence must be compatible.");
        }
        this.bayesIm = bayesIm;
        this.evidence = evidence;
        this.updatedBayesIm = new UpdatedBayesIm(bayesIm, evidence);
        this.storedMarginals = this.initStoredMarginals();
    }

    public static CptInvariantMarginalCalculator serializableInstance() {
        MlBayesIm bayesIm = MlBayesIm.serializableInstance();
        Evidence evidence = Evidence.tautology(bayesIm);
        return new CptInvariantMarginalCalculator(bayesIm, evidence);
    }

    public double getMarginal(int variable, int category) {
        if (this.storedMarginals[variable][category] != -99.0) {
            return this.storedMarginals[variable][category];
        }
        double marginal = 0.0;
        boolean foundANumber = false;
        for (int row = 0; row < this.bayesIm.getNumRows(variable); ++row) {
            double probabilityOfRow;
            double probability = this.updatedBayesIm.getProbability(variable, row, category);
            if (Double.isNaN(probability) || Double.isNaN(probabilityOfRow = this.getProbabilityOfRow(variable, row))) continue;
            marginal += probability * probabilityOfRow;
            foundANumber = true;
        }
        if (!foundANumber) {
            marginal = Double.NaN;
        }
        this.storedMarginals[variable][category] = marginal;
        return marginal;
    }

    public UpdatedBayesIm getUpdatedBayesIm() {
        return this.updatedBayesIm;
    }

    private double[][] initStoredMarginals() {
        this.storedMarginals = new double[this.bayesIm.getNumNodes()][];
        for (int i = 0; i < this.bayesIm.getNumNodes(); ++i) {
            this.storedMarginals[i] = new double[this.bayesIm.getNumColumns(i)];
            Arrays.fill(this.storedMarginals[i], -99.0);
        }
        return this.storedMarginals;
    }

    private double getProbabilityOfRow(int variable, int row) {
        int[] parents = this.bayesIm.getParents(variable);
        int[] parentValues = this.bayesIm.getParentValues(variable, row);
        double probabilityOfRow = 1.0;
        for (int index = 0; index < parents.length; ++index) {
            if (this.noModifiedCpts(parents, index)) {
                double marginal = this.getMarginal(parents[index], parentValues[index]);
                if (Double.isNaN(marginal)) continue;
                probabilityOfRow *= marginal;
                continue;
            }
            Evidence evidence = new Evidence(this.evidence);
            CptInvariantMarginalCalculator marginals = new CptInvariantMarginalCalculator(this.bayesIm, evidence);
            double marginal = marginals.getMarginal(parents[index], parentValues[index]);
            if (Double.isNaN(marginal)) continue;
            probabilityOfRow *= marginal;
        }
        return probabilityOfRow;
    }

    private boolean noModifiedCpts(int[] parents, int i) {
        List<Node> target = Collections.singletonList(this.bayesIm.getNode(parents[i]));
        LinkedList<Node> conditioners = new LinkedList<Node>();
        for (int j = 0; j < i; ++j) {
            conditioners.add(this.bayesIm.getNode(parents[j]));
        }
        List<Node> condAncestors = this.bayesIm.getDag().paths().getAncestors(conditioners);
        List<Node> targetAncestor = this.bayesIm.getDag().paths().getAncestors(target);
        HashSet<Node> intersection = new HashSet<Node>(condAncestors);
        intersection.retainAll(targetAncestor);
        return intersection.isEmpty();
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.bayesIm == null) {
            throw new NullPointerException();
        }
        if (this.evidence == null) {
            throw new NullPointerException();
        }
        if (this.storedMarginals == null) {
            throw new NullPointerException();
        }
        if (this.updatedBayesIm == null) {
            throw new NullPointerException();
        }
    }
}

