/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

public final class MlBayesIm
implements BayesIm {
    static final long serialVersionUID = 23L;
    private static final double ALLOWABLE_DIFFERENCE = 1.0E-10;
    public static final int MANUAL = 0;
    public static final int RANDOM = 1;
    private BayesPm bayesPm;
    private Node[] nodes;
    private int[][] parents;
    private int[][] parentDims;
    private double[][][] probs;
    private static final RandomUtil randomUtil = RandomUtil.getInstance();

    public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException {
        this(bayesPm, null, 0);
    }

    public MlBayesIm(BayesPm bayesPm, int initializationMethod) throws IllegalArgumentException {
        this(bayesPm, null, initializationMethod);
    }

    public MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int initializationMethod) throws IllegalArgumentException {
        if (bayesPm == null) {
            throw new NullPointerException("BayesPm must not be null.");
        }
        this.bayesPm = new BayesPm(bayesPm);
        Dag graph = bayesPm.getDag();
        this.nodes = graph.getNodes().toArray(new Node[0]);
        this.initialize(oldBayesIm, initializationMethod);
    }

    public MlBayesIm(BayesIm bayesIm) throws IllegalArgumentException {
        if (bayesIm == null) {
            throw new NullPointerException("BayesIm must not be null.");
        }
        this.bayesPm = bayesIm.getBayesPm();
        this.nodes = new Node[bayesIm.getNumNodes()];
        for (int i = 0; i < bayesIm.getNumNodes(); ++i) {
            this.nodes[i] = bayesIm.getNode(i);
        }
        this.initialize(bayesIm, 0);
    }

    public static MlBayesIm serializableInstance() {
        return new MlBayesIm(BayesPm.serializableInstance());
    }

    @Override
    public BayesPm getBayesPm() {
        return this.bayesPm;
    }

    @Override
    public Dag getDag() {
        return this.bayesPm.getDag();
    }

    @Override
    public int getNumNodes() {
        return this.nodes.length;
    }

    @Override
    public Node getNode(int nodeIndex) {
        return this.nodes[nodeIndex];
    }

    @Override
    public Node getNode(String name) {
        return this.getDag().getNode(name);
    }

    @Override
    public int getNodeIndex(Node node) {
        for (int i = 0; i < this.nodes.length; ++i) {
            if (node != this.nodes[i]) continue;
            return i;
        }
        return -1;
    }

    @Override
    public List<Node> getVariables() {
        LinkedList<Node> variables = new LinkedList<Node>();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            Node node = this.getNode(i);
            variables.add(this.bayesPm.getVariable(node));
        }
        return variables;
    }

    @Override
    public List<Node> getMeasuredNodes() {
        return this.bayesPm.getMeasuredNodes();
    }

    @Override
    public List<String> getVariableNames() {
        LinkedList<String> variableNames = new LinkedList<String>();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            Node node = this.getNode(i);
            variableNames.add(this.bayesPm.getVariable(node).getName());
        }
        return variableNames;
    }

    @Override
    public int getNumColumns(int nodeIndex) {
        return this.probs[nodeIndex][0].length;
    }

    @Override
    public int getNumRows(int nodeIndex) {
        return this.probs[nodeIndex].length;
    }

    @Override
    public int getNumParents(int nodeIndex) {
        return this.parents[nodeIndex].length;
    }

    @Override
    public int getParent(int nodeIndex, int parentIndex) {
        return this.parents[nodeIndex][parentIndex];
    }

    @Override
    public int getParentDim(int nodeIndex, int parentIndex) {
        return this.parentDims[nodeIndex][parentIndex];
    }

    @Override
    public int[] getParentDims(int nodeIndex) {
        int[] dims = this.parentDims[nodeIndex];
        int[] copy = new int[dims.length];
        System.arraycopy(dims, 0, copy, 0, dims.length);
        return copy;
    }

    @Override
    public int[] getParents(int nodeIndex) {
        int[] nodeParents = this.parents[nodeIndex];
        int[] copy = new int[nodeParents.length];
        System.arraycopy(nodeParents, 0, copy, 0, nodeParents.length);
        return copy;
    }

    @Override
    public int[] getParentValues(int nodeIndex, int rowIndex) {
        int[] dims = this.getParentDims(nodeIndex);
        int[] values = new int[dims.length];
        for (int i = dims.length - 1; i >= 0; --i) {
            values[i] = rowIndex % dims[i];
            rowIndex /= dims[i];
        }
        return values;
    }

    @Override
    public int getParentValue(int nodeIndex, int rowIndex, int colIndex) {
        return this.getParentValues(nodeIndex, rowIndex)[colIndex];
    }

    @Override
    public double getProbability(int nodeIndex, int rowIndex, int colIndex) {
        return this.probs[nodeIndex][rowIndex][colIndex];
    }

    @Override
    public int getRowIndex(int nodeIndex, int[] values) {
        int[] dim = this.getParentDims(nodeIndex);
        int rowIndex = 0;
        for (int i = 0; i < dim.length; ++i) {
            rowIndex *= dim[i];
            rowIndex += values[i];
        }
        return rowIndex;
    }

    @Override
    public void normalizeAll() {
        for (int nodeIndex = 0; nodeIndex < this.nodes.length; ++nodeIndex) {
            this.normalizeNode(nodeIndex);
        }
    }

    @Override
    public void normalizeNode(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            this.normalizeRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public void normalizeRow(int nodeIndex, int rowIndex) {
        int colIndex;
        int numColumns = this.getNumColumns(nodeIndex);
        double total = 0.0;
        for (colIndex = 0; colIndex < numColumns; ++colIndex) {
            total += this.getProbability(nodeIndex, rowIndex, colIndex);
        }
        if (total != 0.0) {
            for (colIndex = 0; colIndex < numColumns; ++colIndex) {
                double probability = this.getProbability(nodeIndex, rowIndex, colIndex);
                double prob = probability / total;
                this.setProbability(nodeIndex, rowIndex, colIndex, prob);
            }
        } else {
            double prob = 1.0 / (double)numColumns;
            for (int colIndex2 = 0; colIndex2 < numColumns; ++colIndex2) {
                this.setProbability(nodeIndex, rowIndex, colIndex2, prob);
            }
        }
    }

    @Override
    public void setProbability(int nodeIndex, int rowIndex, int colIndex, double value) {
        if (colIndex >= this.getNumColumns(nodeIndex)) {
            throw new IllegalArgumentException("Column out of range: " + colIndex + " >= " + this.getNumColumns(nodeIndex));
        }
        if (!(0.0 <= value && value <= 1.0 || Double.isNaN(value))) {
            throw new IllegalArgumentException("Probability value must be between 0.0 and 1.0 or Double.NaN.");
        }
        this.probs[nodeIndex][rowIndex][colIndex] = value;
    }

    @Override
    public int getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm) {
        String nodeName = this.getNode(nodeIndex).getName();
        Node oldNode = otherBayesIm.getNode(nodeName);
        return otherBayesIm.getNodeIndex(oldNode);
    }

    @Override
    public void clearRow(int nodeIndex, int rowIndex) {
        for (int colIndex = 0; colIndex < this.getNumColumns(nodeIndex); ++colIndex) {
            this.setProbability(nodeIndex, rowIndex, colIndex, Double.NaN);
        }
    }

    @Override
    public void randomizeRow(int nodeIndex, int rowIndex) {
        int size = this.getNumColumns(nodeIndex);
        this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size);
    }

    @Override
    public void randomizeIncompleteRows(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            if (!this.isIncomplete(nodeIndex, rowIndex)) continue;
            this.randomizeRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public void randomizeTable(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            this.randomizeRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public void clearTable(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            this.clearRow(nodeIndex, rowIndex);
        }
    }

    @Override
    public boolean isIncomplete(int nodeIndex, int rowIndex) {
        for (int colIndex = 0; colIndex < this.getNumColumns(nodeIndex); ++colIndex) {
            double p = this.getProbability(nodeIndex, rowIndex, colIndex);
            if (!Double.isNaN(p)) continue;
            return true;
        }
        return false;
    }

    @Override
    public boolean isIncomplete(int nodeIndex) {
        for (int rowIndex = 0; rowIndex < this.getNumRows(nodeIndex); ++rowIndex) {
            if (!this.isIncomplete(nodeIndex, rowIndex)) continue;
            return true;
        }
        return false;
    }

    @Override
    public DataSet simulateData(int sampleSize, boolean latentDataSaved) {
        return this.simulateDataHelper(sampleSize, randomUtil, latentDataSaved);
    }

    @Override
    public DataSet simulateData(DataSet dataSet, boolean latentDataSaved) {
        return this.simulateDataHelper(dataSet, randomUtil, latentDataSaved);
    }

    public DataSet simulateData(int sampleSize, long seed, boolean latentDataSaved) {
        RandomUtil random = RandomUtil.getInstance();
        random.setSeed(seed);
        return this.simulateDataHelper(sampleSize, random, latentDataSaved);
    }

    public DataSet simulateData(DataSet dataSet, long seed, boolean latentDataSaved) {
        RandomUtil random = RandomUtil.getInstance();
        random.setSeed(seed);
        return this.simulateDataHelper(dataSet, random, latentDataSaved);
    }

    private DataSet simulateDataHelper(int sampleSize, RandomUtil randomUtil, boolean latentDataSaved) {
        int numMeasured = 0;
        int[] map = new int[this.nodes.length];
        LinkedList<Node> variables = new LinkedList<Node>();
        for (int j = 0; j < this.nodes.length; ++j) {
            if (!latentDataSaved && this.nodes[j].getNodeType() != NodeType.MEASURED) continue;
            int numCategories = this.bayesPm.getNumCategories(this.nodes[j]);
            LinkedList<String> categories = new LinkedList<String>();
            for (int k = 0; k < numCategories; ++k) {
                categories.add(this.bayesPm.getCategory(this.nodes[j], k));
            }
            DiscreteVariable var = new DiscreteVariable(this.nodes[j].getName(), categories);
            variables.add(var);
            int index = ++numMeasured - 1;
            map[index] = j;
        }
        ColtDataSet dataSet = new ColtDataSet(sampleSize, variables);
        this.constructSample(sampleSize, randomUtil, numMeasured, dataSet, map);
        return dataSet;
    }

    private DataSet simulateDataHelper(DataSet dataSet, RandomUtil randomUtil, boolean latentDataSaved) {
        if (dataSet.getNumColumns() != this.nodes.length) {
            throw new IllegalArgumentException("When rewriting the old data set, number of variables in data set must equal number of variables in Bayes net.");
        }
        int sampleSize = dataSet.getNumRows();
        int numMeasured = 0;
        int[] map = new int[this.nodes.length];
        LinkedList<DiscreteVariable> variables = new LinkedList<DiscreteVariable>();
        for (int j = 0; j < this.nodes.length; ++j) {
            if (!latentDataSaved && this.nodes[j].getNodeType() != NodeType.MEASURED) continue;
            int numCategories = this.bayesPm.getNumCategories(this.nodes[j]);
            LinkedList<String> categories = new LinkedList<String>();
            for (int k = 0; k < numCategories; ++k) {
                categories.add(this.bayesPm.getCategory(this.nodes[j], k));
            }
            DiscreteVariable var = new DiscreteVariable(this.nodes[j].getName(), categories);
            variables.add(var);
            int index = ++numMeasured - 1;
            map[index] = j;
        }
        for (int i = 0; i < variables.size(); ++i) {
            Node node = dataSet.getVariable(i);
            Node _node = (Node)variables.get(i);
            dataSet.changeVariable(node, _node);
        }
        this.constructSample(sampleSize, randomUtil, numMeasured, dataSet, map);
        return dataSet;
    }

    private void constructSample(int sampleSize, RandomUtil randomUtil, int numMeasured, DataSet dataSet, int[] map) {
        Dag graph;
        Dag dag = graph = this.getBayesPm().getDag();
        List<Node> tierOrdering = dag.getTierOrdering();
        int[] tiers = new int[tierOrdering.size()];
        for (int i = 0; i < tierOrdering.size(); ++i) {
            tiers[i] = this.getNodeIndex(tierOrdering.get(i));
        }
        int[] combination = new int[this.nodes.length];
        for (int i = 0; i < sampleSize; ++i) {
            int[] point = new int[this.nodes.length];
            block2: for (int nodeIndex : tiers) {
                double cutoff = randomUtil.nextDouble();
                for (int k = 0; k < this.getNumParents(nodeIndex); ++k) {
                    combination[k] = point[this.getParent(nodeIndex, k)];
                }
                int rowIndex = this.getRowIndex(nodeIndex, combination);
                double sum = 0.0;
                for (int k = 0; k < this.getNumColumns(nodeIndex); ++k) {
                    double probability = this.getProbability(nodeIndex, rowIndex, k);
                    if (Double.isNaN(probability)) {
                        throw new IllegalStateException("Some probability values in the BayesIm are not filled in; cannot simulate data.");
                    }
                    if (!((sum += probability) >= cutoff)) continue;
                    point[nodeIndex] = k;
                    continue block2;
                }
            }
            for (int j = 0; j < numMeasured; ++j) {
                dataSet.setInt(i, j, point[map[j]]);
            }
        }
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BayesIm)) {
            return false;
        }
        BayesIm otherIm = (BayesIm)o;
        if (this.getNumNodes() != otherIm.getNumNodes()) {
            return false;
        }
        for (int i = 0; i < this.getNumNodes(); ++i) {
            int otherIndex = otherIm.getCorrespondingNodeIndex(i, otherIm);
            if (otherIndex == -1) {
                return false;
            }
            if (this.getNumColumns(i) != otherIm.getNumColumns(otherIndex)) {
                return false;
            }
            if (this.getNumRows(i) != otherIm.getNumRows(otherIndex)) {
                return false;
            }
            for (int j = 0; j < this.getNumRows(i); ++j) {
                for (int k = 0; k < this.getNumColumns(i); ++k) {
                    double prob = this.getProbability(i, j, k);
                    double otherProb = otherIm.getProbability(i, j, k);
                    if (Double.isNaN(prob) && Double.isNaN(otherProb) || !(Math.abs(prob - otherProb) > 1.0E-10)) continue;
                    return false;
                }
            }
        }
        return true;
    }

    @Override
    public String toString() {
        StringBuilder buf = new StringBuilder();
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        for (int i = 0; i < this.getNumNodes(); ++i) {
            buf.append("\n\nNode: ").append(this.getNode(i));
            if (this.getNumParents(i) == 0) {
                buf.append("\n");
            } else {
                buf.append("\n\n");
                for (int k = 0; k < this.getNumParents(i); ++k) {
                    buf.append(this.getNode(this.getParent(i, k))).append("\t");
                }
            }
            for (int j = 0; j < this.getNumRows(i); ++j) {
                int k;
                buf.append("\n");
                for (k = 0; k < this.getNumParents(i); ++k) {
                    buf.append(this.getParentValue(i, j, k));
                    if (k >= this.getNumParents(i) - 1) continue;
                    buf.append("\t");
                }
                if (this.getNumParents(i) > 0) {
                    buf.append("\t");
                }
                for (k = 0; k < this.getNumColumns(i); ++k) {
                    buf.append(nf.format(this.getProbability(i, j, k))).append("\t");
                }
            }
        }
        buf.append("\n");
        return buf.toString();
    }

    private void initialize(BayesIm oldBayesIm, int initializationMethod) {
        this.parents = new int[this.nodes.length][];
        this.parentDims = new int[this.nodes.length][];
        this.probs = new double[this.nodes.length][][];
        for (int nodeIndex = 0; nodeIndex < this.nodes.length; ++nodeIndex) {
            this.initializeNode(nodeIndex, oldBayesIm, initializationMethod);
        }
    }

    private void initializeNode(int nodeIndex, BayesIm oldBayesIm, int initializationMethod) {
        Node node = this.nodes[nodeIndex];
        Dag graph = this.getBayesPm().getDag();
        List<Node> parentList = graph.getParents(node);
        int[] parentArray = new int[parentList.size()];
        for (int i = 0; i < parentList.size(); ++i) {
            parentArray[i] = this.getNodeIndex(parentList.get(i));
        }
        Arrays.sort(parentArray);
        this.parents[nodeIndex] = parentArray;
        int[] dims = new int[parentArray.length];
        for (int i = 0; i < dims.length; ++i) {
            Node parNode = this.nodes[parentArray[i]];
            dims[i] = this.getBayesPm().getNumCategories(parNode);
        }
        int numRows = 1;
        for (int dim : dims) {
            if (numRows > 1000000) {
                throw new IllegalArgumentException("The number of rows in the conditional probability table for " + this.nodes[nodeIndex] + " is greater than 1,000,000 and cannot be " + "represented.");
            }
            numRows *= dim;
        }
        int numCols = this.getBayesPm().getNumCategories(node);
        this.parentDims[nodeIndex] = dims;
        this.probs[nodeIndex] = new double[numRows][numCols];
        for (int rowIndex = 0; rowIndex < numRows; ++rowIndex) {
            if (oldBayesIm == null) {
                this.overwriteRow(nodeIndex, rowIndex, initializationMethod);
                continue;
            }
            this.retainOldRowIfPossible(nodeIndex, rowIndex, oldBayesIm, initializationMethod);
        }
    }

    private void overwriteRow(int nodeIndex, int rowIndex, int initializationMethod) {
        if (initializationMethod == 1) {
            this.randomizeRow(nodeIndex, rowIndex);
        } else if (initializationMethod == 0) {
            this.initializeRowAsUnknowns(nodeIndex, rowIndex);
        } else {
            throw new IllegalArgumentException("Unrecognized state.");
        }
    }

    private static double[] getRandomWeights(int size) {
        int i;
        assert (size >= 0);
        double[] row = new double[size];
        double sum = 0.0;
        double bias = 0.0;
        int randomCell = RandomUtil.getInstance().nextInt(size);
        for (i = 0; i < size; ++i) {
            row[i] = randomUtil.nextDouble();
            if (i == randomCell) {
                int n = i;
                row[n] = row[n] + bias;
            }
            sum += row[i];
        }
        i = 0;
        while (i < size) {
            int n = i++;
            row[n] = row[n] / sum;
        }
        return row;
    }

    private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) {
        int size = this.getNumColumns(nodeIndex);
        double[] row = new double[size];
        Arrays.fill(row, Double.NaN);
        this.probs[nodeIndex][rowIndex] = row;
    }

    private void retainOldRowIfPossible(int nodeIndex, int rowIndex, BayesIm oldBayesIm, int initializationMethod) {
        int oldNodeIndex = this.getCorrespondingNodeIndex(nodeIndex, oldBayesIm);
        if (oldNodeIndex == -1) {
            this.overwriteRow(nodeIndex, rowIndex, initializationMethod);
        } else if (this.getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) {
            this.overwriteRow(nodeIndex, rowIndex, initializationMethod);
        } else {
            int oldRowIndex = this.getUniqueCompatibleOldRow(nodeIndex, rowIndex, oldBayesIm);
            if (oldRowIndex >= 0) {
                this.copyValuesFromOldToNew(oldNodeIndex, oldRowIndex, nodeIndex, rowIndex, oldBayesIm);
            } else {
                this.overwriteRow(nodeIndex, rowIndex, initializationMethod);
            }
        }
    }

    private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, BayesIm oldBayesIm) {
        int oldNodeIndex = this.getCorrespondingNodeIndex(nodeIndex, oldBayesIm);
        int oldNumParents = oldBayesIm.getNumParents(oldNodeIndex);
        int[] oldParentValues = new int[oldNumParents];
        Arrays.fill(oldParentValues, -1);
        int[] parentValues = this.getParentValues(nodeIndex, rowIndex);
        for (int i = 0; i < this.getNumParents(nodeIndex); ++i) {
            int parentNodeIndex = this.getParent(nodeIndex, i);
            int oldParentNodeIndex = this.getCorrespondingNodeIndex(parentNodeIndex, oldBayesIm);
            int oldParentIndex = -1;
            for (int j = 0; j < oldBayesIm.getNumParents(oldNodeIndex); ++j) {
                if (oldParentNodeIndex != oldBayesIm.getParent(oldNodeIndex, j)) continue;
                oldParentIndex = j;
                break;
            }
            if (oldParentIndex == -1 || oldParentIndex >= oldBayesIm.getNumParents(oldNodeIndex)) continue;
            int newParentValue = parentValues[i];
            int oldParentDim = oldBayesIm.getParentDim(oldNodeIndex, oldParentIndex);
            if (newParentValue < oldParentDim) {
                oldParentValues[oldParentIndex] = newParentValue;
                continue;
            }
            return -1;
        }
        for (int oldParentValue : oldParentValues) {
            if (oldParentValue != -1) continue;
            return -1;
        }
        return oldBayesIm.getRowIndex(oldNodeIndex, oldParentValues);
    }

    private void copyValuesFromOldToNew(int oldNodeIndex, int oldRowIndex, int nodeIndex, int rowIndex, BayesIm oldBayesIm) {
        if (this.getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) {
            throw new IllegalArgumentException("It's only possible to copy one row of probability values to another in a Bayes IM if the number of columns in the table are the same.");
        }
        for (int colIndex = 0; colIndex < this.getNumColumns(nodeIndex); ++colIndex) {
            double prob = oldBayesIm.getProbability(oldNodeIndex, oldRowIndex, colIndex);
            this.setProbability(nodeIndex, rowIndex, colIndex, prob);
        }
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.bayesPm == null) {
            throw new NullPointerException();
        }
        if (this.nodes == null) {
            throw new NullPointerException();
        }
        if (this.parents == null) {
            throw new NullPointerException();
        }
        if (this.parentDims == null) {
            throw new NullPointerException();
        }
        if (this.probs == null) {
            throw new NullPointerException();
        }
    }
}

