/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IPurify;
import edu.cmu.tetrad.search.Tetrad;
import edu.cmu.tetrad.search.TetradTest;
import edu.cmu.tetrad.util.ChoiceGenerator;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class PurifyTetradBased2
implements IPurify {
    private final boolean outputMessage = true;
    private final TetradTest tetradTest;
    private final List<Node> nodes;

    public PurifyTetradBased2(TetradTest tetradTest) {
        this.tetradTest = tetradTest;
        this.nodes = tetradTest.getVariables();
    }

    @Override
    public List<List<Node>> purify(List<List<Node>> clustering) {
        ArrayList<Node> originalNodes = new ArrayList<Node>();
        for (List<Node> list : clustering) {
            originalNodes.addAll(list);
        }
        ArrayList<List<Node>> _clustering = new ArrayList<List<Node>>();
        for (List<Node> list : clustering) {
            List<Node> converted = GraphUtils.replaceNodes(list, this.nodes);
            _clustering.add(converted);
        }
        if (_clustering.isEmpty()) {
            throw new NullPointerException("Clusters not specified.");
        }
        List<List<Node>> list = this.combinedSearch(_clustering);
        ArrayList<List<Node>> arrayList = new ArrayList<List<Node>>();
        for (List<Node> cluster : list) {
            List<Node> converted = GraphUtils.replaceNodes(cluster, originalNodes);
            arrayList.add(converted);
        }
        return arrayList;
    }

    @Override
    public void setTrueGraph(Graph mim) {
    }

    private List<List<Node>> combinedSearch(List<List<Node>> clustering) {
        HashSet<Node> eliminated = new HashSet<Node>();
        HashSet<Tetrad> allImpurities = null;
        double cutoff = this.tetradTest.getSignificance();
        boolean count = false;
        for (List<Node> cluster : clustering) {
            Set<Tetrad> impurities = this.listTetrads(cluster, eliminated, cutoff);
            if (impurities == null) continue;
            if (allImpurities == null) {
                allImpurities = new HashSet();
            }
            allImpurities.addAll(impurities);
        }
        Set<Tetrad> impurities = this.listCrossConstructTetrads(clustering, eliminated, cutoff);
        if (impurities != null) {
            if (allImpurities == null) {
                allImpurities = new HashSet<Tetrad>();
            }
            allImpurities.addAll(impurities);
        }
        if (allImpurities == null) {
            return new ArrayList<List<Node>>();
        }
        DecimalFormat nf = new DecimalFormat("0.####E00");
        while (true) {
            int max = 0;
            Node maxNode = null;
            Map<Node, Set<Tetrad>> impuritiesPerNode = this.getImpuritiesPerNode(allImpurities, eliminated);
            for (Node node : this.nodes) {
                if (impuritiesPerNode.get(node).size() <= max) continue;
                max = impuritiesPerNode.get(node).size();
                maxNode = node;
            }
            if (max == 0) break;
            double minP = Double.POSITIVE_INFINITY;
            double maxP = Double.NEGATIVE_INFINITY;
            for (Tetrad tetrad : impuritiesPerNode.get(maxNode)) {
                if (tetrad.getPValue() < minP) {
                    minP = tetrad.getPValue();
                }
                if (!(tetrad.getPValue() > maxP)) continue;
                maxP = tetrad.getPValue();
            }
            impuritiesPerNode.remove(maxNode);
            eliminated.add(maxNode);
        }
        return this.buildSolution(clustering, eliminated);
    }

    private Map<Node, Set<Tetrad>> getImpuritiesPerNode(Set<Tetrad> allImpurities, Set<Node> _eliminated) {
        HashMap<Node, Set<Tetrad>> impuritiesPerNode = new HashMap<Node, Set<Tetrad>>();
        for (Node node : this.nodes) {
            impuritiesPerNode.put(node, new HashSet());
        }
        for (Tetrad tetrad : allImpurities) {
            if (_eliminated.contains(tetrad.getI()) || _eliminated.contains(tetrad.getJ()) || _eliminated.contains(tetrad.getK()) || _eliminated.contains(tetrad.getL())) continue;
            ((Set)impuritiesPerNode.get(tetrad.getI())).add(tetrad);
            ((Set)impuritiesPerNode.get(tetrad.getJ())).add(tetrad);
            ((Set)impuritiesPerNode.get(tetrad.getK())).add(tetrad);
            ((Set)impuritiesPerNode.get(tetrad.getL())).add(tetrad);
        }
        return impuritiesPerNode;
    }

    private Set<Tetrad> listCrossConstructTetrads(List<List<Node>> clustering, Set<Node> eliminated, double cutoff) {
        HashSet<Tetrad> allTetrads = new HashSet<Tetrad>();
        boolean countable = false;
        for (int p1 = 0; p1 < clustering.size(); ++p1) {
            for (int p2 = p1 + 1; p2 < clustering.size(); ++p2) {
                ArrayList<Node> crossCluster;
                int[] choice2;
                ChoiceGenerator gen2;
                int[] choice1;
                ChoiceGenerator gen1;
                List<Node> cluster1 = clustering.get(p1);
                List<Node> cluster2 = clustering.get(p2);
                if (cluster1.size() >= 3 && cluster2.size() >= 1) {
                    gen1 = new ChoiceGenerator(cluster1.size(), 3);
                    while ((choice1 = gen1.next()) != null) {
                        gen2 = new ChoiceGenerator(cluster2.size(), 1);
                        while ((choice2 = gen2.next()) != null) {
                            crossCluster = new ArrayList<Node>();
                            for (int i : choice1) {
                                crossCluster.add(cluster1.get(i));
                            }
                            for (int i : choice2) {
                                crossCluster.add(cluster2.get(i));
                            }
                            Set<Tetrad> set = this.listTetrads(crossCluster, eliminated, cutoff);
                            if (set == null) continue;
                            countable = true;
                            allTetrads.addAll(set);
                        }
                    }
                }
                if (cluster2.size() >= 3 && cluster1.size() >= 1) {
                    gen1 = new ChoiceGenerator(cluster2.size(), 3);
                    while ((choice1 = gen1.next()) != null) {
                        gen2 = new ChoiceGenerator(cluster1.size(), 1);
                        while ((choice2 = gen2.next()) != null) {
                            crossCluster = new ArrayList();
                            for (int i : choice1) {
                                crossCluster.add(cluster2.get(i));
                            }
                            for (int i : choice2) {
                                crossCluster.add(cluster1.get(i));
                            }
                            Set<Tetrad> set = this.listTetrads(crossCluster, eliminated, cutoff);
                            if (set == null) continue;
                            countable = true;
                            allTetrads.addAll(set);
                        }
                    }
                }
                if (cluster1.size() < 2 || cluster2.size() < 2) continue;
                gen1 = new ChoiceGenerator(cluster1.size(), 2);
                while ((choice1 = gen1.next()) != null) {
                    gen2 = new ChoiceGenerator(cluster2.size(), 2);
                    while ((choice2 = gen2.next()) != null) {
                        crossCluster = new ArrayList();
                        for (int i : choice1) {
                            crossCluster.add(cluster1.get(i));
                        }
                        for (int i : choice2) {
                            crossCluster.add(cluster2.get(i));
                        }
                        Set<Tetrad> set = this.listTetrads2By2(crossCluster, eliminated, cutoff);
                        if (set == null) continue;
                        countable = true;
                        allTetrads.addAll(set);
                    }
                }
            }
        }
        return countable ? allTetrads : null;
    }

    private Set<Tetrad> listTetrads(List<Node> cluster, Set<Node> eliminated, double cutoff) {
        int[] choice;
        if (cluster.size() < 4) {
            return null;
        }
        cluster = new ArrayList<Node>(cluster);
        boolean countable = false;
        HashSet<Tetrad> tetrads = new HashSet<Tetrad>();
        ChoiceGenerator gen = new ChoiceGenerator(cluster.size(), 4);
        while ((choice = gen.next()) != null) {
            int _i = choice[0];
            int _j = choice[1];
            int _k = choice[2];
            int _l = choice[3];
            Node ci = cluster.get(_i);
            Node cj = cluster.get(_j);
            Node ck = cluster.get(_k);
            Node cl = cluster.get(_l);
            if (eliminated.contains(ci) || eliminated.contains(cj) || eliminated.contains(ck) || eliminated.contains(cl)) continue;
            countable = true;
            double p1 = this.tetradTest.tetradPValue(this.nodes.indexOf(ci), this.nodes.indexOf(cj), this.nodes.indexOf(ck), this.nodes.indexOf(cl));
            double p2 = this.tetradTest.tetradPValue(this.nodes.indexOf(ci), this.nodes.indexOf(cj), this.nodes.indexOf(cl), this.nodes.indexOf(ck));
            double p3 = this.tetradTest.tetradPValue(this.nodes.indexOf(ci), this.nodes.indexOf(ck), this.nodes.indexOf(cl), this.nodes.indexOf(cj));
            if (p1 < cutoff) {
                tetrads.add(new Tetrad(ci, cj, ck, cl, p1));
            }
            if (p2 < cutoff) {
                tetrads.add(new Tetrad(ci, cj, cl, ck, p2));
            }
            if (!(p3 < cutoff)) continue;
            tetrads.add(new Tetrad(ci, ck, cl, cj, p3));
        }
        return countable ? tetrads : null;
    }

    private Set<Tetrad> listTetrads2By2(List<Node> cluster, Set<Node> eliminated, double cutoff) {
        if (cluster.size() < 4) {
            return null;
        }
        cluster = new ArrayList<Node>(cluster);
        HashSet<Tetrad> tetrads = new HashSet<Tetrad>();
        Node ci = cluster.get(0);
        Node cj = cluster.get(1);
        Node ck = cluster.get(2);
        Node cl = cluster.get(3);
        if (eliminated.contains(ci) || eliminated.contains(cj) || eliminated.contains(ck) || eliminated.contains(cl)) {
            return null;
        }
        double p3 = this.tetradTest.tetradPValue(this.nodes.indexOf(ci), this.nodes.indexOf(ck), this.nodes.indexOf(cl), this.nodes.indexOf(cj));
        if (p3 < cutoff) {
            tetrads.add(new Tetrad(ci, ck, cl, cj, p3));
        }
        return tetrads;
    }

    private List<List<Node>> buildSolution(List<List<Node>> clustering, Set<Node> eliminated) {
        ArrayList<List<Node>> solution = new ArrayList<List<Node>>();
        for (List<Node> cluster : clustering) {
            ArrayList<Node> _cluster = new ArrayList<Node>(cluster);
            _cluster.removeAll(eliminated);
            solution.add(_cluster);
        }
        return solution;
    }
}

