/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.search.DeltaSextadTest;
import edu.cmu.tetrad.search.IntSextad;
import edu.cmu.tetrad.util.ChoiceGenerator;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class PurifySextadBased {
    private final boolean outputMessage = true;
    private final DeltaSextadTest sextadTest;
    private final List<Integer> nodes;
    private double alpha = 0.05;

    public PurifySextadBased(DeltaSextadTest sextadTest, double alpha) {
        this.sextadTest = sextadTest;
        this.nodes = new ArrayList<Integer>();
        for (int i = 0; i < sextadTest.getVariables().size(); ++i) {
            this.nodes.add(i);
        }
        this.alpha = alpha;
    }

    public List<List<Integer>> purify(List<List<Integer>> clustering) {
        clustering.addAll(clustering);
        if (clustering.isEmpty()) {
            throw new NullPointerException("Clusters not specified.");
        }
        List<List<Integer>> result = this.combinedSearch(clustering);
        ArrayList<List<Integer>> convertedResult = new ArrayList<List<Integer>>(result);
        System.out.println(convertedResult);
        return convertedResult;
    }

    private List<List<Integer>> combinedSearch(List<List<Integer>> clustering) {
        HashSet<Integer> eliminated = new HashSet<Integer>();
        HashSet<IntSextad> allImpurities = null;
        double cutoff = this.alpha;
        int count = 0;
        for (List<Integer> cluster : clustering) {
            System.out.println("Within cluster: " + ++count);
            Set<IntSextad> impurities = this.listSextads(cluster, eliminated, cutoff);
            if (impurities == null) continue;
            if (allImpurities == null) {
                allImpurities = new HashSet();
            }
            allImpurities.addAll(impurities);
        }
        Set<IntSextad> impurities = this.listCrossConstructSextads(clustering, eliminated, cutoff);
        if (impurities != null) {
            if (allImpurities == null) {
                allImpurities = new HashSet<IntSextad>();
            }
            allImpurities.addAll(impurities);
        }
        if (allImpurities == null) {
            return new ArrayList<List<Integer>>();
        }
        DecimalFormat nf = new DecimalFormat("0.####E00");
        while (true) {
            int max = 0;
            Integer maxNode = null;
            Map<Integer, Set<IntSextad>> impuritiesPerNode = this.getImpuritiesPerNode(allImpurities, eliminated);
            for (Integer node : this.nodes) {
                if (impuritiesPerNode.get(node).size() <= max) continue;
                max = impuritiesPerNode.get(node).size();
                maxNode = node;
            }
            if (max == 0) break;
            double minP = Double.POSITIVE_INFINITY;
            double maxP = Double.NEGATIVE_INFINITY;
            for (IntSextad IntSextad2 : impuritiesPerNode.get(maxNode)) {
                IntSextad[] intSextadArray = new IntSextad[]{IntSextad2};
                double pValue = this.sextadTest.getPValue(intSextadArray);
                if (pValue < minP) {
                    minP = pValue;
                }
                if (!(pValue > maxP)) continue;
                maxP = pValue;
            }
            impuritiesPerNode.remove(maxNode);
            eliminated.add(maxNode);
            System.out.println("Eliminated " + maxNode + " impurities = " + max + "q = " + nf.format(minP) + " maxP = " + nf.format(maxP));
        }
        return this.buildSolution(clustering, eliminated);
    }

    private Map<Integer, Set<IntSextad>> getImpuritiesPerNode(Set<IntSextad> allImpurities, Set<Integer> _eliminated) {
        HashMap<Integer, Set<IntSextad>> impuritiesPerNode = new HashMap<Integer, Set<IntSextad>>();
        for (Integer node : this.nodes) {
            impuritiesPerNode.put(node, new HashSet());
        }
        for (IntSextad IntSextad2 : allImpurities) {
            if (_eliminated.contains(IntSextad2.getI()) || _eliminated.contains(IntSextad2.getJ()) || _eliminated.contains(IntSextad2.getK()) || _eliminated.contains(IntSextad2.getL()) || _eliminated.contains(IntSextad2.getM()) || _eliminated.contains(IntSextad2.getN())) continue;
            ((Set)impuritiesPerNode.get(IntSextad2.getI())).add(IntSextad2);
            ((Set)impuritiesPerNode.get(IntSextad2.getJ())).add(IntSextad2);
            ((Set)impuritiesPerNode.get(IntSextad2.getK())).add(IntSextad2);
            ((Set)impuritiesPerNode.get(IntSextad2.getL())).add(IntSextad2);
            ((Set)impuritiesPerNode.get(IntSextad2.getM())).add(IntSextad2);
            ((Set)impuritiesPerNode.get(IntSextad2.getN())).add(IntSextad2);
        }
        return impuritiesPerNode;
    }

    private Set<IntSextad> listCrossConstructSextads(List<List<Integer>> clustering, Set<Integer> eliminated, double cutoff) {
        HashSet<IntSextad> allSextads = new HashSet<IntSextad>();
        boolean countable = false;
        for (int p1 = 0; p1 < clustering.size(); ++p1) {
            for (int p2 = p1 + 1; p2 < clustering.size(); ++p2) {
                ArrayList<Integer> crossCluster;
                int[] choice2;
                ChoiceGenerator gen2;
                int[] choice1;
                ChoiceGenerator gen1;
                List<Integer> cluster1 = clustering.get(p1);
                List<Integer> cluster2 = clustering.get(p2);
                if (cluster1.size() >= 5 && cluster2.size() >= 1) {
                    gen1 = new ChoiceGenerator(cluster1.size(), 5);
                    while ((choice1 = gen1.next()) != null) {
                        gen2 = new ChoiceGenerator(cluster2.size(), 1);
                        while ((choice2 = gen2.next()) != null) {
                            crossCluster = new ArrayList<Integer>();
                            for (int i : choice1) {
                                crossCluster.add(cluster1.get(i));
                            }
                            for (int i : choice2) {
                                crossCluster.add(cluster2.get(i));
                            }
                            Set<IntSextad> set = this.listSextads(crossCluster, eliminated, cutoff);
                            if (set == null) continue;
                            countable = true;
                            allSextads.addAll(set);
                        }
                    }
                }
                if (cluster2.size() < 5 || cluster1.size() < 1) continue;
                gen1 = new ChoiceGenerator(cluster2.size(), 5);
                while ((choice1 = gen1.next()) != null) {
                    gen2 = new ChoiceGenerator(cluster1.size(), 1);
                    while ((choice2 = gen2.next()) != null) {
                        crossCluster = new ArrayList();
                        for (int i : choice1) {
                            crossCluster.add(cluster2.get(i));
                        }
                        for (int i : choice2) {
                            crossCluster.add(cluster1.get(i));
                        }
                        Set<IntSextad> set = this.listSextads(crossCluster, eliminated, cutoff);
                        if (set == null) continue;
                        countable = true;
                        allSextads.addAll(set);
                    }
                }
            }
        }
        return countable ? allSextads : null;
    }

    private Set<IntSextad> listSextads(List<Integer> cluster, Set<Integer> eliminated, double cutoff) {
        int[] choice;
        if (cluster.size() < 6) {
            return null;
        }
        cluster = new ArrayList<Integer>(cluster);
        boolean countable = false;
        HashSet<IntSextad> Sextads = new HashSet<IntSextad>();
        ChoiceGenerator gen = new ChoiceGenerator(cluster.size(), 6);
        while ((choice = gen.next()) != null) {
            int _i = choice[0];
            int _j = choice[1];
            int _k = choice[2];
            int _l = choice[3];
            int _m = choice[4];
            int _n = choice[5];
            int m1 = cluster.get(_i);
            int m2 = cluster.get(_j);
            int m3 = cluster.get(_k);
            int m4 = cluster.get(_l);
            int m5 = cluster.get(_m);
            int m6 = cluster.get(_n);
            if (eliminated.contains(m1) || eliminated.contains(m2) || eliminated.contains(m3) || eliminated.contains(m4) || eliminated.contains(m5) || eliminated.contains(m6)) continue;
            countable = true;
            IntSextad t1 = new IntSextad(m1, m2, m3, m4, m5, m6);
            IntSextad t2 = new IntSextad(m1, m2, m4, m3, m5, m6);
            IntSextad t3 = new IntSextad(m1, m2, m5, m3, m4, m6);
            IntSextad t4 = new IntSextad(m1, m2, m6, m3, m4, m5);
            IntSextad t5 = new IntSextad(m1, m3, m4, m2, m5, m6);
            IntSextad t6 = new IntSextad(m1, m3, m5, m2, m4, m6);
            IntSextad t7 = new IntSextad(m1, m3, m6, m2, m4, m5);
            IntSextad t8 = new IntSextad(m1, m4, m5, m2, m3, m6);
            IntSextad t9 = new IntSextad(m1, m4, m6, m2, m3, m5);
            IntSextad t10 = new IntSextad(m1, m5, m6, m2, m3, m4);
            double p1 = this.sextadTest.getPValue(t1);
            double p2 = this.sextadTest.getPValue(t2);
            double p3 = this.sextadTest.getPValue(t3);
            double p4 = this.sextadTest.getPValue(t4);
            double p5 = this.sextadTest.getPValue(t5);
            double p6 = this.sextadTest.getPValue(t6);
            double p7 = this.sextadTest.getPValue(t7);
            double p8 = this.sextadTest.getPValue(t8);
            double p9 = this.sextadTest.getPValue(t9);
            double p10 = this.sextadTest.getPValue(t10);
            if (p1 < cutoff) {
                Sextads.add(t1);
            }
            if (p2 < cutoff) {
                Sextads.add(t2);
            }
            if (p3 < cutoff) {
                Sextads.add(t3);
            }
            if (p4 < cutoff) {
                Sextads.add(t4);
            }
            if (p5 < cutoff) {
                Sextads.add(t5);
            }
            if (p6 < cutoff) {
                Sextads.add(t6);
            }
            if (p7 < cutoff) {
                Sextads.add(t7);
            }
            if (p8 < cutoff) {
                Sextads.add(t8);
            }
            if (p9 < cutoff) {
                Sextads.add(t9);
            }
            if (!(p10 < cutoff)) continue;
            Sextads.add(t10);
        }
        return countable ? Sextads : null;
    }

    private List<List<Integer>> buildSolution(List<List<Integer>> clustering, Set<Integer> eliminated) {
        ArrayList<List<Integer>> solution = new ArrayList<List<Integer>>();
        for (List<Integer> cluster : clustering) {
            ArrayList<Integer> _cluster = new ArrayList<Integer>(cluster);
            _cluster.removeAll(eliminated);
            solution.add(_cluster);
        }
        return solution;
    }
}

