/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.mb;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesEstimator;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.OnTheFlyMarginalCalculator;
import edu.cmu.tetrad.bayes.Proposition;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DataReader;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataWriter;
import edu.cmu.tetrad.data.DelimiterType;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.GesMbFilter;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.IndTestGSquare;
import edu.cmu.tetrad.search.MbClassify;
import edu.cmu.tetrad.search.MbUtils;
import edu.cmu.tetrad.search.Mbfs;
import edu.cmu.tetrad.search.mb.Mmmb;
import edu.cmu.tetrad.sem.LargeSemSimulator;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestMbClassify
extends TestCase {
    static int[][] testCrosstabs = new int[][]{{496, 53}, {93, 358}};
    static int[][] testCrosstabsNew = new int[][]{{38, 9, 7}, {10, 15, 14}, {5, 3, 59}};

    public TestMbClassify(String name) {
        super(name);
    }

    @Override
    public void setUp() throws Exception {
        TetradLogger.getInstance().addOutputStream(System.out);
        TetradLogger.getInstance().setForceLog(true);
    }

    @Override
    public void tearDown() {
        TetradLogger.getInstance().setForceLog(false);
        TetradLogger.getInstance().removeOutputStream(System.out);
    }

    public void test1() {
    }

    public void rtest1() {
        try {
            DataReader reader = new DataReader();
            reader.setDelimiter(DelimiterType.COMMA);
            reader.setIdsSupplied(true);
            reader.setIdLabel(null);
            DataSet data = reader.parseTabular(new File("test_data/soybean.data"));
            System.out.println(data);
            reader.setKnownVariables(data.getVariables());
            double alpha = 1.0E-6;
            int depth = 2;
            double prior = 1.0E-4;
            int maxMissing = 9;
            new MbClassify(data, data, "class", alpha, depth, prior, maxMissing).classify();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void rtest2() {
        String train = "test_data/markovBlanketTestDisc.dat";
        String test = "test_data/markovBlanketTestDisc.dat";
        String variable = "A6";
        String alpha = "0.01";
        String depth = "3";
        String prior = "0.001";
        String maxMissing = "4";
        MbClassify mbClassify = new MbClassify(train, test, variable, alpha, depth, prior, maxMissing);
        mbClassify.classify();
        int[][] crossTabs = mbClassify.crossTabulation();
        TestMbClassify.assertTrue(Arrays.equals(crossTabs[0], testCrosstabs[0]));
        TestMbClassify.assertTrue(Arrays.equals(crossTabs[1], testCrosstabs[1]));
    }

    public void rtest3() {
        String train = "test_data/sampledata.txt";
        String test = "test_data/sampledata.txt";
        String variable = "col2";
        String alpha = "0.05";
        String depth = "2";
        String prior = "0.05";
        String maxMissing = "1";
        MbClassify mbClassify = new MbClassify(train, test, variable, alpha, depth, prior, maxMissing);
        mbClassify.classify();
        int[][] crossTabsNew = mbClassify.crossTabulation();
        TestMbClassify.assertTrue(Arrays.equals(crossTabsNew[0], testCrosstabsNew[0]));
        TestMbClassify.assertTrue(Arrays.equals(crossTabsNew[1], testCrosstabsNew[1]));
        TestMbClassify.assertTrue(Arrays.equals(crossTabsNew[2], testCrosstabsNew[2]));
    }

    public void rtest5() {
        String train = "test_data/ar_test_01-20-07.txt";
        String test = "test_data/ar_test_01-20-07.txt";
        String variable = "X280";
        String alpha = "0.05";
        String depth = "2";
        String prior = "0.05";
        String maxMissing = "1";
        MbClassify mbClassify = new MbClassify(train, test, variable, alpha, depth, prior, maxMissing);
        mbClassify.classify();
    }

    public void makeData() {
        try {
            Dag randomGraph = GraphUtils.randomDag(1400, 0, 30, 9, 3, 9, false);
            DataSet dataSet = this.simulateContinuous(randomGraph, 300);
            FileWriter out = new FileWriter(new File("test_data/myout.dat"));
            DataWriter.writeRectangularData(dataSet, out, ',');
            ((Writer)out).close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private DataSet simulateContinuous(Dag randomGraph, int sampleSize) {
        LargeSemSimulator simulator = new LargeSemSimulator(randomGraph);
        return simulator.simulateDataAcyclic(sampleSize);
    }

    public void rtest6() {
        try {
            String trainPath = "test_data/sp1s_aa_train.txt";
            String target = "X3";
            double alpha = 0.05;
            int depth = 3;
            DataReader reader = new DataReader();
            reader.setVariablesSupplied(false);
            DataSet train = null;
            train = reader.parseTabular(new File(trainPath));
            IndTestFisherZ indTest = new IndTestFisherZ(train, alpha);
            long time0 = System.currentTimeMillis();
            System.out.println("Start");
            Mmmb search = new Mmmb(indTest, depth, false);
            List<Node> mb = search.findMb(target);
            long diff = System.currentTimeMillis() - time0;
            TetradLogger.getInstance().log("info", "Elapsed: " + diff);
            TetradLogger.getInstance().log("info", "MB = " + mb);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void rtest7() {
        try {
            String trainPath = "test_data/sp1s_aa_train.txt";
            String target = "X3";
            double alpha = 1.0E-4;
            int depth = 2;
            DataReader reader = new DataReader();
            reader.setDelimiter(DelimiterType.WHITESPACE);
            reader.setVariablesSupplied(false);
            DataSet train = reader.parseTabular(new File(trainPath));
            IndTestFisherZ indTest = new IndTestFisherZ(train, alpha);
            Mbfs search = new Mbfs(indTest, depth);
            search.findMb(target);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void rtestSimulatedClassify() {
        int numVars = 1000;
        int numEdges = 1000;
        int numTrainSamples = 1000;
        int numTestSamples = 1000;
        int minNumCategories = 2;
        int maxNumCategories = 2;
        double alpha = 0.001;
        int depth = 2;
        System.out.println("Number of variables = " + numVars);
        System.out.println("Number of randomly selected edges = " + numEdges);
        System.out.println("Number of training samples  = " + numTrainSamples);
        System.out.println("Number of test samples  = " + numTestSamples);
        System.out.println();
        System.out.println();
        System.out.println("... creating random DAG");
        Dag randomGraph = GraphUtils.randomDag(numVars, 0, numEdges, 40, 40, 40, false);
        System.out.println("... creating Bayes PM");
        BayesPm bayesPm = new BayesPm(randomGraph, minNumCategories, maxNumCategories);
        System.out.println("... creating Bayes IM");
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        System.out.println("... simulating data");
        DataSet trainData = bayesIm.simulateData(numTrainSamples, false);
        DataSet testData = bayesIm.simulateData(numTestSamples, false);
        List<Node> variablesTest = testData.getVariables();
        for (int e = 0; e < 10; ++e) {
            int i = RandomUtil.getInstance().nextInt(numVars);
            DiscreteVariable target = (DiscreteVariable)variablesTest.get(i);
            Node node = randomGraph.getNode(target.getName());
            Dag trueMb0 = GraphUtils.markovBlanketDag(node, randomGraph);
            if (trueMb0.getNumNodes() < 4) {
                --e;
                continue;
            }
            System.out.println();
            System.out.println("***************************************************");
            System.out.println();
            System.out.println("EXAMPLE #" + (e + 1) + " TARGET = " + target);
            Dag trueMb1 = new Dag(trueMb0);
            Dag trueMb2 = this.useVariablesFromData(trueMb1, trainData);
            System.out.println("\nTrue MB: " + trueMb2);
            System.out.println("\nClassification using true MB:");
            this.classifyMbOfm(trainData, testData, trueMb2, target);
            this.trainAndTest(trainData, testData, target, alpha, depth);
        }
    }

    public void rCompareAlphas() {
        int numVars = 1000;
        int numEdges = 1000;
        int numTrainSamples = 500;
        int numTestSamples = 100;
        int depth = 2;
        int minNumCategories = 3;
        int maxNumCategories = 3;
        System.out.println("Number of variables = " + numVars);
        System.out.println("Number of randomly selected edges = " + numEdges);
        System.out.println("Number of training samples  = " + numTrainSamples);
        System.out.println("Number of test samples  = " + numTestSamples);
        System.out.println();
        System.out.println();
        System.out.println("... creating random DAG");
        Dag randomGraph = GraphUtils.randomDag(numVars, 0, numEdges, 40, 40, 40, false);
        System.out.println("... creating SEM PM");
        BayesPm bayesPm = new BayesPm(randomGraph, minNumCategories, maxNumCategories);
        System.out.println("... creating SEM IM");
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        System.out.println("... simulating data");
        DataSet trainData = bayesIm.simulateData(numTrainSamples, false);
        DataSet testData = bayesIm.simulateData(numTestSamples, false);
        List<Node> variables = trainData.getVariables();
        for (int e = 0; e < 10; ++e) {
            int i = RandomUtil.getInstance().nextInt(numVars);
            DiscreteVariable target = (DiscreteVariable)variables.get(i);
            Node node = randomGraph.getNode(target.getName());
            Dag trueMb0 = GraphUtils.markovBlanketDag(node, randomGraph);
            if (trueMb0.getNumNodes() < 7) {
                --e;
                continue;
            }
            System.out.println();
            System.out.println("***************************************************");
            System.out.println();
            System.out.println("TARGET: " + target);
            Dag trueMb1 = new Dag(trueMb0);
            Dag trueMb2 = this.useVariablesFromData(trueMb1, trainData);
            System.out.println("\nTrue MB: " + trueMb2);
            System.out.println("\nClassification using true MB:");
            this.classifyMbRsu(trainData, testData, trueMb2, target);
            for (double alpha = 0.01; alpha <= 0.05; alpha += 0.01) {
                this.trainAndTest(trainData, testData, target, alpha, depth);
            }
        }
    }

    public void rtestCompareTainingSampleSizes() {
        int numVars = 1500;
        int numEdges = 1500;
        int numTestSamples = 100;
        int minNumCategories = 3;
        int maxNumCategories = 3;
        System.out.println("Number of variables = " + numVars);
        System.out.println("Number of randomly selected edges = " + numEdges);
        System.out.println("Number of test samples  = " + numTestSamples);
        System.out.println();
        System.out.println();
        System.out.println("... creating random DAG");
        Dag randomGraph = GraphUtils.randomDag(numVars, 0, numEdges, 40, 40, 40, false);
        System.out.println("... creating SEM PM");
        BayesPm bayesPm = new BayesPm(randomGraph, minNumCategories, maxNumCategories);
        System.out.println("... creating SEM IM");
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        System.out.println("... simulating data");
        DataSet testData = bayesIm.simulateData(numTestSamples, false);
        List<Node> variables = testData.getVariables();
        int i = RandomUtil.getInstance().nextInt(numVars);
        DiscreteVariable target = (DiscreteVariable)variables.get(i);
        Node node = randomGraph.getNode(target.getName());
        Dag trueMb0 = GraphUtils.markovBlanketDag(node, randomGraph);
        System.out.println("TARGET: " + target);
        for (int numTrainSamples = 100; numTrainSamples < 2000; numTrainSamples += 100) {
            System.out.println();
            System.out.println("***************************************************");
            System.out.println();
            System.out.println("Number of training samples  = " + numTrainSamples);
            DataSet trainData = bayesIm.simulateData(numTrainSamples, false);
            Dag trueMb1 = new Dag(trueMb0);
            Dag trueMb2 = this.useVariablesFromData(trueMb1, trainData);
            System.out.println("\nTrue MB: " + trueMb2);
            System.out.println("\nClassification using true MB:");
            this.classifyMbRsu(trainData, testData, trueMb2, target);
        }
    }

    private void trainAndTest(DataSet trainData, DataSet testData, DiscreteVariable target, double alpha, int depth) {
        System.out.println();
        System.out.println("** RUNNING MBF alpha = " + alpha + " " + "depth = " + depth);
        IndTestGSquare test = new IndTestGSquare(trainData, alpha);
        Mbfs search = new Mbfs(test, depth);
        Graph untrimmedGraph = search.search(target.getName());
        List<Node> untrimmedNodes = untrimmedGraph.getNodes();
        System.out.println(untrimmedNodes.size() + " GES nodes: " + untrimmedNodes);
        System.out.println("=================MBF RESULT=================");
        this.classifyEachDag(untrimmedGraph, search, trainData, testData, target);
        GesMbFilter gesFilter = new GesMbFilter(trainData);
        Graph pattern = gesFilter.filter(untrimmedNodes, target);
        System.out.println("\n=================MBF+GES RESULT=================");
        this.classifyEachDag(pattern, search, trainData, testData, target);
    }

    private void classifyEachDag(Graph pattern, Mbfs search, DataSet trainData, DataSet testData, DiscreteVariable target) {
        System.out.println("\nPattern = " + pattern);
        List<Graph> dags = MbUtils.generateMbDags(pattern, true, search.getTest(), search.getDepth(), search.getTarget());
        for (int j = 0; j < dags.size(); ++j) {
            Dag estimatedMb = new Dag(dags.get(j));
            System.out.println("\nMBD # " + (j + 1) + " in pattern: " + estimatedMb);
            this.classifyMbRsu(trainData, testData, estimatedMb, target);
        }
    }

    private void classifyMbRsu(DataSet trainData, DataSet testData, Dag estimatedMb, DiscreteVariable target) {
        int m;
        int estimatedCategory;
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        List<Node> mbNodes = estimatedMb.getNodes();
        DataSet trainDataSubset = trainData.subsetColumns(mbNodes);
        BayesPm mbBayesPm = new BayesPm(estimatedMb);
        List<Node> varsTrain = trainDataSubset.getVariables();
        for (int i = 0; i < varsTrain.size(); ++i) {
            DiscreteVariable dv = (DiscreteVariable)varsTrain.get(i);
            int ncats = dv.getNumCategories();
            mbBayesPm.setNumCategories(mbNodes.get(i), ncats);
        }
        BayesIm mbBayesIm = new MlBayesEstimator().estimate(mbBayesPm, trainDataSubset);
        RowSummingExactUpdater updater = new RowSummingExactUpdater(mbBayesIm);
        DataSet testDataSubset = testData.subsetColumns(mbNodes);
        int ncases = testDataSubset.getNumRows();
        int[] estimatedCategories = new int[ncases];
        Arrays.fill(estimatedCategories, -1);
        List<Node> varsClassify = testDataSubset.getVariables();
        for (int k = 0; k < ncases; ++k) {
            Evidence evidence = Evidence.tautology(mbBayesIm);
            Proposition proposition = evidence.getProposition();
            for (int m2 = 0; m2 < varsClassify.size(); ++m2) {
                DiscreteVariable testVar = (DiscreteVariable)varsClassify.get(m2);
                if (testVar.equals(target)) continue;
                int iother = proposition.getNodeIndex(testVar.getName());
                proposition.setCategory(iother, testDataSubset.getInt(k, m2));
            }
            updater.setEvidence(evidence);
            int indexTargetBN = proposition.getNodeIndex(target.getName());
            double highestProb = -0.1;
            estimatedCategory = -1;
            for (m = 0; m < target.getNumCategories(); ++m) {
                double marginal = updater.getMarginal(indexTargetBN, m);
                if (!(marginal > highestProb)) continue;
                highestProb = marginal;
                estimatedCategory = m;
            }
            estimatedCategories[k] = estimatedCategory;
        }
        int targetIndex = varsClassify.indexOf(target);
        int numCategories = target.getNumCategories();
        int[][] crossTabs = new int[numCategories][numCategories];
        int numberCorrect = 0;
        int numberCounted = 0;
        for (int k = 0; k < ncases; ++k) {
            estimatedCategory = estimatedCategories[k];
            int observedValue = testDataSubset.getInt(k, targetIndex);
            if (estimatedCategory < 0) continue;
            int[] nArray = crossTabs[observedValue];
            int n = estimatedCategory;
            nArray[n] = nArray[n] + 1;
            ++numberCounted;
            if (observedValue != estimatedCategory) continue;
            ++numberCorrect;
        }
        double percentCorrect = 100.0 * (double)numberCorrect / (double)numberCounted;
        System.out.println();
        System.out.println("\t\t\tEstimated\t");
        System.out.print("Observed\t");
        for (m = 0; m < numCategories; ++m) {
            System.out.print(target.getCategory(m) + "\t");
        }
        System.out.println();
        for (int k = 0; k < numCategories; ++k) {
            System.out.print(target.getCategory(k) + "\t");
            for (int m3 = 0; m3 < numCategories; ++m3) {
                System.out.print(crossTabs[k][m3] + "\t\t");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("Number correct = " + numberCorrect);
        System.out.println("Number counted = " + numberCounted);
        System.out.println("Percent correct = " + nf.format(percentCorrect) + "%");
    }

    private void classifyMbRsu2(DataSet trainData, DataSet testData, Dag estimatedMb, DiscreteVariable target) {
        int m;
        int estimatedCategory;
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        List<Node> mbNodes = estimatedMb.getNodes();
        DataSet trainDataSubset = trainData.subsetColumns(mbNodes);
        BayesPm mbBayesPm = new BayesPm(estimatedMb);
        List<Node> varsTrain = trainDataSubset.getVariables();
        for (int i = 0; i < varsTrain.size(); ++i) {
            DiscreteVariable dv = (DiscreteVariable)varsTrain.get(i);
            int ncats = dv.getNumCategories();
            mbBayesPm.setNumCategories(mbNodes.get(i), ncats);
        }
        BayesIm mbBayesIm = new MlBayesEstimator().estimate(mbBayesPm, trainDataSubset);
        RowSummingExactUpdater updater = new RowSummingExactUpdater(mbBayesIm);
        DataSet testDataSubset = testData.subsetColumns(mbNodes);
        int ncases = testDataSubset.getNumRows();
        int[] estimatedCategories = new int[ncases];
        Arrays.fill(estimatedCategories, -1);
        List<Node> varsClassify = testDataSubset.getVariables();
        for (int k = 0; k < ncases; ++k) {
            Evidence evidence = Evidence.tautology(mbBayesIm);
            Proposition proposition = evidence.getProposition();
            updater.setEvidence(evidence);
            int indexTargetBN = proposition.getNodeIndex(target.getName());
            double highestProb = -0.1;
            estimatedCategory = -1;
            for (m = 0; m < target.getNumCategories(); ++m) {
                double marginal = updater.getMarginal(indexTargetBN, m);
                if (!(marginal > highestProb)) continue;
                highestProb = marginal;
                estimatedCategory = m;
            }
            estimatedCategories[k] = estimatedCategory;
        }
        int targetIndex = varsClassify.indexOf(target);
        int numCategories = target.getNumCategories();
        int[][] crossTabs = new int[numCategories][numCategories];
        int numberCorrect = 0;
        int numberCounted = 0;
        for (int k = 0; k < ncases; ++k) {
            estimatedCategory = estimatedCategories[k];
            int observedValue = testDataSubset.getInt(k, targetIndex);
            if (estimatedCategory < 0) continue;
            int[] nArray = crossTabs[observedValue];
            int n = estimatedCategory;
            nArray[n] = nArray[n] + 1;
            ++numberCounted;
            if (observedValue != estimatedCategory) continue;
            ++numberCorrect;
        }
        double percentCorrect = 100.0 * (double)numberCorrect / (double)numberCounted;
        System.out.println();
        System.out.println("\t\t\tEstimated\t");
        System.out.print("Observed\t");
        for (m = 0; m < numCategories; ++m) {
            System.out.print(target.getCategory(m) + "\t");
        }
        System.out.println();
        for (int k = 0; k < numCategories; ++k) {
            System.out.print(target.getCategory(k) + "\t");
            for (int m2 = 0; m2 < numCategories; ++m2) {
                System.out.print(crossTabs[k][m2] + "\t\t");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("Number correct = " + numberCorrect);
        System.out.println("Number counted = " + numberCounted);
        System.out.println("Percent correct = " + nf.format(percentCorrect) + "%");
    }

    private void classifyMbOfm(DataSet trainData, DataSet testData, Dag estimatedMb, DiscreteVariable target) {
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        List<Node> mbNodes = estimatedMb.getNodes();
        DataSet trainDataSubset = trainData.subsetColumns(mbNodes);
        BayesPm mbBayesPm = new BayesPm(estimatedMb);
        List<Node> varsTrain = trainDataSubset.getVariables();
        for (int i = 0; i < varsTrain.size(); ++i) {
            DiscreteVariable dv = (DiscreteVariable)varsTrain.get(i);
            int ncats = dv.getNumCategories();
            mbBayesPm.setNumCategories(mbNodes.get(i), ncats);
        }
        OnTheFlyMarginalCalculator bayesUpdater = new OnTheFlyMarginalCalculator(mbBayesPm, trainDataSubset);
        DataSet testDataSubset = testData.subsetColumns(mbNodes);
        int ncases = testDataSubset.getNumRows();
        int[] estimatedCategories = new int[ncases];
        Arrays.fill(estimatedCategories, -1);
        List<Node> varsClassify = testDataSubset.getVariables();
        for (int k = 0; k < ncases; ++k) {
            Proposition proposition = Proposition.tautology(bayesUpdater);
            for (int m = 0; m < varsClassify.size(); ++m) {
                DiscreteVariable testVar = (DiscreteVariable)varsClassify.get(m);
                if (testVar.equals(target)) continue;
                int iother = proposition.getNodeIndex(testVar.getName());
                proposition.setCategory(iother, testDataSubset.getInt(k, m));
            }
            bayesUpdater.setEvidence(new Evidence(proposition));
            int indexTargetBN = proposition.getNodeIndex(target.getName());
            double highestProb = -0.1;
            int estimatedCategory = -1;
            for (int category = 0; category < target.getNumCategories(); ++category) {
                double marginal = bayesUpdater.getMarginal(indexTargetBN, category);
                if (!(marginal > highestProb)) continue;
                highestProb = marginal;
                estimatedCategory = category;
            }
            if (estimatedCategory < 0) continue;
            estimatedCategories[k] = estimatedCategory;
        }
        int targetIndex = varsClassify.indexOf(target);
        int numCategories = target.getNumCategories();
        int[][] crossTabs = new int[numCategories][numCategories];
        int numberCorrect = 0;
        int numberCounted = 0;
        for (int k = 0; k < ncases; ++k) {
            int estimatedCategory = estimatedCategories[k];
            int observedValue = testDataSubset.getInt(k, targetIndex);
            if (estimatedCategory < 0) continue;
            int[] nArray = crossTabs[observedValue];
            int n = estimatedCategory;
            nArray[n] = nArray[n] + 1;
            ++numberCounted;
            if (observedValue != estimatedCategory) continue;
            ++numberCorrect;
        }
        double percentCorrect = 100.0 * (double)numberCorrect / (double)numberCounted;
        System.out.println();
        System.out.println("\t\t\tEstimated\t");
        System.out.print("Observed\t");
        for (int m = 0; m < numCategories; ++m) {
            System.out.print(target.getCategory(m) + "\t");
        }
        System.out.println();
        for (int k = 0; k < numCategories; ++k) {
            System.out.print(target.getCategory(k) + "\t");
            for (int m = 0; m < numCategories; ++m) {
                System.out.print(crossTabs[k][m] + "\t\t");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("Number correct = " + numberCorrect);
        System.out.println("Number counted = " + numberCounted);
        System.out.println("Percent correct = " + nf.format(percentCorrect) + "%");
    }

    private void classifyMbOfm2(DataSet trainData, DataSet testData, Dag estimatedMb, DiscreteVariable target) {
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        List<Node> mbNodes = estimatedMb.getNodes();
        DataSet trainDataSubset = trainData.subsetColumns(mbNodes);
        BayesPm mbBayesPm = new BayesPm(estimatedMb);
        List<Node> varsTrain = trainDataSubset.getVariables();
        for (int i = 0; i < varsTrain.size(); ++i) {
            DiscreteVariable dv = (DiscreteVariable)varsTrain.get(i);
            int ncats = dv.getNumCategories();
            mbBayesPm.setNumCategories(mbNodes.get(i), ncats);
        }
        OnTheFlyMarginalCalculator bayesUpdater = new OnTheFlyMarginalCalculator(mbBayesPm, trainDataSubset);
        DataSet testDataSubset = testData.subsetColumns(mbNodes);
        int ncases = testDataSubset.getNumRows();
        int[] estimatedCategories = new int[ncases];
        Arrays.fill(estimatedCategories, -1);
        List<Node> varsClassify = testDataSubset.getVariables();
        for (int k = 0; k < ncases; ++k) {
            Proposition proposition = Proposition.tautology(bayesUpdater);
            bayesUpdater.setEvidence(new Evidence(proposition));
            int indexTargetBN = proposition.getNodeIndex(target.getName());
            double highestProb = -0.1;
            int estimatedCategory = -1;
            for (int category = 0; category < target.getNumCategories(); ++category) {
                double marginal = bayesUpdater.getMarginal(indexTargetBN, category);
                if (!(marginal > highestProb)) continue;
                highestProb = marginal;
                estimatedCategory = category;
            }
            estimatedCategories[k] = estimatedCategory;
        }
        int targetIndex = varsClassify.indexOf(target);
        int numCategories = target.getNumCategories();
        int[][] crossTabs = new int[numCategories][numCategories];
        int numberCorrect = 0;
        int numberCounted = 0;
        for (int k = 0; k < ncases; ++k) {
            int estimatedCategory = estimatedCategories[k];
            int observedValue = testDataSubset.getInt(k, targetIndex);
            if (estimatedCategory < 0) continue;
            int[] nArray = crossTabs[observedValue];
            int n = estimatedCategory;
            nArray[n] = nArray[n] + 1;
            ++numberCounted;
            if (observedValue != estimatedCategory) continue;
            ++numberCorrect;
        }
        double percentCorrect = 100.0 * (double)numberCorrect / (double)numberCounted;
        System.out.println();
        System.out.println("\t\t\tEstimated\t");
        System.out.print("Observed\t");
        for (int m = 0; m < numCategories; ++m) {
            System.out.print(target.getCategory(m) + "\t");
        }
        System.out.println();
        for (int k = 0; k < numCategories; ++k) {
            System.out.print(target.getCategory(k) + "\t");
            for (int m = 0; m < numCategories; ++m) {
                System.out.print(crossTabs[k][m] + "\t\t");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("Number correct = " + numberCorrect);
        System.out.println("Number counted = " + numberCounted);
        System.out.println("Percent correct = " + nf.format(percentCorrect) + "%");
    }

    private Dag useVariablesFromData(Dag oldDag, DataSet dataSet) {
        Dag newDag = new Dag();
        List<Node> oldNodes = oldDag.getNodes();
        ArrayList<Node> newNodes = new ArrayList<Node>();
        for (Node oldNode : oldNodes) {
            String name = oldNode.getName();
            Node variable = dataSet.getVariable(name);
            newNodes.add(variable);
            newDag.addNode(variable);
        }
        List<Edge> oldEdges = oldDag.getEdges();
        for (Edge edge : oldEdges) {
            Node node1 = (Node)newNodes.get(oldNodes.indexOf(edge.getNode1()));
            Node node2 = (Node)newNodes.get(oldNodes.indexOf(edge.getNode2()));
            Endpoint endpoint1 = edge.getEndpoint1();
            Endpoint endpoint2 = edge.getEndpoint2();
            Edge newEdge = new Edge(node1, node2, endpoint1, endpoint2);
            newDag.addEdge(newEdge);
        }
        return newDag;
    }

    public static Test suite() {
        return new TestSuite(TestMbClassify.class);
    }
}

