/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.mb;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphConverter;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.CpcMb;
import edu.cmu.tetrad.search.Ges;
import edu.cmu.tetrad.search.GrowShrink;
import edu.cmu.tetrad.search.HitonVariant;
import edu.cmu.tetrad.search.IndTestDSep;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.IndTestGSquare;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.MbSearch;
import edu.cmu.tetrad.search.MbUtils;
import edu.cmu.tetrad.search.Mbfs;
import edu.cmu.tetrad.search.Pcmb;
import edu.cmu.tetrad.search.mb.Mmmb;
import edu.cmu.tetrad.search.mb.VanderbiltHitonMb;
import edu.cmu.tetrad.search.mb.VanderbiltIamb;
import edu.cmu.tetrad.search.mb.VanderbiltIambnPc;
import edu.cmu.tetrad.search.mb.VanderbiltInterIamb;
import edu.cmu.tetrad.search.mb.VanderbiltInterIambnPc;
import edu.cmu.tetrad.sem.LargeSemSimulator;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.io.File;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestMarkovBlanketSearches
extends TestCase {
    static Graph testGraphSub;
    static Graph testGraphSubCorrect;
    NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    NumberFormat nf2 = new DecimalFormat("     0");
    NumberFormat nf3 = NumberFormatUtil.getInstance().getNumberFormat();

    public TestMarkovBlanketSearches(String name) {
        super(name);
    }

    public static void main(String[] args) {
        new TestMarkovBlanketSearches("name").overnight();
    }

    public static void testSubgraph1() {
        Graph graph = GraphConverter.convert("T-->X,X-->Y,W-->X,W-->Y");
        System.out.println(graph);
        IndTestDSep test = new IndTestDSep(graph);
        GrowShrink search = new GrowShrink(test);
        List<Node> blanket = search.findMb("T");
        System.out.println(blanket);
    }

    public static void testSubgraph2() {
        Graph graph = GraphConverter.convert("P1-->T,P2-->T,T-->C1,T-->C2,T-->C3,PC1a-->C1,PC1b-->C1,PC2a-->C2,PC2b<--C2,PC3a-->C3,PC3b-->C3,PC1b-->PC2a,PC1a<--PC3b,U,V");
        System.out.println("True graph is: " + graph);
        IndTestDSep test = new IndTestDSep(graph);
        GrowShrink mbSearch = new GrowShrink(test);
        List<Node> blanket = mbSearch.findMb("T");
        System.out.println(blanket);
    }

    public static void testRandom() {
        Dag dag = GraphUtils.randomDag(10, 0, 10, 5, 5, 5, false);
        IndTestDSep test = new IndTestDSep(dag);
        Mbfs search = new Mbfs(test, -1);
        System.out.println("INDEPENDENT GRAPH: " + dag);
        List<Node> nodes = dag.getNodes();
        for (Node node : nodes) {
            List<Node> resultNodes = search.findMb(node.getName());
            Dag trueMb = GraphUtils.markovBlanketDag(node, dag);
            List<Node> trueNodes = trueMb.getNodes();
            trueNodes.remove(node);
            Collections.sort(trueNodes, new Comparator<Node>(){

                @Override
                public int compare(Node n1, Node n2) {
                    return n1.getName().compareTo(n2.getName());
                }
            });
            Collections.sort(resultNodes, new Comparator<Node>(){

                @Override
                public int compare(Node n1, Node n2) {
                    return n1.getName().compareTo(n2.getName());
                }
            });
            System.out.println();
            System.out.println(trueNodes);
            System.out.println(resultNodes);
        }
    }

    public void overnight() {
        try {
            File file = new File("overnight.txt");
            System.out.println(file);
            PrintWriter out = new PrintWriter(file);
            SimulationParams params = new SimulationParams();
            params.setSampleSize(1000);
            params.setDiscrete(false);
            params.setRandomGraphEveryTime(true);
            params.setTimeLimit(600000L);
            params.setDepth(3);
            params.setNumTests(30);
            params.setMinMbSize(8);
            params.setAlpha(0.05);
            params.setMinNumCategories(2);
            params.setMaxNumCategories(4);
            params.setAlgNames(Arrays.asList("PCMB", "CPCMB", "GS", "IAMB", "InterIAMBnPC", "IAMBnPC", "InterIAMB", "HITON-MB", "MMMB", "MBFS"));
            params.setEdgeMultipler(1.2);
            params.setAlgNames(Arrays.asList("HITON-MB", "MMMB", "MBFS"));
            params.setNumVars(500);
            this.testLoop(out, params);
            params.setNumVars(1000);
            this.testLoop(out, params);
            params.setDiscrete(true);
            params.setRandomGraphEveryTime(true);
            params.setNumVars(100);
            this.testLoop(out, params);
            params.setNumVars(500);
            this.testLoop(out, params);
            params.setNumVars(1000);
            this.testLoop(out, params);
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void large() {
        try {
            File file = new File("overnight.txt");
            System.out.println(file);
            PrintWriter out = new PrintWriter(file);
            SimulationParams params = new SimulationParams();
            params.setSampleSize(1000);
            params.setRandomGraphEveryTime(false);
            params.setTimeLimit(450000L);
            params.setDepth(2);
            params.setNumTests(30);
            params.setMinMbSize(8);
            params.setAlpha(0.01);
            params.setMinNumCategories(2);
            params.setMaxNumCategories(4);
            params.setAlgNames(Arrays.asList("HITON-MB", "MMMB", "MBFS"));
            params.setEdgeMultipler(1.2);
            params.setNumVars(5000);
            params.setDiscrete(false);
            this.testLoop(out, params);
            params.setDiscrete(true);
            this.testLoop(out, params);
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void tryout() {
        try {
            File file = new File("tryout.txt");
            System.out.println(file);
            PrintWriter out = new PrintWriter(file);
            SimulationParams params = new SimulationParams();
            params.setSampleSize(1000);
            params.setNumVars(1000);
            params.setDiscrete(false);
            params.setRandomGraphEveryTime(true);
            params.setTimeLimit(600000L);
            params.setDepth(3);
            params.setNumTests(30);
            params.setMinMbSize(6);
            params.setAlpha(0.05);
            params.setMinNumCategories(2);
            params.setMaxNumCategories(4);
            params.setAlgNames(Arrays.asList("HITON-MB", "MMMB", "MBFS"));
            this.testLoop(out, params);
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void testLoop(PrintWriter out, SimulationParams params) {
        LinkedList<String> algNames = new LinkedList<String>(params.getAlgNames());
        int numEdges = (int)(params.getEdgeMultipler() * (double)params.getNumVars());
        int minNumCategories = params.getMinNumCategories();
        int maxNumCategories = params.getMaxNumCategories();
        this.println(out, "Alpha = " + params.getAlpha());
        this.println(out, "# variables = " + params.getNumVars());
        this.println(out, "# edges = " + numEdges);
        this.println(out, "# samples = " + params.getSampleSize());
        this.println(out, "Depth = " + params.getDepth());
        this.println(out, params.isDiscrete() ? "Discrete" : "Continuous");
        if (params.isDiscrete()) {
            this.println(out, minNumCategories + " to " + maxNumCategories + " categories.");
        }
        out.println();
        out.println();
        this.println(out, "\t FP\t FN\t Err\t Corr\t Truth\t Time");
        LinkedList<MbSearch> algorithms = new LinkedList<MbSearch>();
        ArrayList<Stats> collectedStats = new ArrayList<Stats>();
        Dag randomGraph = null;
        DataSet dataSet = null;
        HashSet<Node> usedMbNodes = new HashSet<Node>();
        boolean createRandomGraph = true;
        IndependenceTest test = null;
        for (int n = 0; n < params.getNumTests(); ++n) {
            System.gc();
            if (params.isRandomGraphEveryTime() || createRandomGraph) {
                randomGraph = GraphUtils.randomDag(params.getNumVars(), 0, numEdges, 9, 3, 9, false);
                createRandomGraph = false;
                if (params.isDiscrete()) {
                    dataSet = this.simulateDiscrete(randomGraph, dataSet, params.getSampleSize(), minNumCategories, maxNumCategories);
                    test = new IndTestGSquare(dataSet, params.getAlpha());
                } else {
                    dataSet = this.simulateContinuous(randomGraph, params.getSampleSize(), dataSet);
                    test = new IndTestFisherZ(dataSet, params.getAlpha());
                }
                algorithms.clear();
                for (String algName : algNames) {
                    algorithms.add(this.getAlgorithm(algName, test, params.getDepth(), dataSet));
                }
            }
            HashSet<Integer> visited = new HashSet<Integer>();
            int tried = 0;
            Graph trueMbDag = null;
            Node t = null;
            int i = -1;
            while (tried <= 30) {
                i = RandomUtil.getInstance().nextInt(params.getNumVars());
                if (visited.contains(i)) {
                    ++tried;
                    continue;
                }
                t = randomGraph.getNodes().get(i);
                if (usedMbNodes.contains(t)) {
                    ++tried;
                    continue;
                }
                trueMbDag = GraphUtils.markovBlanketDag(t, randomGraph);
                if (trueMbDag.getNumNodes() >= params.getMinMbSize() + 1) break;
                trueMbDag = null;
                t = null;
                visited.add(i);
            }
            if (t == null || trueMbDag == null) {
                this.println(out, "new data");
                createRandomGraph = true;
                usedMbNodes.clear();
                --n;
                continue;
            }
            List<Node> nodes2 = trueMbDag.getNodes();
            usedMbNodes.addAll(nodes2);
            nodes2.remove(t);
            List<String> truth = this.extractVarNames(nodes2, t);
            this.println(out, "n = " + (n + 1));
            for (MbSearch algorithm : new LinkedList(algorithms)) {
                Stats stats = this.printNodeStats(algorithm, t, truth, i, out, params.getTimeLimit());
                if (stats == null) continue;
                collectedStats.add(stats);
            }
            this.println(out, "");
        }
        this.println(out, "\\begin{tabular}{llllllll}");
        this.println(out, "\\hline");
        this.println(out, "#vars&Algorithm&FP&FN&Err&Corr&Truth&Time\\\\");
        this.println(out, "\\hline");
        this.println(out, "\tFP\tFN\tErr\tCorr\tTruth\tTime");
        for (MbSearch algorithm : algorithms) {
            int fpSum = 0;
            int fnSum = 0;
            int errorsSum = 0;
            int truthSum = 0;
            long timeSum = 0L;
            int n = 0;
            for (Stats stats : collectedStats) {
                if (!stats.getAlgorithm().getAlgorithmName().equals(algorithm.getAlgorithmName())) continue;
                fpSum += stats.getFp();
                fnSum += stats.getFn();
                errorsSum += stats.getErrors();
                truthSum += stats.getTruth();
                timeSum += stats.getTime();
                ++n;
            }
            double fpAvg = (double)fpSum / (double)n;
            double fnAvg = (double)fnSum / (double)n;
            double errorsAve = (double)errorsSum / (double)n;
            double truthAvg = (double)truthSum / (double)n;
            double timeAvg = (double)timeSum / (double)n;
            this.println(out, params.getNumVars() + "&" + algorithm.getAlgorithmName() + "&" + this.nf3.format(fpAvg) + "&" + this.nf3.format(fnAvg) + "&" + this.nf3.format(errorsAve) + "&" + this.nf3.format(truthAvg - fnAvg) + "&" + this.nf3.format(truthAvg) + "&" + this.nf3.format(timeAvg) + "\\\\");
        }
        this.println(out, "\\hline");
        this.println(out, "\\end{tabular}");
    }

    private void println(PrintWriter out, String x) {
        out.println(x);
        out.flush();
        System.out.println(x);
    }

    private MbSearch getAlgorithm(String name, IndependenceTest test, int depth, DataSet dataSet) {
        if ("PCMB".equals(name)) {
            return new Pcmb(test, depth);
        }
        if ("CPCMB".equals(name)) {
            return new CpcMb(test, depth);
        }
        if ("GS".equals(name)) {
            return new GrowShrink(test);
        }
        if ("IAMB".equals(name)) {
            return new VanderbiltIamb(test);
        }
        if ("IAMBnPC".equals(name)) {
            return new VanderbiltIambnPc(test);
        }
        if ("InterIAMB".equals(name)) {
            return new VanderbiltInterIamb(test);
        }
        if ("InterIAMBnPC".equals(name)) {
            return new VanderbiltInterIambnPc(test);
        }
        if ("HITON-VARIANT".equals(name)) {
            return new HitonVariant(test, depth);
        }
        if ("HITON-MB".equals(name)) {
            return new VanderbiltHitonMb(test, depth, false);
        }
        if ("HITON-MB-SYM".equals(name)) {
            return new VanderbiltHitonMb(test, depth, true);
        }
        if ("MMMB".equals(name)) {
            return new Mmmb(test, depth, false);
        }
        if ("MMMB-SYM".equals(name)) {
            return new Mmmb(test, depth, true);
        }
        if ("MBFS".equals(name)) {
            return new Mbfs(test, depth);
        }
        throw new IllegalStateException("Unrecognized algorithm name: " + name);
    }

    private Stats printNodeStats(MbSearch algorithm, Node t, List<String> _truth, int i, PrintWriter out, long timeLimit) {
        long time = System.currentTimeMillis();
        class MyThread
        extends Thread {
            private MbSearch algorithm;
            private List<Node> nodes;
            private Node t;
            private boolean done = false;
            private long startTime = System.currentTimeMillis();
            private long endTime;

            public MyThread(MbSearch algorithm, Node t) {
                this.algorithm = algorithm;
                this.t = t;
            }

            @Override
            public void run() {
                this.startTime = System.currentTimeMillis();
                this.nodes = this.algorithm.findMb(this.t.getName());
                this.done = true;
                this.endTime = System.currentTimeMillis();
            }

            public List<Node> getNodes() {
                return this.nodes;
            }

            public boolean isDone() {
                return this.done;
            }
        }
        MyThread thread = new MyThread(algorithm, t);
        thread.start();
        while (!thread.isDone()) {
            long cur = System.currentTimeMillis();
            long diff = cur - thread.startTime;
            if (timeLimit != -1L && diff > timeLimit) {
                System.out.println("Took too long: " + algorithm.getAlgorithmName());
                thread.stop();
                return null;
            }
            try {
                Thread.sleep(100L);
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        List<Node> nodes = thread.getNodes();
        List<String> mbf = this.extractVarNames(nodes, t);
        ArrayList<String> mbfAndTruth = new ArrayList<String>(mbf);
        mbfAndTruth.retainAll(_truth);
        ArrayList<String> mbfFp = new ArrayList<String>(mbf);
        mbfFp.removeAll(mbfAndTruth);
        int fp = mbfFp.size();
        ArrayList<String> mbfFn = new ArrayList<String>(_truth);
        mbfFn.removeAll(mbfAndTruth);
        int fn = mbfFn.size();
        int truth = _truth.size();
        long elapsedTime = System.currentTimeMillis() - time;
        this.println(out, i + ".\t" + this.nf2.format(fp) + "\t" + this.nf2.format(fn) + "\t" + this.nf2.format(fp + fn) + "\t" + this.nf2.format(truth - fn) + "\t" + this.nf2.format(truth) + "\t" + elapsedTime + " ms " + algorithm.getNumIndependenceTests() + "\t" + algorithm.getAlgorithmName());
        return new Stats(algorithm, fp, fn, fp + fn, truth, elapsedTime);
    }

    private Stats printGraphStats(MbSearch algorithm, Node target, Graph trueMbDag, DataSet dataSet, int i, NumberFormat nf2, PrintWriter out) {
        Edge _edge;
        Node node2;
        Node node1;
        Graph estimatedMbDag;
        long time = System.currentTimeMillis();
        if (algorithm instanceof Mbfs) {
            algorithm.findMb(target.getName());
            estimatedMbDag = ((Mbfs)algorithm).getGraph();
        } else {
            List<Node> nodes = algorithm.findMb(target.getName());
            nodes.add(target);
            ArrayList<Node> _nodes = new ArrayList<Node>();
            for (Node node : nodes) {
                _nodes.add(dataSet.getVariable(node.getName()));
            }
            DataSet _dataSet = dataSet.subsetColumns(_nodes);
            Ges search = new Ges(_dataSet);
            estimatedMbDag = search.search();
            MbUtils.trimToMbNodes(estimatedMbDag, estimatedMbDag.getNode(target.getName()), false);
        }
        long elapsedTime = System.currentTimeMillis() - time;
        int truth = trueMbDag.getNumEdges();
        int fp = 0;
        int fn = 0;
        for (Edge edge : estimatedMbDag.getEdges()) {
            node1 = trueMbDag.getNode(edge.getNode1().getName());
            node2 = trueMbDag.getNode(edge.getNode2().getName());
            if (node1 == null || node2 == null) {
                ++fp;
                continue;
            }
            _edge = trueMbDag.getEdge(node1, node2);
            if (_edge != null) continue;
            ++fp;
        }
        for (Edge edge : trueMbDag.getEdges()) {
            node1 = estimatedMbDag.getNode(edge.getNode1().getName());
            node2 = estimatedMbDag.getNode(edge.getNode2().getName());
            if (node1 == null || node2 == null) {
                ++fn;
                continue;
            }
            _edge = estimatedMbDag.getEdge(node1, node2);
            if (_edge != null) continue;
            ++fn;
        }
        this.println(out, i + ".\t" + nf2.format(fp) + "\t" + nf2.format(fn) + "\t" + nf2.format(fp + fn) + "\t" + nf2.format(truth - fn) + "\t" + nf2.format(truth) + "\t" + elapsedTime + " ms " + algorithm.getAlgorithmName() + "\t");
        return new Stats(algorithm, fp, fn, fp + fn, truth, elapsedTime);
    }

    private Stats printMbfsGraphStats(Node target, IndependenceTest test, int depth, Graph trueMbDag, int i, NumberFormat nf2) {
        Edge _edge;
        Node node2;
        Node node1;
        long time = System.currentTimeMillis();
        Mbfs algorithm = new Mbfs(test, depth);
        Graph estimatedMbDag = algorithm.search(target.getName());
        long elapsedTime = System.currentTimeMillis() - time;
        int truth = trueMbDag.getNumEdges();
        int fp = 0;
        int fn = 0;
        for (Edge edge : estimatedMbDag.getEdges()) {
            node1 = trueMbDag.getNode(edge.getNode1().getName());
            node2 = trueMbDag.getNode(edge.getNode2().getName());
            if (node1 == null || node2 == null) {
                ++fp;
                continue;
            }
            _edge = trueMbDag.getEdge(node1, node2);
            if (_edge != null) continue;
            ++fp;
        }
        for (Edge edge : trueMbDag.getEdges()) {
            node1 = estimatedMbDag.getNode(edge.getNode1().getName());
            node2 = estimatedMbDag.getNode(edge.getNode2().getName());
            if (node1 == null || node2 == null) {
                ++fn;
                continue;
            }
            _edge = estimatedMbDag.getEdge(node1, node2);
            if (_edge != null) continue;
            ++fn;
        }
        System.out.println(i + ". (M)\t" + nf2.format(fp) + "\t" + nf2.format(fn) + "\t" + nf2.format(fp + fn) + "\t" + nf2.format(truth - fn) + "\t" + nf2.format(truth) + "\t" + elapsedTime + " ms " + algorithm.getAlgorithmName() + "\t");
        return new Stats(algorithm, fp, fn, fp + fn, truth, elapsedTime);
    }

    private DataSet simulateDiscrete(Dag randomGraph, DataSet dataSet, int sampleSize, int minNumCategories, int maxNumCategories) {
        BayesPm bayesPm = new BayesPm(randomGraph, minNumCategories, maxNumCategories);
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        dataSet = dataSet == null ? bayesIm.simulateData(sampleSize, false) : bayesIm.simulateData(dataSet, false);
        return dataSet;
    }

    private DataSet simulateContinuous(Dag randomGraph, int sampleSize, DataSet dataSet) {
        LargeSemSimulator simulator = new LargeSemSimulator(randomGraph);
        dataSet = dataSet == null ? simulator.simulateDataAcyclic(sampleSize) : simulator.simulateDataAcyclic(dataSet);
        return dataSet;
    }

    private List<String> extractVarNames(List<Node> nodes, Node target) {
        ArrayList<String> varNames = new ArrayList<String>();
        for (Node node : nodes) {
            varNames.add(node.getName());
        }
        varNames.remove(target.getName());
        Collections.sort(varNames);
        return varNames;
    }

    public static void findExample() {
        Dag dag = GraphUtils.randomDag(10, 0, 10, 5, 5, 5, false);
        IndTestDSep test = new IndTestDSep(dag);
        Mbfs search = new Mbfs(test, -1);
        System.out.println("INDEPENDENT GRAPH: " + dag);
        List<Node> nodes = dag.getNodes();
        for (Node node : nodes) {
            Graph resultMb = search.search(node.getName());
            Dag trueMb = GraphUtils.markovBlanketDag(node, dag);
            List<Node> resultNodes = resultMb.getNodes();
            List<Node> trueNodes = trueMb.getNodes();
            HashSet<String> resultNames = new HashSet<String>();
            for (Node resultNode : resultNodes) {
                resultNames.add(resultNode.getName());
            }
            HashSet<String> trueNames = new HashSet<String>();
            for (Node v : trueNodes) {
                trueNames.add(v.getName());
            }
            TestMarkovBlanketSearches.assertTrue(((Object)resultNames).equals(trueNames));
            List<Edge> resultEdges = resultMb.getEdges();
            for (Edge resultEdge : resultEdges) {
                if (!Edges.isDirectedEdge(resultEdge)) continue;
                String name1 = resultEdge.getNode1().getName();
                String name2 = resultEdge.getNode2().getName();
                Node node1 = trueMb.getNode(name1);
                Node node2 = trueMb.getNode(name2);
                if (node1 == null) {
                    System.err.println("Node " + name1 + " is not in the true graph.");
                    continue;
                }
                if (node2 == null) {
                    System.err.println("Node " + name2 + " is not in the true graph.");
                    continue;
                }
                Edge trueEdge = trueMb.getEdge(node1, node2);
                if (trueEdge == null) {
                    Node resultNode1 = resultMb.getNode(node1.getName());
                    Node resultNode2 = resultMb.getNode(node2.getName());
                    Node resultTarget = resultMb.getNode(node.getName());
                    Edge a = resultMb.getEdge(resultNode1, resultTarget);
                    Edge b = resultMb.getEdge(resultNode2, resultTarget);
                    if (a == null || b == null || Edges.isDirectedEdge(a) && Edges.isUndirectedEdge(b) || Edges.isUndirectedEdge(a) && Edges.isDirectedEdge(b)) continue;
                    TestMarkovBlanketSearches.fail("EXTRA EDGE: Edge in result MB but not true MB = " + resultEdge);
                }
                TestMarkovBlanketSearches.assertEquals(resultEdge.getEndpoint1(), trueEdge.getEndpoint1());
                TestMarkovBlanketSearches.assertEquals(resultEdge.getEndpoint2(), trueEdge.getEndpoint2());
                System.out.println("Result edge = " + resultEdge + ", true edge = " + trueEdge);
            }
            List<Edge> trueEdges = trueMb.getEdges();
            for (Edge trueEdge : trueEdges) {
                Node node1 = trueEdge.getNode1();
                Node node2 = trueEdge.getNode2();
                Node resultNode1 = resultMb.getNode(node1.getName());
                Node resultNode2 = resultMb.getNode(node2.getName());
                TestMarkovBlanketSearches.assertTrue("Expected adjacency " + resultNode1 + "---" + resultNode2, resultMb.isAdjacentTo(resultNode1, resultNode2));
            }
        }
    }

    public static Test suite() {
        return new TestSuite(TestMarkovBlanketSearches.class);
    }

    static class SimulationParams {
        private boolean discrete = false;
        private int numVars = 100;
        private double edgeMultipler = 1.0;
        private int sampleSize = 1000;
        private boolean randomGraphEveryTime = true;
        private long timeLimit = 300000L;
        private int depth = 3;
        private int numTests = 30;
        private int minMbSize = 8;
        private double alpha = 0.05;
        private int minNumCategories = 2;
        private int maxNumCategories = 4;
        private List<String> algNames = new LinkedList<String>();

        SimulationParams() {
        }

        public boolean isDiscrete() {
            return this.discrete;
        }

        public void setDiscrete(boolean discrete) {
            this.discrete = discrete;
        }

        public int getNumVars() {
            return this.numVars;
        }

        public void setNumVars(int numVars) {
            this.numVars = numVars;
        }

        public boolean isRandomGraphEveryTime() {
            return this.randomGraphEveryTime;
        }

        public void setRandomGraphEveryTime(boolean randomGraphEveryTime) {
            this.randomGraphEveryTime = randomGraphEveryTime;
        }

        public long getTimeLimit() {
            return this.timeLimit;
        }

        public void setTimeLimit(long timeLimit) {
            this.timeLimit = timeLimit;
        }

        public int getDepth() {
            return this.depth;
        }

        public void setDepth(int depth) {
            this.depth = depth;
        }

        public int getNumTests() {
            return this.numTests;
        }

        public void setNumTests(int numTests) {
            this.numTests = numTests;
        }

        public int getMinMbSize() {
            return this.minMbSize;
        }

        public void setMinMbSize(int minMbSize) {
            this.minMbSize = minMbSize;
        }

        public double getAlpha() {
            return this.alpha;
        }

        public void setAlpha(double alpha) {
            this.alpha = alpha;
        }

        public int getMinNumCategories() {
            return this.minNumCategories;
        }

        public void setMinNumCategories(int minNumCategories) {
            this.minNumCategories = minNumCategories;
        }

        public int getMaxNumCategories() {
            return this.maxNumCategories;
        }

        public void setMaxNumCategories(int maxNumCategories) {
            this.maxNumCategories = maxNumCategories;
        }

        public List<String> getAlgNames() {
            return this.algNames;
        }

        public void setAlgNames(List<String> algNames) {
            this.algNames = algNames;
        }

        public int getSampleSize() {
            return this.sampleSize;
        }

        public void setSampleSize(int sampleSize) {
            this.sampleSize = sampleSize;
        }

        public double getEdgeMultipler() {
            return this.edgeMultipler;
        }

        public void setEdgeMultipler(double edgeMultipler) {
            this.edgeMultipler = edgeMultipler;
        }
    }

    private static class Stats {
        private MbSearch algorithm;
        private int fp;
        private int fn;
        private int errors;
        private int truth;
        private long time;

        public Stats(MbSearch algorithm, int fp, int fn, int errors, int truth, long time) {
            this.algorithm = algorithm;
            this.fp = fp;
            this.fn = fn;
            this.errors = errors;
            this.truth = truth;
            this.time = time;
        }

        public MbSearch getAlgorithm() {
            return this.algorithm;
        }

        public int getFp() {
            return this.fp;
        }

        public int getFn() {
            return this.fn;
        }

        public int getErrors() {
            return this.errors;
        }

        public int getTruth() {
            return this.truth;
        }

        public long getTime() {
            return this.time;
        }
    }
}

