/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.StoredCellProbsObs;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.Paths;
import edu.cmu.tetrad.graph.TimeLagGraph;
import edu.cmu.tetrad.util.RandomUtil;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public final class MlBayesImObs
implements BayesIm {
    static final long serialVersionUID = 23L;
    private static final double ALLOWABLE_DIFFERENCE = 1.0E-10;
    private static final int MANUAL = 0;
    private static final int RANDOM = 1;
    private final BayesPm bayesPm;
    private final Node[] nodes;
    private int[][] parents;
    private int[][] parentDims;
    private double[][][] probs;
    private StoredCellProbsObs jpd;
    private BayesIm bayesImRandomize;
    private BayesIm bayesImObs;

    public MlBayesImObs(BayesPm bayesPm) throws IllegalArgumentException {
        this(bayesPm, 0);
    }

    public MlBayesImObs(BayesPm bayesPm, int initializationMethod) throws IllegalArgumentException {
        if (bayesPm == null) {
            throw new NullPointerException("BayesPm must not be null.");
        }
        this.bayesPm = new BayesPm(bayesPm);
        Graph graph = bayesPm.getDag();
        this.nodes = graph.getNodes().toArray(new Node[0]);
        this.initialize(null, initializationMethod);
    }

    public MlBayesImObs(BayesPm bayesPm, BayesIm oldBayesIm, int initializationMethod) throws IllegalArgumentException {
        if (bayesPm == null) {
            throw new NullPointerException("BayesPm must not be null.");
        }
        this.bayesPm = new BayesPm(bayesPm);
        Graph graph = bayesPm.getDag();
        this.nodes = graph.getNodes().toArray(new Node[0]);
        this.initialize(oldBayesIm, initializationMethod);
    }

    public MlBayesImObs(BayesIm bayesIm) throws IllegalArgumentException {
        if (bayesIm == null) {
            throw new NullPointerException("BayesIm must not be null.");
        }
        this.bayesPm = bayesIm.getBayesPm();
        this.nodes = new Node[bayesIm.getNumNodes()];
        for (int i = 0; i < bayesIm.getNumNodes(); ++i) {
            this.nodes[i] = bayesIm.getNode(i);
        }
        this.initialize(bayesIm, 0);
    }

    public static MlBayesImObs serializableInstance() {
        return new MlBayesImObs(BayesPm.serializableInstance());
    }

    @Override
    public BayesPm getBayesPm() {
        return this.bayesPm;
    }

    @Override
    public Graph getDag() {
        return this.bayesPm.getDag();
    }

    @Override
    public int getNumNodes() {
        return this.nodes.length;
    }

    @Override
    public Node getNode(int nodeIndex) {
        return this.nodes[nodeIndex];
    }

    @Override
    public Node getNode(String name) {
        return this.getDag().getNode(name);
    }

    @Override
    public int getNodeIndex(Node node) {
        for (int i = 0; i < this.nodes.length; ++i) {
            if (node != this.nodes[i]) continue;
            return i;
        }
        return -1;
    }

    @Override
    public List<Node> getVariables() {
        LinkedList<Node> variables = new LinkedList<Node>();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            Node node = this.getNode(i);
            variables.add(this.bayesPm.getVariable(node));
        }
        return variables;
    }

    @Override
    public List<Node> getMeasuredNodes() {
        return this.bayesPm.getMeasuredNodes();
    }

    @Override
    public List<String> getVariableNames() {
        LinkedList<String> variableNames = new LinkedList<String>();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            Node node = this.getNode(i);
            variableNames.add(this.bayesPm.getVariable(node).getName());
        }
        return variableNames;
    }

    @Override
    public int getNumColumns(int nodeIndex) {
        return this.probs[nodeIndex][0].length;
    }

    @Override
    public int getNumRows(int nodeIndex) {
        return this.probs[nodeIndex].length;
    }

    @Override
    public int getNumParents(int nodeIndex) {
        return this.parents[nodeIndex].length;
    }

    @Override
    public int getParent(int nodeIndex, int parentIndex) {
        return this.parents[nodeIndex][parentIndex];
    }

    @Override
    public int getParentDim(int nodeIndex, int parentIndex) {
        return this.parentDims[nodeIndex][parentIndex];
    }

    @Override
    public int[] getParentDims(int nodeIndex) {
        int[] dims = this.parentDims[nodeIndex];
        int[] copy = new int[dims.length];
        System.arraycopy(dims, 0, copy, 0, dims.length);
        return copy;
    }

    @Override
    public int[] getParents(int nodeIndex) {
        int[] nodeParents = this.parents[nodeIndex];
        int[] copy = new int[nodeParents.length];
        System.arraycopy(nodeParents, 0, copy, 0, nodeParents.length);
        return copy;
    }

    @Override
    public int[] getParentValues(int nodeIndex, int rowIndex) {
        int[] dims = this.getParentDims(nodeIndex);
        int[] values = new int[dims.length];
        for (int i = dims.length - 1; i >= 0; --i) {
            values[i] = rowIndex % dims[i];
            rowIndex /= dims[i];
        }
        return values;
    }

    @Override
    public int getParentValue(int nodeIndex, int rowIndex, int colIndex) {
        return this.getParentValues(nodeIndex, rowIndex)[colIndex];
    }

    @Override
    public double getProbability(int nodeIndex, int rowIndex, int colIndex) {
        return this.probs[nodeIndex][rowIndex][colIndex];
    }

    @Override
    public int getRowIndex(int nodeIndex, int[] values) {
        int[] dim = this.getParentDims(nodeIndex);
        int rowIndex = 0;
        for (int i = 0; i < dim.length; ++i) {
            rowIndex *= dim[i];
            rowIndex += values[i];
        }
        return rowIndex;
    }

    @Override
    public void normalizeAll() {
        for (int nodeIndex = 0; nodeIndex < this.nodes.length; ++nodeIndex) {
            this.normalizeNode(nodeIndex);
        }
    }

    @Override
    public void normalizeNode(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            this.normalizeRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public void normalizeRow(int nodeIndex, int rowIndex) {
        int colIndex;
        int numColumns = this.getNumColumns(nodeIndex);
        double total = 0.0;
        for (colIndex = 0; colIndex < numColumns; ++colIndex) {
            total += this.getProbability(nodeIndex, rowIndex, colIndex);
        }
        if (total != 0.0) {
            for (colIndex = 0; colIndex < numColumns; ++colIndex) {
                double probability = this.getProbability(nodeIndex, rowIndex, colIndex);
                double prob = probability / total;
                this.setProbability(nodeIndex, rowIndex, colIndex, prob);
            }
        } else {
            double prob = 1.0 / (double)numColumns;
            for (int colIndex2 = 0; colIndex2 < numColumns; ++colIndex2) {
                this.setProbability(nodeIndex, rowIndex, colIndex2, prob);
            }
        }
    }

    @Override
    public void setProbability(int nodeIndex, double[][] probMatrix) {
        for (int i = 0; i < probMatrix.length; ++i) {
            System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length);
        }
    }

    @Override
    public void setProbability(int nodeIndex, int rowIndex, int colIndex, double value) {
        if (colIndex >= this.getNumColumns(nodeIndex)) {
            throw new IllegalArgumentException("Column out of range: " + colIndex + " >= " + this.getNumColumns(nodeIndex));
        }
        if (!(0.0 <= value && value <= 1.0 || Double.isNaN(value))) {
            throw new IllegalArgumentException("Probability value must be between 0.0 and 1.0 or Double.NaN.");
        }
        this.probs[nodeIndex][rowIndex][colIndex] = value;
    }

    @Override
    public int getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm) {
        String nodeName = this.getNode(nodeIndex).getName();
        Node oldNode = otherBayesIm.getNode(nodeName);
        return otherBayesIm.getNodeIndex(oldNode);
    }

    @Override
    public void clearRow(int nodeIndex, int rowIndex) {
        for (int colIndex = 0; colIndex < this.getNumColumns(nodeIndex); ++colIndex) {
            this.setProbability(nodeIndex, rowIndex, colIndex, Double.NaN);
        }
    }

    @Override
    public void randomizeRow(int nodeIndex, int rowIndex) {
        int size = this.getNumColumns(nodeIndex);
        this.probs[nodeIndex][rowIndex] = MlBayesImObs.getRandomWeights(size);
    }

    @Override
    public void randomizeIncompleteRows(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            if (!this.isIncomplete(nodeIndex, rowIndex)) continue;
            this.randomizeRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public void randomizeTable(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            this.randomizeRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public void clearTable(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            this.clearRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public boolean isIncomplete(int nodeIndex, int rowIndex) {
        for (int colIndex = 0; colIndex < this.getNumColumns(nodeIndex); ++colIndex) {
            double p = this.getProbability(nodeIndex, rowIndex, colIndex);
            if (!Double.isNaN(p)) continue;
            return true;
        }
        return false;
    }

    @Override
    public boolean isIncomplete(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            if (!this.isIncomplete(nodeIndex, rowIndex)) continue;
            return true;
        }
        return false;
    }

    @Override
    public DataSet simulateData(int sampleSize, boolean latentDataSaved) {
        if (this.getBayesPm().getDag().isTimeLagModel()) {
            return this.simulateTimeSeries(sampleSize);
        }
        return this.simulateDataHelper(sampleSize, latentDataSaved);
    }

    @Override
    public DataSet simulateData(DataSet dataSet, boolean latentDataSaved) {
        return this.simulateDataHelper(dataSet, latentDataSaved);
    }

    private DataSet simulateTimeSeries(int sampleSize) {
        TimeLagGraph timeSeriesGraph = this.getBayesPm().getDag().getTimeLagGraph();
        ArrayList<Node> variables = new ArrayList<Node>();
        for (Node node : timeSeriesGraph.getLag0Nodes()) {
            variables.add(new DiscreteVariable(timeSeriesGraph.getNodeId(node).getName()));
        }
        List<Node> lag0Nodes = timeSeriesGraph.getLag0Nodes();
        BoxDataSet fullData = new BoxDataSet(new VerticalDoubleDataBox(sampleSize, variables.size()), variables);
        Graph contemporaneousDag = timeSeriesGraph.subgraph(lag0Nodes);
        Paths paths = contemporaneousDag.paths();
        List<Node> initialOrder = contemporaneousDag.getNodes();
        List<Node> tierOrdering = paths.validOrder(initialOrder, true);
        int[] tiers = new int[tierOrdering.size()];
        for (int i = 0; i < tierOrdering.size(); ++i) {
            tiers[i] = this.getNodeIndex(tierOrdering.get(i));
        }
        int[] combination = new int[tierOrdering.size()];
        for (int i = 0; i < sampleSize; ++i) {
            int[] point = new int[this.nodes.length];
            block3: for (int nodeIndex : tiers) {
                double cutoff = RandomUtil.getInstance().nextDouble();
                for (int k = 0; k < this.getNumParents(nodeIndex); ++k) {
                    combination[k] = point[this.getParent(nodeIndex, k)];
                }
                int rowIndex = this.getRowIndex(nodeIndex, combination);
                double sum = 0.0;
                for (int k = 0; k < this.getNumColumns(nodeIndex); ++k) {
                    double probability = this.getProbability(nodeIndex, rowIndex, k);
                    if (Double.isNaN(probability)) {
                        throw new IllegalStateException("Some probability values in the BayesIm are not filled in; cannot simulate data.");
                    }
                    if (!((sum += probability) >= cutoff)) continue;
                    point[nodeIndex] = k;
                    continue block3;
                }
            }
        }
        return fullData;
    }

    private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved) {
        int numMeasured = 0;
        int[] map = new int[this.nodes.length];
        LinkedList<Node> variables = new LinkedList<Node>();
        for (int j = 0; j < this.nodes.length; ++j) {
            if (!latentDataSaved && this.nodes[j].getNodeType() != NodeType.MEASURED) continue;
            int numCategories = this.bayesPm.getNumCategories(this.nodes[j]);
            LinkedList<String> categories = new LinkedList<String>();
            for (int k = 0; k < numCategories; ++k) {
                categories.add(this.bayesPm.getCategory(this.nodes[j], k));
            }
            DiscreteVariable var = new DiscreteVariable(this.nodes[j].getName(), categories);
            variables.add(var);
            int index = ++numMeasured - 1;
            map[index] = j;
        }
        BoxDataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(sampleSize, variables.size()), variables);
        this.constructSample(sampleSize, numMeasured, dataSet, map);
        return dataSet;
    }

    private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved) {
        if (dataSet.getNumColumns() != this.nodes.length) {
            throw new IllegalArgumentException("When rewriting the old data set, number of variables in data set must equal number of variables in Bayes net.");
        }
        int sampleSize = dataSet.getNumRows();
        int numMeasured = 0;
        int[] map = new int[this.nodes.length];
        LinkedList<DiscreteVariable> variables = new LinkedList<DiscreteVariable>();
        for (int j = 0; j < this.nodes.length; ++j) {
            if (!latentDataSaved && this.nodes[j].getNodeType() != NodeType.MEASURED) continue;
            int numCategories = this.bayesPm.getNumCategories(this.nodes[j]);
            LinkedList<String> categories = new LinkedList<String>();
            for (int k = 0; k < numCategories; ++k) {
                categories.add(this.bayesPm.getCategory(this.nodes[j], k));
            }
            DiscreteVariable var = new DiscreteVariable(this.nodes[j].getName(), categories);
            variables.add(var);
            int index = ++numMeasured - 1;
            map[index] = j;
        }
        for (int i = 0; i < variables.size(); ++i) {
            Node node = dataSet.getVariable(i);
            Node _node = (Node)variables.get(i);
            dataSet.changeVariable(node, _node);
        }
        this.constructSample(sampleSize, numMeasured, dataSet, map);
        return dataSet;
    }

    private void constructSample(int sampleSize, int numMeasured, DataSet dataSet, int[] map) {
        Graph graph = this.getBayesPm().getDag();
        Dag dag = new Dag(graph);
        Paths paths = dag.paths();
        List<Node> initialOrder = dag.getNodes();
        List<Node> tierOrdering = paths.validOrder(initialOrder, true);
        int[] tiers = new int[tierOrdering.size()];
        for (int i = 0; i < tierOrdering.size(); ++i) {
            tiers[i] = this.getNodeIndex(tierOrdering.get(i));
        }
        int[] combination = new int[this.nodes.length];
        for (int i = 0; i < sampleSize; ++i) {
            int[] point = new int[this.nodes.length];
            block2: for (int nodeIndex : tiers) {
                double cutoff = RandomUtil.getInstance().nextDouble();
                for (int k = 0; k < this.getNumParents(nodeIndex); ++k) {
                    combination[k] = point[this.getParent(nodeIndex, k)];
                }
                int rowIndex = this.getRowIndex(nodeIndex, combination);
                double sum = 0.0;
                for (int k = 0; k < this.getNumColumns(nodeIndex); ++k) {
                    double probability = this.getProbability(nodeIndex, rowIndex, k);
                    if (Double.isNaN(probability)) {
                        throw new IllegalStateException("Some probability values in the BayesIm are not filled in; cannot simulate data.");
                    }
                    if (!((sum += probability) >= cutoff)) continue;
                    point[nodeIndex] = k;
                    continue block2;
                }
            }
            for (int j = 0; j < numMeasured; ++j) {
                dataSet.setInt(i, j, point[map[j]]);
            }
        }
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BayesIm)) {
            return false;
        }
        BayesIm otherIm = (BayesIm)o;
        if (this.getNumNodes() != otherIm.getNumNodes()) {
            return false;
        }
        for (int i = 0; i < this.getNumNodes(); ++i) {
            int otherIndex = otherIm.getCorrespondingNodeIndex(i, otherIm);
            if (otherIndex == -1) {
                return false;
            }
            if (this.getNumColumns(i) != otherIm.getNumColumns(otherIndex)) {
                return false;
            }
            if (this.getNumRows(i) != otherIm.getNumRows(otherIndex)) {
                return false;
            }
            for (int j = 0; j < this.getNumRows(i); ++j) {
                for (int k = 0; k < this.getNumColumns(i); ++k) {
                    double prob = this.getProbability(i, j, k);
                    double otherProb = otherIm.getProbability(i, j, k);
                    if (Double.isNaN(prob) && Double.isNaN(otherProb) || !(FastMath.abs(prob - otherProb) > 1.0E-10)) continue;
                    return false;
                }
            }
        }
        return true;
    }

    @Override
    public String toString() {
        return "MlBayesImObs\n";
    }

    public BayesIm getBayesImObs() {
        return this.bayesImObs;
    }

    public StoredCellProbsObs getJPD() {
        return this.jpd;
    }

    public int getNumRows() {
        return this.jpd.getNumRows();
    }

    public int[] getRowValues(int rowIndex) {
        return this.jpd.getVariableValues(rowIndex);
    }

    public double getProbability(int rowIndex) {
        return this.jpd.getCellProb(this.getRowValues(rowIndex));
    }

    public void setProbability(int rowIndex, double value) {
        if (!(0.0 <= value && value <= 1.0 || Double.isNaN(value))) {
            throw new IllegalArgumentException("Probability value must be between 0.0 and 1.0 or Double.NaN.");
        }
        this.jpd.setCellProbability(this.getRowValues(rowIndex), value);
    }

    public void createRandomCellTable() {
        for (int nodeIndex = 0; nodeIndex < this.nodes.length; ++nodeIndex) {
            this.bayesImRandomize.randomizeTable(nodeIndex);
        }
        this.jpd.createCellTable((MlBayesIm)this.bayesImRandomize);
    }

    private void initialize(BayesIm oldBayesIm, int initializationMethod) {
        this.parents = new int[this.nodes.length][];
        this.parentDims = new int[this.nodes.length][];
        this.probs = new double[this.nodes.length][][];
        for (int nodeIndex = 0; nodeIndex < this.nodes.length; ++nodeIndex) {
            this.initializeNode(nodeIndex);
        }
        this.bayesImRandomize = new MlBayesIm(this.bayesPm);
        Dag dag = new Dag(this.bayesPm.getDag());
        for (Node node : this.nodes) {
            if (node.getNodeType() != NodeType.LATENT) continue;
            dag.removeNode(node);
        }
        BayesPm bayesPmObs = new BayesPm(dag, this.bayesPm);
        this.bayesImObs = new MlBayesIm(bayesPmObs);
        ArrayList<Node> obsNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            Node node = this.bayesPm.getVariable(node1);
            if (node.getNodeType() != NodeType.MEASURED) continue;
            obsNodes.add(node);
        }
        this.jpd = new StoredCellProbsObs(obsNodes);
        if (initializationMethod == 1) {
            if (oldBayesIm == null) {
                this.createRandomCellTable();
            } else if (oldBayesIm.getClass().getSimpleName().equals("MlBayesIm")) {
                this.jpd.createCellTable((MlBayesIm)oldBayesIm);
            } else if (oldBayesIm.getClass().getSimpleName().equals("MlBayesImObs")) {
                if (this.bayesPm.equals(oldBayesIm.getBayesPm())) {
                    this.jpd.createCellTable((MlBayesImObs)oldBayesIm);
                } else {
                    this.createRandomCellTable();
                }
            }
        } else if (initializationMethod == 0) {
            if (oldBayesIm == null) {
                this.jpd.clearCellTable();
            } else if (oldBayesIm.getClass().getSimpleName().equals("MlBayesIm")) {
                this.jpd.createCellTable((MlBayesIm)oldBayesIm);
            } else if (oldBayesIm.getClass().getSimpleName().equals("MlBayesImObs")) {
                if (this.bayesPm.equals(oldBayesIm.getBayesPm())) {
                    this.jpd.createCellTable((MlBayesImObs)oldBayesIm);
                } else {
                    this.jpd.clearCellTable();
                }
            }
        } else {
            throw new IllegalArgumentException("Unrecognized state.");
        }
    }

    private void initializeNode(int nodeIndex) {
        Node node = this.nodes[nodeIndex];
        Graph graph = this.getBayesPm().getDag();
        List<Node> parentList = graph.getParents(node);
        int[] parentArray = new int[parentList.size()];
        for (int i = 0; i < parentList.size(); ++i) {
            parentArray[i] = this.getNodeIndex(parentList.get(i));
        }
        Arrays.sort(parentArray);
        this.parents[nodeIndex] = parentArray;
        int[] dims = new int[parentArray.length];
        for (int i = 0; i < dims.length; ++i) {
            Node node2 = this.nodes[parentArray[i]];
            dims[i] = this.getBayesPm().getNumCategories(node2);
        }
        int numRows = 1;
        for (int dim : dims) {
            if (numRows > 1000000) {
                throw new IllegalArgumentException("The number of rows in the conditional probability table for " + this.nodes[nodeIndex] + " is greater than 1,000,000 and cannot be represented.");
            }
            numRows *= dim;
        }
        int n = this.getBayesPm().getNumCategories(node);
        this.parentDims[nodeIndex] = dims;
        this.probs[nodeIndex] = new double[numRows][n];
    }

    private static double[] getRandomWeights(int size) {
        int i;
        assert (size >= 0);
        double[] row = new double[size];
        double sum = 0.0;
        double bias = 0.0;
        int randomCell = RandomUtil.getInstance().nextInt(size);
        for (i = 0; i < size; ++i) {
            row[i] = RandomUtil.getInstance().nextDouble();
            if (i == randomCell) {
                int n = i;
                row[n] = row[n] + 0.0;
            }
            sum += row[i];
        }
        i = 0;
        while (i < size) {
            int n = i++;
            row[n] = row[n] / sum;
        }
        return row;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.bayesPm == null) {
            throw new NullPointerException();
        }
        if (this.nodes == null) {
            throw new NullPointerException();
        }
        if (this.parents == null) {
            throw new NullPointerException();
        }
        if (this.parentDims == null) {
            throw new NullPointerException();
        }
        if (this.probs == null) {
            throw new NullPointerException();
        }
    }
}

