/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndTestChiSquare;
import edu.cmu.tetrad.search.IndTestFisherZGeneralizedInverse;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.MbUtils;
import edu.cmu.tetrad.search.Mbfs;
import edu.cmu.tetrad.sem.LargeSemSimulator;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

public class TestMbfReport
extends TestCase {
    private NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    public static int CONTINUOUS = 0;
    public static int DISCRETE = 1;
    static int[][] testCrosstabs = new int[][]{{487, 49}, {81, 366}};
    static int[][] testCrosstabsNew = new int[][]{{38, 9, 7}, {10, 15, 14}, {5, 3, 59}};
    private PrintWriter fileOut;

    public TestMbfReport(String name) {
        super(name);
    }

    public void testBlank() {
    }

    public void rtestReportOut() {
        try {
            double alpha = 0.001;
            int dimension = 200;
            int sampleSize = 1000;
            int depth = 2;
            String variableType = "Continuous";
            int numTargets = 25;
            String fileOutName = "test_data/mbfsreport_" + this.nf.format(alpha) + "_" + dimension + "_" + sampleSize + "_" + depth + "_" + variableType + "_" + numTargets + ".txt";
            this.fileOut = new PrintWriter(new FileWriter(new File(fileOutName)));
            this.generateReport(alpha, dimension, sampleSize, depth, variableType, numTargets);
            this.fileOut.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void rtestReportOutOvernightRun() {
        try {
            double[] alphas = new double[]{1.0E-4, 0.001, 0.01};
            int[] dimensions = new int[]{500, 1000, 2000, 5000};
            int sampleSize = 1000;
            int depth = 2;
            int numTargets = 25;
            this.doRun(0.01, 10000, sampleSize, 2, "Continuous", numTargets);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void doRun(double alpha, int dimension, int sampleSize, int depth, String variableType, int numTargets) throws IOException {
        String fileOutName = "test_data/mbfsreport_" + this.nf.format(alpha) + "_" + dimension + "_" + sampleSize + "_" + depth + "_" + variableType + "_" + numTargets + ".txt";
        this.fileOut = new PrintWriter(new FileWriter(new File(fileOutName)));
        this.generateReport(alpha, dimension, sampleSize, depth, variableType, numTargets);
        this.fileOut.close();
    }

    private void generateReport(double alpha, int dimension, int sampleSize, int depth, String variableType, int numTargets) {
        this.printLine("MBF trials from random targets in a single simulated data set.");
        this.printLine("");
        this.printLine("Alpha = " + this.nf.format(alpha));
        this.printLine("Dimension = " + dimension);
        this.printLine("Sample size  = " + sampleSize);
        this.printLine("Depth = " + depth);
        this.printLine("Variable type = " + variableType);
        this.printLine("Num targets = " + numTargets);
        this.printLine("");
        this.printLine("");
        if ("Continuous".equals(variableType)) {
            this.examineContinuousDatabase(numTargets, dimension, sampleSize, alpha, depth);
        } else if ("Discrete".equals(variableType)) {
            this.examineDiscreteDatabase(numTargets, dimension, sampleSize, alpha, depth);
        } else {
            throw new IllegalStateException();
        }
    }

    private void examineContinuousDatabase(int numTargets, int dimension, int sampleSize, double alpha, int depth) {
        this.printLine("TARGET\tSIZE\tFP\tFN\tPFP\tPFN\tCFP\tCFN\tPCFP\tPCFN\tTIME");
        System.out.println("Creating graph.");
        Dag randomGraph = GraphUtils.randomDag(dimension, 0, dimension, 40, 40, 40, false);
        System.out.println("Starting simulation.");
        LargeSemSimulator simulator = new LargeSemSimulator(randomGraph);
        DataSet dataSet = simulator.simulateDataAcyclic(sampleSize);
        IndTestFisherZGeneralizedInverse test = new IndTestFisherZGeneralizedInverse(dataSet, alpha);
        System.out.println("Running MBF");
        this.examineRandomTargets(numTargets, randomGraph, test, depth);
    }

    private void examineDiscreteDatabase(int numTargets, int dimension, int sampleSize, double alpha, int depth) {
        this.printLine("TARGET\tSIZE\tFP\tFN\tPFP\tPFN\tCFP\tCFN\tPCFP\tPCFN\tTIME");
        Dag randomGraph = GraphUtils.randomDag(dimension, 0, dimension, 40, 40, 40, false);
        BayesPm bayesPm = new BayesPm(randomGraph, 2, 2);
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        DataSet dataSet = bayesIm.simulateData(sampleSize, false);
        IndTestChiSquare test = new IndTestChiSquare(dataSet, alpha);
        this.examineRandomTargets(numTargets, randomGraph, test, depth);
    }

    private void examineRandomTargets(int numTargets, Dag trueGraph, IndependenceTest test, int depth) {
        RandomUtil random = RandomUtil.getInstance();
        int numNodes = trueGraph.getNumNodes();
        for (int i = 0; i < numTargets; ++i) {
            int index = random.nextInt(numNodes);
            Node tTrue = trueGraph.getNodes().get(index);
            String targetName = tTrue.getName();
            Dag gTrue = GraphUtils.markovBlanketDag(tTrue, trueGraph);
            Mbfs mbSearch = new Mbfs(test, depth);
            Graph gEst = mbSearch.search(tTrue.getName());
            double elapsedTime = (double)mbSearch.getElapsedTime() / 1000.0;
            Node tEst = gEst.getNode(tTrue.getName());
            this.orientNondirectedEdgesAs(gEst, gTrue);
            MbUtils.trimToMbNodes(gEst, tEst, false);
            MbUtils.trimEdgesAmongParents(gEst, tEst);
            MbUtils.trimEdgesAmongParentsOfChildren(gEst, tEst);
            int fp = this.getFp(gEst.getNodes(), gTrue.getNodes(), targetName);
            int fn = this.getFn(gEst.getNodes(), gTrue.getNodes(), targetName);
            int pfp = this.getFp(gEst.getParents(tEst), gTrue.getParents(tTrue), targetName);
            int pfn = this.getFn(gEst.getParents(tEst), gTrue.getParents(tTrue), targetName);
            int cfp = this.getFp(gEst.getChildren(tEst), gTrue.getChildren(tTrue), targetName);
            int cfn = this.getFn(gEst.getChildren(tEst), gTrue.getChildren(tTrue), targetName);
            List<Node> childrenEst = gEst.getChildren(tEst);
            LinkedList<Node> pcEst = new LinkedList<Node>();
            for (Node node : childrenEst) {
                pcEst.addAll(gEst.getParents(node));
            }
            List<Node> childrenTrue = gTrue.getChildren(tTrue);
            LinkedList<Node> pcTrue = new LinkedList<Node>();
            for (Node node : childrenTrue) {
                pcTrue.addAll(gTrue.getParents(node));
            }
            int pcfp = this.getFp(pcEst, pcTrue, targetName);
            int pcfn = this.getFn(pcEst, pcTrue, targetName);
            int mbSize = this.extractVarNames(gTrue.getNodes(), tTrue.getName()).size();
            this.printLine(i + 1 + "\t" + mbSize + "\t" + fp + "\t" + fn + "\t" + pfp + "\t" + pfn + "\t" + cfp + "\t" + cfn + "\t" + pcfp + "\t" + pcfn + "\t" + this.nf.format(elapsedTime));
        }
    }

    private int getFp(List<Node> nodesEst, List<Node> nodesTrue, String targetName) {
        List<String> truth = this.extractVarNames(nodesTrue, targetName);
        List<String> est = this.extractVarNames(nodesEst, targetName);
        ArrayList<String> estAndTruth = new ArrayList<String>(est);
        estAndTruth.retainAll(truth);
        ArrayList<String> estFp = new ArrayList<String>(est);
        estFp.removeAll(estAndTruth);
        return estFp.size();
    }

    private int getFn(List<Node> nodesEst, List<Node> nodesTrue, String targetName) {
        List<String> truth = this.extractVarNames(nodesTrue, targetName);
        List<String> mbf = this.extractVarNames(nodesEst, targetName);
        ArrayList<String> estAndTruth = new ArrayList<String>(mbf);
        estAndTruth.retainAll(truth);
        ArrayList<String> estFn = new ArrayList<String>(truth);
        estFn.removeAll(estAndTruth);
        return estFn.size();
    }

    private void orientNondirectedEdgesAs(Graph gEst, Graph gTrue) {
        for (Edge edge : gEst.getEdges()) {
            if (!Edges.isUndirectedEdge(edge) && !Edges.isBidirectedEdge(edge)) continue;
            Node a1 = edge.getNode1();
            Node a2 = edge.getNode2();
            Node b1 = gTrue.getNode(a1.getName());
            Node b2 = gTrue.getNode(a2.getName());
            if (b1 == null || b2 == null || gTrue.getEdge(b1, b2) == null) continue;
            gEst.setEndpoint(a1, a2, gTrue.getEndpoint(b1, b2));
            gEst.setEndpoint(a2, a1, gTrue.getEndpoint(b2, b1));
        }
    }

    private List<String> extractVarNames(List<Node> nodes, String targetName) {
        ArrayList<String> varNames = new ArrayList<String>();
        for (Node node : nodes) {
            varNames.add(node.getName());
        }
        varNames.remove(targetName);
        Collections.sort(varNames);
        return varNames;
    }

    private void printLine(String s) {
        System.out.println(s);
        this.fileOut.println(s);
        this.fileOut.flush();
    }

    public static void main(String[] args) {
        new TestMbfReport("").rtestReportOutOvernightRun();
    }

    public static Test suite() {
        return new TestSuite(TestMbfReport.class);
    }
}

