/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.CptInvariantUpdater;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.util.TetradLogger;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public final class TestRowSummingUpdater
extends TestCase {
    public TestRowSummingUpdater(String name) {
        super(name);
    }

    @Override
    public void setUp() throws Exception {
        TetradLogger.getInstance().addOutputStream(System.out);
        TetradLogger.getInstance().setForceLog(true);
    }

    @Override
    public void tearDown() {
        TetradLogger.getInstance().setForceLog(false);
        TetradLogger.getInstance().removeOutputStream(System.out);
    }

    public static void testUpdate1() {
        BayesIm bayesIm = TestRowSummingUpdater.sampleBayesIm1();
        RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm);
        Evidence evidence = Evidence.tautology(bayesIm);
        int xIndex = evidence.getNodeIndex("x");
        int zIndex = evidence.getNodeIndex("z");
        int valueIndex = evidence.getCategoryIndex("z", "1");
        evidence.getProposition().setCategory(zIndex, valueIndex);
        updater.setEvidence(evidence);
        BayesIm updatedIm = updater.getUpdatedBayesIm();
        System.out.println(bayesIm.getBayesPm());
        System.out.println(bayesIm);
        System.out.println(updatedIm);
        TestRowSummingUpdater.assertEquals(0.125, updatedIm.getProbability(0, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.875, updatedIm.getProbability(0, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(1, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(1, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(1, 1, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(1, 1, 1), 0.001);
        System.out.println(updater.getMarginal(xIndex, 0));
        CptInvariantUpdater updater2 = new CptInvariantUpdater(bayesIm);
        Evidence evidence2 = new Evidence(evidence, bayesIm);
        updater2.setEvidence(evidence2);
        System.out.println(updater2.getMarginal(xIndex, 0));
    }

    public static void testUpdate2() {
        BayesIm bayesIm = TestRowSummingUpdater.sampleBayesIm2();
        RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm);
        Evidence evidence = Evidence.tautology(bayesIm);
        int nodeIndex = evidence.getNodeIndex("c");
        int valueIndex = evidence.getCategoryIndex("c", "1");
        evidence.getProposition().setCategory(nodeIndex, valueIndex);
        updater.setEvidence(evidence);
        BayesIm updatedIm = updater.getUpdatedBayesIm();
        System.out.println(bayesIm.getBayesPm());
        System.out.println(bayesIm);
        System.out.println(updatedIm);
        TestRowSummingUpdater.assertEquals(0.275, updatedIm.getProbability(0, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.725, updatedIm.getProbability(0, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0556, updatedIm.getProbability(1, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.6667, updatedIm.getProbability(1, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.2778, updatedIm.getProbability(1, 0, 2), 0.001);
        TestRowSummingUpdater.assertEquals(0.7869, updatedIm.getProbability(1, 1, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.0656, updatedIm.getProbability(1, 1, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.1475, updatedIm.getProbability(1, 1, 2), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(2, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(2, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(2, 1, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(2, 1, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(2, 2, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(2, 2, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(2, 3, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(2, 3, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(2, 4, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(2, 4, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(2, 5, 0), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(2, 5, 1), 0.001);
    }

    public static void testUpdate3() {
        BayesIm bayesIm = TestRowSummingUpdater.sampleBayesIm2();
        RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm);
        Evidence evidence = Evidence.tautology(bayesIm);
        int nodeIndex = evidence.getNodeIndex("b");
        int valueIndex = evidence.getCategoryIndex("b", "0");
        evidence.getProposition().setCategory(nodeIndex, valueIndex);
        System.out.println(evidence);
        updater.setEvidence(evidence);
        BayesIm updatedIm = updater.getUpdatedBayesIm();
        System.out.println(bayesIm.getBayesPm());
        System.out.println(bayesIm);
        System.out.println(updatedIm);
        TestRowSummingUpdater.assertEquals(0.1765, updatedIm.getProbability(0, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.8235, updatedIm.getProbability(0, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(1, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(1, 0, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(1, 0, 2), 0.001);
        TestRowSummingUpdater.assertEquals(1.0, updatedIm.getProbability(1, 1, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(1, 1, 1), 0.001);
        TestRowSummingUpdater.assertEquals(0.0, updatedIm.getProbability(1, 1, 2), 0.001);
        TestRowSummingUpdater.assertEquals(0.9, updatedIm.getProbability(2, 0, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.1, updatedIm.getProbability(2, 0, 1), 0.001);
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 1, 0)));
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 1, 1)));
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 2, 0)));
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 2, 1)));
        TestRowSummingUpdater.assertEquals(0.2, updatedIm.getProbability(2, 3, 0), 0.001);
        TestRowSummingUpdater.assertEquals(0.8, updatedIm.getProbability(2, 3, 1), 0.001);
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 4, 0)));
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 4, 1)));
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 5, 0)));
        TestRowSummingUpdater.assertTrue(Double.isNaN(updatedIm.getProbability(2, 5, 1)));
    }

    public static void testUpdate4() {
        GraphNode x0Node = new GraphNode("X0");
        GraphNode x1Node = new GraphNode("X1");
        GraphNode x2Node = new GraphNode("X2");
        GraphNode x3Node = new GraphNode("X3");
        Dag graph = new Dag();
        graph.addNode(x0Node);
        graph.addNode(x1Node);
        graph.addNode(x2Node);
        graph.addNode(x3Node);
        graph.addDirectedEdge(x0Node, x1Node);
        graph.addDirectedEdge(x0Node, x2Node);
        graph.addDirectedEdge(x1Node, x3Node);
        graph.addDirectedEdge(x2Node, x3Node);
        System.out.println(graph);
        BayesPm bayesPm = new BayesPm(graph);
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        int x2 = bayesIm.getNodeIndex(x2Node);
        int x3 = bayesIm.getNodeIndex(x3Node);
        System.out.println(bayesIm);
        Evidence evidence = Evidence.tautology(bayesIm);
        evidence.getProposition().setCategory(x2, 0);
        System.out.println(evidence);
        CptInvariantUpdater updater1 = new CptInvariantUpdater(bayesIm);
        updater1.setEvidence(evidence);
        RowSummingExactUpdater updater2 = new RowSummingExactUpdater(bayesIm);
        updater2.setEvidence(evidence);
        double marginal1 = updater1.getMarginal(x3, 0);
        double marginal2 = updater2.getMarginal(x3, 0);
        System.out.println("Marginal from CPT Inv = " + marginal1);
        System.out.println("Marginal from Row Summer = " + marginal2);
        TestRowSummingUpdater.assertEquals(marginal1, marginal2, 1.0E-6);
    }

    public static void testUpdate5() {
        GraphNode x0Node = new GraphNode("X0");
        GraphNode x1Node = new GraphNode("X1");
        GraphNode x2Node = new GraphNode("X2");
        GraphNode x3Node = new GraphNode("X3");
        GraphNode x4Node = new GraphNode("X4");
        Dag graph = new Dag();
        graph.addNode(x0Node);
        graph.addNode(x1Node);
        graph.addNode(x2Node);
        graph.addNode(x3Node);
        graph.addNode(x4Node);
        graph.addDirectedEdge(x0Node, x1Node);
        graph.addDirectedEdge(x0Node, x2Node);
        graph.addDirectedEdge(x1Node, x3Node);
        graph.addDirectedEdge(x2Node, x3Node);
        graph.addDirectedEdge(x4Node, x0Node);
        graph.addDirectedEdge(x4Node, x2Node);
        System.out.println(graph);
        BayesPm bayesPm = new BayesPm(graph);
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        int x1 = bayesIm.getNodeIndex(x1Node);
        int x2 = bayesIm.getNodeIndex(x2Node);
        int x3 = bayesIm.getNodeIndex(x3Node);
        System.out.println(bayesIm);
        Evidence evidence = Evidence.tautology(bayesIm);
        evidence.getProposition().setCategory(x1, 1);
        evidence.getProposition().setCategory(x2, 0);
        evidence.getNodeIndex("X1");
        System.out.println(evidence);
        CptInvariantUpdater updater1 = new CptInvariantUpdater(bayesIm);
        updater1.setEvidence(evidence);
        RowSummingExactUpdater updater2 = new RowSummingExactUpdater(bayesIm);
        updater2.setEvidence(evidence);
        double marginal1 = updater1.getMarginal(x3, 0);
        double marginal2 = updater2.getMarginal(x3, 0);
        System.out.println("Marginal from CPT Inv = " + marginal1);
        System.out.println("Marginal from Row Summer = " + marginal2);
        TestRowSummingUpdater.assertEquals(marginal1, marginal2, 1.0E-6);
    }

    private static BayesIm sampleBayesIm1() {
        GraphNode x = new GraphNode("x");
        GraphNode z = new GraphNode("z");
        Dag graph = new Dag();
        graph.addNode(x);
        graph.addNode(z);
        graph.addDirectedEdge(x, z);
        System.out.println(graph);
        BayesPm bayesPm = new BayesPm(graph);
        MlBayesIm bayesIm1 = new MlBayesIm(bayesPm);
        bayesIm1.setProbability(0, 0, 0, 0.3);
        bayesIm1.setProbability(0, 0, 1, 0.7);
        bayesIm1.setProbability(1, 0, 0, 0.8);
        bayesIm1.setProbability(1, 0, 1, 0.2);
        bayesIm1.setProbability(1, 1, 0, 0.4);
        bayesIm1.setProbability(1, 1, 1, 0.6);
        return bayesIm1;
    }

    private static BayesIm sampleBayesIm2() {
        GraphNode a = new GraphNode("a");
        GraphNode b = new GraphNode("b");
        GraphNode c = new GraphNode("c");
        Dag graph = new Dag();
        graph.addNode(a);
        graph.addNode(b);
        graph.addNode(c);
        graph.addDirectedEdge(a, b);
        graph.addDirectedEdge(a, c);
        graph.addDirectedEdge(b, c);
        System.out.println(graph);
        BayesPm bayesPm = new BayesPm(graph);
        bayesPm.setNumCategories(b, 3);
        MlBayesIm bayesIm1 = new MlBayesIm(bayesPm);
        bayesIm1.setProbability(0, 0, 0, 0.3);
        bayesIm1.setProbability(0, 0, 1, 0.7);
        bayesIm1.setProbability(1, 0, 0, 0.3);
        bayesIm1.setProbability(1, 0, 1, 0.4);
        bayesIm1.setProbability(1, 0, 2, 0.3);
        bayesIm1.setProbability(1, 1, 0, 0.6);
        bayesIm1.setProbability(1, 1, 1, 0.1);
        bayesIm1.setProbability(1, 1, 2, 0.3);
        bayesIm1.setProbability(2, 0, 0, 0.9);
        bayesIm1.setProbability(2, 0, 1, 0.1);
        bayesIm1.setProbability(2, 1, 0, 0.1);
        bayesIm1.setProbability(2, 1, 1, 0.9);
        bayesIm1.setProbability(2, 2, 0, 0.5);
        bayesIm1.setProbability(2, 2, 1, 0.5);
        bayesIm1.setProbability(2, 3, 0, 0.2);
        bayesIm1.setProbability(2, 3, 1, 0.8);
        bayesIm1.setProbability(2, 4, 0, 0.6);
        bayesIm1.setProbability(2, 4, 1, 0.4);
        bayesIm1.setProbability(2, 5, 0, 0.7);
        bayesIm1.setProbability(2, 5, 1, 0.3);
        return bayesIm1;
    }

    public static Test suite() {
        return new TestSuite(TestRowSummingUpdater.class);
    }
}

