/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.EmBayesEstimator;
import edu.cmu.tetrad.bayes.GraphTools;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.TetradSerializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.math3.util.FastMath;

public class JunctionTreeAlgorithm
implements TetradSerializable {
    private static final long serialVersionUID = 23L;
    private final TreeNode root;
    private final Node[] graphNodes;
    private final double[][] margins;
    private final Node[] maxCardOrdering;
    private final BayesPm bayesPm;
    private final BayesIm bayesIm;
    private final Map<Node, TreeNode> treeNodes;

    public JunctionTreeAlgorithm(Graph graph, DataModel dataModel) {
        this.bayesPm = this.createBayesPm(dataModel, graph);
        this.bayesIm = this.createBayesIm(dataModel, this.bayesPm);
        this.treeNodes = new HashMap<Node, TreeNode>();
        int numOfNodes = graph.getNumNodes();
        this.graphNodes = this.bayesIm.getDag().getNodes().toArray(new Node[numOfNodes]);
        this.margins = new double[numOfNodes][];
        this.maxCardOrdering = new Node[numOfNodes];
        this.root = this.buildJunctionTree();
        this.initialize();
    }

    public JunctionTreeAlgorithm(BayesIm bayesIm) {
        this.bayesPm = bayesIm.getBayesPm();
        this.bayesIm = bayesIm;
        this.treeNodes = new HashMap<Node, TreeNode>();
        int numOfNodes = this.bayesPm.getDag().getNumNodes();
        this.graphNodes = bayesIm.getDag().getNodes().toArray(new Node[numOfNodes]);
        this.margins = new double[numOfNodes][];
        this.maxCardOrdering = new Node[numOfNodes];
        this.root = this.buildJunctionTree();
        this.initialize();
    }

    private void initialize() {
        for (int i = this.maxCardOrdering.length - 1; i >= 0; --i) {
            TreeNode treeNode = this.treeNodes.get(this.maxCardOrdering[i]);
            if (treeNode == null) continue;
            treeNode.initializeUp();
        }
        for (Node node : this.maxCardOrdering) {
            TreeNode treeNode = this.treeNodes.get(node);
            if (treeNode == null) continue;
            treeNode.initializeDown(false);
        }
    }

    private TreeNode buildJunctionTree() {
        Graph undirectedGraph = GraphTools.moralize(this.bayesIm.getDag());
        this.computeMaximumCardinalityOrdering(undirectedGraph, this.maxCardOrdering);
        GraphTools.fillIn(undirectedGraph, this.maxCardOrdering);
        this.computeMaximumCardinalityOrdering(undirectedGraph, this.maxCardOrdering);
        Map<Node, Set<Node>> cliques = GraphTools.getCliques(this.maxCardOrdering, undirectedGraph);
        Map<Node, Set<Node>> separators = GraphTools.getSeparators(this.maxCardOrdering, cliques);
        Map<Node, Node> parentCliques = GraphTools.getCliqueTree(this.maxCardOrdering, cliques, separators);
        HashSet<Node> finishedCalculated = new HashSet<Node>();
        for (Node node : this.maxCardOrdering) {
            if (!cliques.containsKey(node)) continue;
            this.treeNodes.put(node, new TreeNode(cliques.get(node), finishedCalculated));
        }
        for (Node node : this.maxCardOrdering) {
            if (!cliques.containsKey(node) || !parentCliques.containsKey(node)) continue;
            TreeNode parent = this.treeNodes.get(parentCliques.get(node));
            TreeNode treeNode = this.treeNodes.get(node);
            treeNode.setParentSeparator(new TreeSeparator(separators.get(node), treeNode, parent));
            parent.addChildClique(treeNode);
        }
        TreeNode rootNode = null;
        for (Node node : this.treeNodes.keySet()) {
            if (parentCliques.containsKey(node)) continue;
            rootNode = this.treeNodes.get(node);
        }
        return rootNode;
    }

    private void computeMaximumCardinalityOrdering(Graph graph, Node[] nodes) {
        HashSet<Node> numbered = new HashSet<Node>();
        for (int i = 0; i < nodes.length; ++i) {
            Node maxCardinalityNode = null;
            int maxCardinality = -1;
            for (Node v : graph.getNodes()) {
                if (numbered.contains(v)) continue;
                int cardinality = (int)graph.getAdjacentNodes(v).stream().filter(numbered::contains).count();
                if (cardinality <= maxCardinality) continue;
                maxCardinality = cardinality;
                maxCardinalityNode = v;
            }
            nodes[i] = maxCardinalityNode;
            numbered.add(maxCardinalityNode);
        }
    }

    private BayesPm createBayesPm(DataModel dataModel, Graph graph) {
        Dag dag = new Dag(dataModel.getVariables());
        new Dag(graph).getEdges().forEach(edge -> {
            Node node1 = dag.getNode(edge.getNode1().getName());
            Node node2 = dag.getNode(edge.getNode2().getName());
            Endpoint endpoint1 = edge.getEndpoint1();
            Endpoint endpoint2 = edge.getEndpoint2();
            dag.addEdge(new Edge(node1, node2, endpoint1, endpoint2));
        });
        return new BayesPm(dag);
    }

    private BayesIm createBayesIm(DataModel dataModel, BayesPm bayesPm) {
        return new EmBayesEstimator(bayesPm, (DataSet)dataModel).getEstimatedIm();
    }

    private Node[] toArray(Set<Node> nodes) {
        int size = nodes.size();
        Node[] order = new Node[size];
        int index = 0;
        for (Node node : this.graphNodes) {
            if (!nodes.contains(node)) continue;
            order[index++] = node;
            if (index == size) break;
        }
        return order;
    }

    private void normalize(double[] values) {
        double sum = 0.0;
        for (double value : values) {
            sum += value;
        }
        int i = 0;
        while (i < values.length) {
            int n = i++;
            values[n] = values[n] / sum;
        }
    }

    private int getCardinality(Set<Node> nodes) {
        int count = 1;
        count = nodes.stream().map(this.bayesPm::getNumCategories).reduce(count, (accumulator, element) -> accumulator * element);
        return count;
    }

    private void updateValues(int size, int[] values, Node[] nodes) {
        int j;
        int n = j = size - 1;
        values[n] = values[n] + 1;
        while (j >= 0 && values[j] == this.bayesPm.getNumCategories(nodes[j])) {
            values[j] = 0;
            if (--j < 0) continue;
            int n2 = j;
            values[n2] = values[n2] + 1;
        }
    }

    private int getIndexOfCPT(Node[] nodes, int[] values, Node[] order) {
        int index = 0;
        int j = 0;
        for (int i = 0; i < order.length && j < nodes.length; ++i) {
            if (order[i] != nodes[j]) continue;
            index *= this.bayesPm.getNumCategories(nodes[j]);
            index += values[i];
            ++j;
        }
        return index;
    }

    private int getIndexOfCPT(Node[] nodes, int[] values) {
        int index = 0;
        for (int i = 0; i < nodes.length; ++i) {
            index *= this.bayesPm.getNumCategories(nodes[i]);
            index += values[i];
        }
        return index;
    }

    private void clear(double[] array) {
        Arrays.fill(array, 0.0);
    }

    private TreeNode getCliqueContainsNode(Node node) {
        for (Node k : this.graphNodes) {
            if (!this.treeNodes.containsKey(k) || !this.treeNodes.get(k).contains(node)) continue;
            return this.treeNodes.get(k);
        }
        return null;
    }

    private void validate(int iNode) {
        int maxIndex = this.margins.length - 1;
        if (iNode < 0 || iNode > maxIndex) {
            String msg = String.format("Invalid node index %d. Node index must be between 0 and %d.", iNode, maxIndex);
            throw new IllegalArgumentException(msg);
        }
    }

    private void validate(int iNode, int value) {
        this.validate(iNode);
        int maxValue = this.margins[iNode].length - 1;
        if (value < 0 || value > maxValue) {
            String msg = String.format("Invalid value %d for node index %d. Value must be between 0 and %d.", value, iNode, maxValue);
            throw new IllegalArgumentException(msg);
        }
    }

    private void validate(int[] nodes) {
        if (nodes == null) {
            throw new IllegalArgumentException("Node indices cannot be null.");
        }
        if (nodes.length == 0) {
            throw new IllegalArgumentException("Node indices are required.");
        }
        if (nodes.length > this.graphNodes.length) {
            String msg = String.format("Number of nodes cannot exceed %d.", this.graphNodes.length);
            throw new IllegalArgumentException(msg);
        }
    }

    private void validate(int[] nodes, int[] values) {
        this.validate(nodes);
        if (values == null) {
            throw new IllegalArgumentException("Node values cannot be null.");
        }
        if (values.length == 0) {
            throw new IllegalArgumentException("Node values are required.");
        }
        if (values.length != nodes.length) {
            throw new IllegalArgumentException("Number of nodes values must be equal to the number of nodes.");
        }
        for (int i = 0; i < nodes.length; ++i) {
            this.validate(nodes[i], values[i]);
        }
    }

    private void validateAll(int[] values) {
        if (values == null) {
            throw new IllegalArgumentException("Node values cannot be null.");
        }
        if (values.length == 0) {
            throw new IllegalArgumentException("Node values are required.");
        }
        if (values.length != this.graphNodes.length) {
            throw new IllegalArgumentException("Number of nodes values must be equal to the number of nodes.");
        }
        for (int i = 0; i < values.length; ++i) {
            int maxValue = this.margins[i].length - 1;
            if (values[i] >= 0 || values[i] <= maxValue) continue;
            String msg = String.format("Invalid value %d for node index %d. Value must be between 0 and %d.", values[i], i, maxValue);
            throw new IllegalArgumentException(msg);
        }
    }

    public void setEvidence(int iNode, int value) {
        this.validate(iNode, value);
        Node node = this.graphNodes[iNode];
        TreeNode treeNode = this.getCliqueContainsNode(node);
        if (treeNode == null) {
            String msg = String.format("Node %s is not in junction tree.", node.getName());
            throw new IllegalArgumentException(msg);
        }
        treeNode.setEvidence(node, value);
    }

    private double[] getConditionalProbabilities(int iNode, int parent, int parentValue) {
        this.validate(iNode);
        this.validate(parent, parentValue);
        this.setEvidence(parent, parentValue);
        double[] condProbs = new double[this.margins[iNode].length];
        System.arraycopy(this.margins[iNode], 0, condProbs, 0, condProbs.length);
        this.normalize(condProbs);
        this.initialize();
        return condProbs;
    }

    private boolean isAllNodes(int[] nodes) {
        if (nodes.length == this.graphNodes.length) {
            long total;
            long sum = Arrays.stream(nodes).sum();
            return sum == (total = (long)((this.graphNodes.length - 1) * this.graphNodes.length / 2));
        }
        return false;
    }

    public double getConditionalProbabilities(int[] nodes, int[] values, int[] parents, int[] parentValues) {
        this.validate(nodes, values);
        this.validate(parents, parentValues);
        for (int i = 0; i < parents.length; ++i) {
            this.setEvidence(parents[i], parentValues[i]);
        }
        double prob = 1.0;
        for (int i = 0; i < nodes.length; ++i) {
            double[] marg = this.margins[nodes[i]];
            double[] condProbs = new double[marg.length];
            System.arraycopy(marg, 0, condProbs, 0, marg.length);
            this.normalize(condProbs);
            prob *= condProbs[values[i]];
        }
        this.initialize();
        return prob;
    }

    public double[] getConditionalProbabilities(int iNode, int[] parents, int[] parentValues) {
        this.validate(iNode);
        this.validate(parents, parentValues);
        if (parents.length == 1) {
            return this.getConditionalProbabilities(iNode, parents[0], parentValues[0]);
        }
        for (int i = 0; i < parents.length; ++i) {
            this.setEvidence(parents[i], parentValues[i]);
        }
        double[] condProbs = new double[this.margins[iNode].length];
        System.arraycopy(this.margins[iNode], 0, condProbs, 0, condProbs.length);
        this.normalize(condProbs);
        this.initialize();
        return condProbs;
    }

    public double getConditionalProbability(int iNode, int value, int[] parents, int[] parentValues) {
        this.validate(iNode, value);
        return this.getConditionalProbabilities(iNode, parents, parentValues)[value];
    }

    public double getJointProbabilityAll(int[] nodeValues) {
        this.validateAll(nodeValues);
        double logJointClusterPotentials = this.root.getLogJointClusterPotentials(nodeValues);
        double logJointSeparatorPotentials = this.root.getLogJointSeparatorPotentials(nodeValues);
        return FastMath.exp(logJointClusterPotentials - logJointSeparatorPotentials);
    }

    public double getJointProbability(int[] nodes, int[] values) {
        this.validate(nodes, values);
        if (this.isAllNodes(nodes)) {
            return this.getJointProbabilityAll(values);
        }
        for (int i = 0; i < nodes.length; ++i) {
            this.setEvidence(nodes[i], values[i]);
        }
        double prob = 0.0;
        int index = 0;
        for (int i = 0; i < this.margins.length; ++i) {
            if (i < nodes.length && i == nodes[index]) {
                ++index;
                continue;
            }
            prob += Arrays.stream(this.margins[i]).sum();
            break;
        }
        this.initialize();
        return prob;
    }

    public double[] getMarginalProbability(int iNode) {
        this.validate(iNode);
        double[] marginals = new double[this.margins[iNode].length];
        System.arraycopy(this.margins[iNode], 0, marginals, 0, marginals.length);
        this.normalize(marginals);
        return marginals;
    }

    public double getMarginalProbability(int iNode, int value) {
        this.validate(iNode, value);
        return this.margins[iNode][value];
    }

    public List<Node> getNodes() {
        return Collections.unmodifiableList(Arrays.asList(this.graphNodes));
    }

    public int getNumberOfNodes() {
        return this.graphNodes.length;
    }

    public String toString() {
        return this.root.toString().trim();
    }

    private class TreeNode
    implements TetradSerializable {
        private static final long serialVersionUID = 23L;
        private final double[] prob;
        private final double[][] margProb;
        private final double[] potentials;
        private final List<TreeNode> children;
        private final int cardinality;
        private final Set<Node> clique;
        private final Node[] nodes;
        private TreeSeparator parentSeparator;

        public TreeNode(Set<Node> clique, Set<Node> finishedCalculated) {
            this.clique = clique;
            this.nodes = JunctionTreeAlgorithm.this.toArray(clique);
            this.children = new LinkedList<TreeNode>();
            this.cardinality = JunctionTreeAlgorithm.this.getCardinality(clique);
            this.potentials = new double[this.cardinality];
            this.prob = new double[this.cardinality];
            this.margProb = new double[this.nodes.length][];
            for (int iNode = 0; iNode < this.nodes.length; ++iNode) {
                this.margProb[iNode] = new double[JunctionTreeAlgorithm.this.bayesPm.getNumCategories(this.nodes[iNode])];
            }
            this.calculatePotentials(clique, finishedCalculated);
        }

        private void calculatePotentials(Set<Node> cliques, Set<Node> finishedCalculated) {
            Graph dag = JunctionTreeAlgorithm.this.bayesIm.getDag();
            HashSet<Node> nodesWithParentsInCluster = new HashSet<Node>();
            for (Node node : this.nodes) {
                if (finishedCalculated.contains(node) || !cliques.containsAll(dag.getParents(node))) continue;
                nodesWithParentsInCluster.add(node);
                finishedCalculated.add(node);
            }
            int size = this.nodes.length;
            int[] values = new int[size];
            for (int i = 0; i < this.cardinality; ++i) {
                int indexCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values);
                this.potentials[indexCPT] = 1.0;
                for (int iNode = 0; iNode < this.nodes.length; ++iNode) {
                    Node node = this.nodes[iNode];
                    if (!nodesWithParentsInCluster.contains(node)) continue;
                    int nodeIndex = JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(node);
                    int rowIndex = this.getRowIndex(nodeIndex, values, this.nodes);
                    int n = indexCPT;
                    this.potentials[n] = this.potentials[n] * JunctionTreeAlgorithm.this.bayesIm.getProbability(nodeIndex, rowIndex, values[iNode]);
                }
                JunctionTreeAlgorithm.this.updateValues(size, values, this.nodes);
            }
        }

        public void initializeUp() {
            System.arraycopy(this.potentials, 0, this.prob, 0, this.cardinality);
            int size = this.nodes.length;
            int[] values = new int[size];
            this.children.forEach(childNode -> {
                TreeSeparator separator = childNode.parentSeparator;
                for (int i = 0; i < this.cardinality; ++i) {
                    int indexNodeCPT;
                    int indexSepCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(separator.nodes, values, this.nodes);
                    int n = indexNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values);
                    this.prob[n] = this.prob[n] * separator.childPotentials[indexSepCPT];
                    JunctionTreeAlgorithm.this.updateValues(size, values, this.nodes);
                }
            });
            if (this.parentSeparator != null) {
                this.parentSeparator.updateFromChild();
            }
        }

        public void initializeDown(boolean recursively) {
            if (this.parentSeparator != null) {
                this.parentSeparator.updateFromParent();
                int size = this.nodes.length;
                int[] values = new int[size];
                for (int i = 0; i < this.cardinality; ++i) {
                    int indexSepCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.parentSeparator.nodes, values, this.nodes);
                    int indexNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values);
                    if (this.parentSeparator.childPotentials[indexSepCPT] > 0.0) {
                        int n = indexNodeCPT;
                        this.prob[n] = this.prob[n] * (this.parentSeparator.parentPotentials[indexSepCPT] / this.parentSeparator.childPotentials[indexSepCPT]);
                    } else {
                        this.prob[indexNodeCPT] = 0.0;
                    }
                    JunctionTreeAlgorithm.this.updateValues(size, values, this.nodes);
                }
                this.parentSeparator.updateFromChild();
            }
            this.calculateMarginalProbabilities();
            if (recursively) {
                this.children.forEach(childNode -> childNode.initializeDown(true));
            }
        }

        private void calculateMarginalProbabilities() {
            for (int iNode = 0; iNode < this.nodes.length; ++iNode) {
                JunctionTreeAlgorithm.this.clear(this.margProb[iNode]);
            }
            int size = this.nodes.length;
            int[] values = new int[size];
            for (int i = 0; i < this.cardinality; ++i) {
                int indexNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values);
                for (int iNode = 0; iNode < size; ++iNode) {
                    double[] dArray = this.margProb[iNode];
                    int n = values[iNode];
                    dArray[n] = dArray[n] + this.prob[indexNodeCPT];
                }
                JunctionTreeAlgorithm.this.updateValues(size, values, this.nodes);
            }
            for (int iNode = 0; iNode < size; ++iNode) {
                ((JunctionTreeAlgorithm)JunctionTreeAlgorithm.this).margins[((JunctionTreeAlgorithm)JunctionTreeAlgorithm.this).bayesIm.getNodeIndex((Node)this.nodes[iNode])] = this.margProb[iNode];
            }
        }

        private int getRowIndex(int nodeIndex, int[] values, Node[] nodes) {
            int[] parents;
            int index = 0;
            for (int parent : parents = JunctionTreeAlgorithm.this.bayesIm.getParents(nodeIndex)) {
                Node node = JunctionTreeAlgorithm.this.bayesIm.getNode(parent);
                index *= JunctionTreeAlgorithm.this.bayesPm.getNumCategories(node);
                for (int j = 0; j < nodes.length; ++j) {
                    if (node != nodes[j]) continue;
                    index += values[j];
                }
            }
            return index;
        }

        private int getNodeIndex(Node node) {
            for (int i = 0; i < this.nodes.length; ++i) {
                if (this.nodes[i] != node) continue;
                return i;
            }
            return -1;
        }

        public void setEvidence(Node node, int value) {
            int nodeIndex = this.getNodeIndex(node);
            if (nodeIndex < 0) {
                String msg = String.format("Unable to find node %s in clique.", node.getName());
                throw new IllegalArgumentException(msg);
            }
            int size = this.nodes.length;
            int[] values = new int[size];
            for (int i = 0; i < this.cardinality; ++i) {
                if (values[nodeIndex] != value) {
                    int indexNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values);
                    this.prob[indexNodeCPT] = 0.0;
                }
                JunctionTreeAlgorithm.this.updateValues(size, values, this.nodes);
            }
            this.calculateMarginalProbabilities();
            this.updateEvidence(this);
        }

        private void updateEvidence(TreeNode source) {
            if (source != this) {
                int size = this.nodes.length;
                int[] values = new int[size];
                for (int i = 0; i < this.cardinality; ++i) {
                    int indexNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values);
                    int indexChildNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(source.parentSeparator.nodes, values, this.nodes);
                    if (source.parentSeparator.parentPotentials[indexChildNodeCPT] != 0.0) {
                        int n = indexNodeCPT;
                        this.prob[n] = this.prob[n] * (source.parentSeparator.childPotentials[indexChildNodeCPT] / source.parentSeparator.parentPotentials[indexChildNodeCPT]);
                    } else {
                        this.prob[indexNodeCPT] = 0.0;
                    }
                    JunctionTreeAlgorithm.this.updateValues(size, values, this.nodes);
                }
                this.calculateMarginalProbabilities();
            }
            this.children.stream().filter(e -> e != source).forEach(e -> e.initializeDown(true));
            if (this.parentSeparator != null) {
                this.parentSeparator.updateFromChild();
                this.parentSeparator.parentNode.updateEvidence(this);
                this.parentSeparator.updateFromParent();
            }
        }

        private double getLogJointSeparatorPotentials(int[] nodeValues) {
            double logJointPotentials = FastMath.log(1.0);
            if (this.parentSeparator != null) {
                Node[] parentNodes = this.parentSeparator.nodes;
                int size = parentNodes.length;
                int[] values = new int[size];
                for (int iNode = 0; iNode < size; ++iNode) {
                    values[iNode] = nodeValues[JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(parentNodes[iNode])];
                }
                logJointPotentials += FastMath.log(this.parentSeparator.childPotentials[JunctionTreeAlgorithm.this.getIndexOfCPT(parentNodes, values)]);
            }
            logJointPotentials = this.children.stream().map(child -> child.getLogJointSeparatorPotentials(nodeValues)).reduce(logJointPotentials, Double::sum);
            return logJointPotentials;
        }

        private double getLogJointClusterPotentials(int[] nodeValues) {
            int size = this.nodes.length;
            int[] values = new int[size];
            for (int iNode = 0; iNode < size; ++iNode) {
                values[iNode] = nodeValues[JunctionTreeAlgorithm.this.bayesIm.getNodeIndex(this.nodes[iNode])];
            }
            double logJointPotentials = FastMath.log(this.prob[JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values)]);
            logJointPotentials = this.children.stream().map(child -> child.getLogJointClusterPotentials(nodeValues)).reduce(logJointPotentials, Double::sum);
            return logJointPotentials;
        }

        public void setParentSeparator(TreeSeparator parentSeparator) {
            this.parentSeparator = parentSeparator;
        }

        public void addChildClique(TreeNode child) {
            this.children.add(child);
        }

        public Set<Node> getClique() {
            return this.clique;
        }

        public boolean contains(Node node) {
            return this.clique.contains(node);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < this.nodes.length; ++i) {
                sb.append(this.nodes[i].getName());
                sb.append(": ");
                sb.append(Arrays.stream(this.margProb[i]).mapToObj(String::valueOf).collect(Collectors.joining(" ")));
                sb.append('\n');
            }
            this.children.forEach(childNode -> {
                sb.append("----------------\n");
                sb.append(childNode.toString());
            });
            return sb.toString();
        }
    }

    private class TreeSeparator
    implements TetradSerializable {
        private static final long serialVersionUID = 23L;
        private final double[] parentPotentials;
        private final double[] childPotentials;
        private final Node[] nodes;
        private final TreeNode childNode;
        private final TreeNode parentNode;

        public TreeSeparator(Set<Node> separator, TreeNode childNode, TreeNode parentNode) {
            this.childNode = childNode;
            this.parentNode = parentNode;
            this.nodes = JunctionTreeAlgorithm.this.toArray(separator);
            int cardinality = JunctionTreeAlgorithm.this.getCardinality(separator);
            this.parentPotentials = new double[cardinality];
            this.childPotentials = new double[cardinality];
        }

        public void update(TreeNode node, double[] potentials) {
            JunctionTreeAlgorithm.this.clear(potentials);
            if (node.prob != null) {
                int size = node.nodes.length;
                int[] values = new int[size];
                for (int i = 0; i < node.cardinality; ++i) {
                    int indexSepCPT;
                    int indexNodeCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(node.nodes, values);
                    int n = indexSepCPT = JunctionTreeAlgorithm.this.getIndexOfCPT(this.nodes, values, node.nodes);
                    potentials[n] = potentials[n] + node.prob[indexNodeCPT];
                    JunctionTreeAlgorithm.this.updateValues(size, values, node.nodes);
                }
            }
        }

        public void updateFromParent() {
            this.update(this.parentNode, this.parentPotentials);
        }

        public void updateFromChild() {
            this.update(this.childNode, this.childPotentials);
        }
    }
}

