/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodePair;
import edu.cmu.tetrad.search.FasDci;
import edu.cmu.tetrad.search.IndTestChiSquare;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.ResolveSepsets;
import edu.cmu.tetrad.search.SepsetMapDci;
import edu.cmu.tetrad.util.RandomUtil;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import junit.framework.TestCase;

public class TestResolveSepsets
extends TestCase {
    private String dir = "/home/rtillman/Desktop/temp/";
    private File correctFile = new File(this.dir + "correct.dat");
    private File incorrectFile = new File(this.dir + "incorrect.dat");
    private File accuracy = new File(this.dir + "resolve.dat");
    private File indFile = new File(this.dir + "independent.dat");
    private String[] methods = new String[]{"fisher", "tippett", "worsleyfriston", "stouffer", "mudholkergeorge", "averagetest", "average", "random"};
    private int[] nsizes = new int[]{50, 100, 500, 1000, 2500};

    public TestResolveSepsets(String name) {
        super(name);
    }

    public void testSimulation() {
    }

    public void discreteTest(int d) {
        Dag graph = GraphUtils.randomDag(d, 0, d, 3, 2, 1, true);
        BayesPm pm = new BayesPm(graph, 4, 4);
        MlBayesIm im = new MlBayesIm(pm, 1);
        List<Set<Node>> subsetsNodes = TestResolveSepsets.subsetsFromDag(graph, 2);
        for (int n : this.nsizes) {
            DataSet dataset = im.simulateData(2 * n, false);
            List<DataSet> missingDatasets = TestResolveSepsets.missingDatasets(dataset, subsetsNodes);
            ArrayList<IndependenceTest> independenceTests = new ArrayList<IndependenceTest>();
            for (DataSet missingDataset : missingDatasets) {
                independenceTests.add(new IndTestChiSquare(missingDataset, 0.01));
            }
            HashMap<String, Integer> correct = new HashMap<String, Integer>();
            HashMap<String, Integer> incorrect = new HashMap<String, Integer>();
            HashMap<String, Integer> independent = new HashMap<String, Integer>();
            HashMap<String, Integer> associated = new HashMap<String, Integer>();
            for (String method : this.methods) {
                correct.put(method, 0);
                incorrect.put(method, 0);
                independent.put(method, 0);
                associated.put(method, 0);
            }
            this.tryMethods(graph, independenceTests, correct, incorrect, independent, associated);
            try {
                FileWriter correctWr = new FileWriter(this.correctFile, true);
                PrintWriter pcorrectWr = new PrintWriter(correctWr);
                FileWriter incorrectWr = new FileWriter(this.incorrectFile, true);
                PrintWriter pincorrectWr = new PrintWriter(incorrectWr);
                FileWriter accuracyWr = new FileWriter(this.accuracy, true);
                PrintWriter paccuracyWr = new PrintWriter(accuracyWr);
                FileWriter indWr = new FileWriter(this.indFile, true);
                PrintWriter pindWr = new PrintWriter(indWr);
                for (String method : this.methods) {
                    int correctnum = (Integer)correct.get(method);
                    int incorrectnum = (Integer)incorrect.get(method);
                    pcorrectWr.print(correctnum + ",");
                    pincorrectWr.print(incorrectnum + ",");
                    double acc = 1.0;
                    if (correctnum + incorrectnum > 0) {
                        acc = (double)correctnum / (double)(correctnum + incorrectnum);
                    }
                    paccuracyWr.print(acc + ",");
                    int indnum = (Integer)independent.get(method);
                    int assnum = (Integer)associated.get(method);
                    if (indnum + assnum <= 0) continue;
                    pindWr.print((double)indnum / (double)(indnum + assnum) + ",");
                }
                pcorrectWr.close();
                pincorrectWr.close();
                paccuracyWr.close();
                pindWr.close();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        try {
            FileWriter correctWr = new FileWriter(this.correctFile, true);
            PrintWriter pcorrectWr = new PrintWriter(correctWr);
            FileWriter incorrectWr = new FileWriter(this.incorrectFile, true);
            PrintWriter pincorrectWr = new PrintWriter(incorrectWr);
            FileWriter accuracyWr = new FileWriter(this.accuracy, true);
            PrintWriter paccuracyWr = new PrintWriter(accuracyWr);
            FileWriter indWr = new FileWriter(this.indFile, true);
            PrintWriter pindWr = new PrintWriter(indWr);
            pcorrectWr.println();
            pincorrectWr.println();
            paccuracyWr.println();
            pindWr.println();
            pcorrectWr.close();
            pincorrectWr.close();
            paccuracyWr.close();
            pindWr.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static List<Set<Node>> subsetsFromDag(Graph graph, int s) {
        int sets = s;
        int n = graph.getNumNodes();
        List<Node> nodes = graph.getNodes();
        ArrayList<Node> overlap = new ArrayList<Node>();
        HashSet<Node> set1 = new HashSet<Node>();
        HashSet<Node> set2 = new HashSet<Node>();
        RandomUtil generator = RandomUtil.getInstance();
        long overlapsize = Math.round((double)n * 0.6);
        while ((long)nodes.size() > (long)n - overlapsize) {
            overlap.add(nodes.remove(generator.nextInt(nodes.size())));
        }
        while (!nodes.isEmpty()) {
            if (generator.nextInt(2) == 0) {
                set1.add(nodes.remove(generator.nextInt(nodes.size())));
                continue;
            }
            set2.add(nodes.remove(generator.nextInt(nodes.size())));
        }
        set1.addAll(overlap);
        set2.addAll(overlap);
        ArrayList<Set<Node>> dds = new ArrayList<Set<Node>>();
        dds.add(set1);
        dds.add(set2);
        sets -= 2;
        nodes = graph.getNodes();
        while (sets > 0) {
            ArrayList lastSet = new ArrayList((Collection)dds.get(dds.size() - 1));
            HashSet newSet = new HashSet();
            while ((long)newSet.size() < overlapsize) {
                newSet.add(lastSet.remove(generator.nextInt(lastSet.size())));
            }
            int averageSetSize = 0;
            for (Set set : dds) {
                averageSetSize += set.size();
            }
            averageSetSize /= dds.size();
            for (Node node : nodes) {
                if (!lastSet.contains(node)) {
                    newSet.add(node);
                }
                if (newSet.size() < averageSetSize) continue;
                break;
            }
            dds.add(newSet);
            --sets;
        }
        return dds;
    }

    private static List<DataSet> missingDatasets(DataSet dataset, List<Set<Node>> missingVars) {
        ArrayList<DataSet> datasetSet = new ArrayList<DataSet>();
        int datapoint = 0;
        for (Set<Node> missing : missingVars) {
            int[] ints = new int[dataset.getNumRows() / missingVars.size()];
            for (int i = 0; i < dataset.getNumRows() / missingVars.size(); ++i) {
                ints[i] = datapoint++;
            }
            DataSet newDataset = dataset.subsetRows(ints);
            for (Node node : dataset.getVariables()) {
                boolean remove = true;
                for (Node node2 : missing) {
                    if (!node.getName().equals(node2.getName())) continue;
                    remove = false;
                    break;
                }
                if (!remove) continue;
                newDataset.removeColumn(node);
            }
            datasetSet.add(newDataset);
        }
        return datasetSet;
    }

    public void tryMethods(Graph graph, List<IndependenceTest> independenceTests, Map<String, Integer> correct, Map<String, Integer> incorrect, Map<String, Integer> independent, Map<String, Integer> associated) {
        ArrayList<SepsetMapDci> sepsets = new ArrayList<SepsetMapDci>();
        HashSet<Node> allVars = new HashSet<Node>();
        for (IndependenceTest independenceTest : independenceTests) {
            allVars.addAll(independenceTest.getVariables());
        }
        for (IndependenceTest independenceTest : independenceTests) {
            EdgeListGraph fullGraph = new EdgeListGraph(new ArrayList<Node>(allVars));
            fullGraph.fullyConnect(Endpoint.CIRCLE);
            FasDci adj = new FasDci(new EdgeListGraph(fullGraph), independenceTest);
            adj.setDepth(3);
            sepsets.add(adj.search());
        }
        List<NodePair> allPairs = ResolveSepsets.allNodePairs(new ArrayList<Node>(allVars));
        for (String method : this.methods) {
            SepsetMapDci resolvedInd = new SepsetMapDci();
            SepsetMapDci resolvedDep = new SepsetMapDci();
            ResolveSepsets.ResolveSepsets(sepsets, independenceTests, method, resolvedInd, resolvedDep);
            for (NodePair pair : allPairs) {
                List<List<Node>> depCondSets;
                Node x = graph.getNode(pair.getFirst().getName());
                Node y = graph.getNode(pair.getSecond().getName());
                List<List<Node>> indCondSets = resolvedInd.getSet(pair.getFirst(), pair.getSecond());
                if (indCondSets != null) {
                    for (List<Node> indCondSet : indCondSets) {
                        Integer num;
                        ArrayList<Node> z = new ArrayList<Node>();
                        for (Node c : indCondSet) {
                            z.add(graph.getNode(c.getName()));
                        }
                        if (graph.isDSeparatedFrom(x, y, z)) {
                            num = correct.get(method) + 1;
                            correct.put(method, num);
                            num = independent.get(method) + 1;
                            independent.put(method, num);
                            continue;
                        }
                        num = incorrect.get(method) + 1;
                        incorrect.put(method, num);
                        num = associated.get(method) + 1;
                        associated.put(method, num);
                    }
                }
                if ((depCondSets = resolvedDep.getSet(pair.getFirst(), pair.getSecond())) == null) continue;
                for (List<Node> depCondSet : depCondSets) {
                    Integer num;
                    ArrayList<Node> z = new ArrayList<Node>();
                    for (Node c : depCondSet) {
                        z.add(graph.getNode(c.getName()));
                    }
                    if (graph.isDConnectedTo(x, y, z)) {
                        num = correct.get(method) + 1;
                        correct.put(method, num);
                        num = associated.get(method) + 1;
                        associated.put(method, num);
                        continue;
                    }
                    num = incorrect.get(method) + 1;
                    incorrect.put(method, num);
                    num = independent.get(method) + 1;
                    independent.put(method, num);
                }
            }
        }
    }
}

