/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public final class TestUpdaterJointMarginal
extends TestCase {
    public TestUpdaterJointMarginal(String name) {
        super(name);
    }

    public static void testEstimate1() {
        Dag graph = new Dag();
        GraphNode L1 = new GraphNode("L1");
        GraphNode X1 = new GraphNode("X1");
        GraphNode X2 = new GraphNode("X2");
        GraphNode X3 = new GraphNode("X3");
        L1.setNodeType(NodeType.MEASURED);
        X1.setNodeType(NodeType.MEASURED);
        X2.setNodeType(NodeType.MEASURED);
        X3.setNodeType(NodeType.MEASURED);
        graph.addNode(L1);
        graph.addNode(X1);
        graph.addNode(X2);
        graph.addNode(X3);
        graph.addDirectedEdge(L1, X1);
        graph.addDirectedEdge(L1, X2);
        graph.addDirectedEdge(L1, X3);
        BayesPm bayesPm = new BayesPm(graph);
        bayesPm.setNumCategories(L1, 2);
        bayesPm.setNumCategories(X1, 2);
        bayesPm.setNumCategories(X2, 2);
        bayesPm.setNumCategories(X3, 2);
        MlBayesIm estimatedIm = new MlBayesIm(bayesPm);
        Node l1Node = graph.getNode("L1");
        int l1Index = estimatedIm.getNodeIndex(l1Node);
        Node x1Node = graph.getNode("X1");
        int x1Index = estimatedIm.getNodeIndex(x1Node);
        Node x2Node = graph.getNode("X2");
        int x2Index = estimatedIm.getNodeIndex(x2Node);
        Node x3Node = graph.getNode("X3");
        int x3Index = estimatedIm.getNodeIndex(x3Node);
        estimatedIm.setProbability(l1Index, 0, 0, 0.5);
        estimatedIm.setProbability(l1Index, 0, 1, 0.5);
        estimatedIm.setProbability(x1Index, 0, 0, 0.33333);
        estimatedIm.setProbability(x1Index, 0, 1, 0.66667);
        estimatedIm.setProbability(x1Index, 1, 0, 0.66667);
        estimatedIm.setProbability(x1Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x2Index, 1, 0, 0.66667);
        estimatedIm.setProbability(x2Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x2Index, 0, 0, 0.33333);
        estimatedIm.setProbability(x2Index, 0, 1, 0.66667);
        estimatedIm.setProbability(x3Index, 1, 0, 0.66667);
        estimatedIm.setProbability(x3Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x3Index, 0, 0, 0.33333);
        estimatedIm.setProbability(x3Index, 0, 1, 0.66667);
        Evidence evidence = Evidence.tautology(estimatedIm);
        evidence.getProposition().setCategory(x1Index, 0);
        evidence.getProposition().setCategory(x2Index, 0);
        evidence.getProposition().setCategory(x3Index, 0);
        RowSummingExactUpdater rseu = new RowSummingExactUpdater(estimatedIm);
        rseu.setEvidence(evidence);
        int[] vars1 = new int[]{l1Index};
        int[] vals1 = new int[]{0};
        double p1 = rseu.getJointMarginal(vars1, vals1);
        TestUpdaterJointMarginal.assertEquals(0.1111, p1, 1.0E-4);
        System.out.println("p1 = " + p1);
        int[] vars2 = new int[]{l1Index, x1Index};
        int[] vals2 = new int[]{0, 0};
        double p2 = rseu.getJointMarginal(vars2, vals2);
        TestUpdaterJointMarginal.assertEquals(0.1111, p2, 1.0E-4);
        System.out.println("p2 = " + p2);
        int[] vals3 = new int[]{1, 0};
        double p3 = rseu.getJointMarginal(vars2, vals3);
        TestUpdaterJointMarginal.assertEquals(0.8888, p3, 1.0E-4);
        System.out.println("p3 = " + p3);
    }

    public static Test suite() {
        return new TestSuite(TestUpdaterJointMarginal.class);
    }
}

