/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IPurify;
import edu.cmu.tetrad.search.TetradTest;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;

public class PurifyTetradBased
implements IPurify {
    private final boolean outputMessage = true;
    private final TetradTest tetradTest;
    private final int numVars;
    boolean doFdr;
    boolean listTetrads;

    public PurifyTetradBased(TetradTest tetradTest) {
        this.tetradTest = tetradTest;
        this.numVars = tetradTest.getVarNames().length;
    }

    @Override
    public List<List<Node>> purify(List<List<Node>> clustering) {
        System.out.println("*** " + clustering);
        List<int[]> _clustering = this.convertListToInt(clustering);
        System.out.println("&&&");
        this.printIntPartition(_clustering);
        List _clustering2 = this.purify2(_clustering);
        System.out.println("%%%");
        this.printIntPartition(_clustering2);
        return this.convertIntToList(_clustering2);
    }

    @Override
    public void setTrueGraph(Graph mim) {
        throw new UnsupportedOperationException();
    }

    private void printIntPartition(List<int[]> clustering) {
        for (int i = 0; i < clustering.size(); ++i) {
            int[] cluster = clustering.get(i);
            System.out.print(i + ": ");
            for (int k : cluster) {
                System.out.print(k + " ");
            }
            System.out.println();
        }
        System.out.println();
    }

    private List<int[]> convertListToInt(List<List<Node>> clustering) {
        List<Node> nodes = this.tetradTest.getVariables();
        ArrayList<int[]> _clustering = new ArrayList<int[]>();
        for (List<Node> cluster : clustering) {
            int[] _cluster = new int[cluster.size()];
            for (int j = 0; j < cluster.size(); ++j) {
                for (int k = 0; k < nodes.size(); ++k) {
                    if (!nodes.get(k).getName().equals(cluster.get(j).getName())) continue;
                    _cluster[j] = k;
                }
            }
            _clustering.add(_cluster);
        }
        return _clustering;
    }

    private List<List<Node>> convertIntToList(List<int[]> clustering) {
        List<Node> nodes = this.tetradTest.getVariables();
        ArrayList<List<Node>> _clustering = new ArrayList<List<Node>>();
        for (int[] cluster : clustering) {
            ArrayList<Node> _cluster = new ArrayList<Node>();
            for (int k : cluster) {
                _cluster.add(nodes.get(k));
            }
            _clustering.add(_cluster);
        }
        return _clustering;
    }

    private List purify2(List clustering) {
        return this.tetradBasedPurify(clustering);
    }

    private List tetradBasedPurify(List clustering) {
        boolean[] eliminated = new boolean[this.numVars];
        for (int i = 0; i < this.numVars; ++i) {
            eliminated[i] = false;
        }
        this.printlnMessage("TETRAD-BASED PURIFY:");
        this.printlnMessage("Finding Unidimensional Measurement Models");
        this.printlnMessage();
        this.printlnMessage("Initially Specified Measurement Model");
        this.printlnMessage();
        this.printClustering(clustering, eliminated);
        this.printlnMessage();
        this.printlnMessage("INTRA-CONSTRUCT PHASE.");
        this.printlnMessage("----------------------");
        this.printlnMessage();
        for (Object o : clustering) {
            this.intraConstructPhase2((int[])o, eliminated);
        }
        this.printlnMessage();
        this.printlnMessage("CROSS-CONSTRUCT PHASE.");
        this.printlnMessage("----------------------");
        this.printlnMessage();
        this.crossConstructPhase2(clustering, eliminated);
        this.printlnMessage();
        this.printlnMessage("------------------------------------------------------");
        this.printlnMessage("Output Measurement Model");
        List output = this.buildSolution(clustering, eliminated);
        this.printClustering(output, eliminated);
        return output;
    }

    private void intraConstructPhase2(int[] _cluster, boolean[] eliminated) {
        List<Double> pValues2;
        ArrayList<Integer> cluster = new ArrayList<Integer>();
        for (int i : _cluster) {
            cluster.add(i);
        }
        double cutoff = this.tetradTest.getSignificance();
        if (this.doFdr) {
            ArrayList<Double> allPValues = new ArrayList<Double>(this.listPValues(cluster, eliminated, Double.MAX_VALUE));
            System.out.println("# p values for this cluster: " + allPValues.size());
            Collections.sort(allPValues);
            cutoff = 1.0;
            for (int c = 0; c < allPValues.size(); ++c) {
                if (!((Double)allPValues.get(c) >= this.tetradTest.getSignificance() * ((double)c + 1.0) / (double)allPValues.size())) continue;
                cutoff = (Double)allPValues.get(c);
                break;
            }
            System.out.println("FDR cutoff = " + cutoff);
        }
        if ((pValues2 = this.listPValues(cluster, eliminated, cutoff)) == null) {
            System.out.println("Nothing to count.");
            return;
        }
        int numImpurities = pValues2.size();
        System.out.println("Num impurities going in = " + numImpurities);
        boolean minImpurities = false;
        while (numImpurities > 0) {
            System.out.println("Num impurities this round = " + numImpurities);
            int min = Integer.MAX_VALUE;
            int minIndex = -1;
            ArrayList<Integer> minList = new ArrayList<Integer>();
            Iterator iterator = cluster.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                if (eliminated[i]) continue;
                eliminated[i] = true;
                List<Double> pValues = this.listPValues(cluster, eliminated, cutoff);
                if (pValues == null) {
                    eliminated[i] = false;
                    continue;
                }
                System.out.println("Tried dropping " + this.tetradTest.getVarNames()[i] + " (" + pValues.size() + " impurities)");
                eliminated[i] = false;
                if (pValues.size() < min) {
                    min = pValues.size();
                    minIndex = i;
                    numImpurities = min;
                    minList = new ArrayList();
                    minList.add(i);
                    continue;
                }
                if (pValues.size() != min) continue;
                minList.add(i);
            }
            if (minList.isEmpty()) break;
            if (minIndex == -1) continue;
            if (min < 0) break;
            int index = (Integer)minList.get(RandomUtil.getInstance().nextInt(minList.size()));
            for (int m = 0; m < minList.size(); ++m) {
                eliminated[index] = true;
                numImpurities = min;
                System.out.println("Dropped " + this.tetradTest.getVarNames()[index]);
            }
        }
    }

    private void crossConstructPhase2(List<int[]> clustering, boolean[] eliminated) {
        List<Double> pValues;
        double cutoff = this.tetradTest.getSignificance();
        if (this.doFdr) {
            ArrayList<Double> allPValues = new ArrayList<Double>(this.listCrossConstructPValues(clustering, eliminated, Double.MAX_VALUE));
            System.out.println("Num p values cross clusters: " + allPValues.size());
            if (allPValues.isEmpty()) {
                return;
            }
            Collections.sort(allPValues);
            System.out.println("# p values = " + allPValues.size());
            cutoff = 1.0;
            System.out.println("significance = " + this.tetradTest.getSignificance());
            for (int c = 0; c < allPValues.size(); ++c) {
                if (!((Double)allPValues.get(c) >= this.tetradTest.getSignificance() * ((double)c + 1.0) / (double)allPValues.size())) continue;
                cutoff = (Double)allPValues.get(c);
                break;
            }
        }
        if ((pValues = this.listCrossConstructPValues(clustering, eliminated, cutoff)) == null) {
            System.out.println("Nothing to count.");
            return;
        }
        int numImpurities = pValues.size();
        boolean minImpurities = false;
        while (numImpurities > 0) {
            System.out.println("Num impurities this round = " + numImpurities);
            int min = Integer.MAX_VALUE;
            int minIndex = -1;
            ArrayList<Integer> minList = new ArrayList<Integer>();
            for (int i = 0; i < eliminated.length; ++i) {
                if (eliminated[i]) continue;
                eliminated[i] = true;
                List<Double> pValuesCrossConstruct = this.listCrossConstructPValues(clustering, eliminated, cutoff);
                if (pValuesCrossConstruct == null) {
                    eliminated[i] = false;
                    continue;
                }
                System.out.println("Tried dropping " + this.tetradTest.getVarNames()[i] + " (" + pValuesCrossConstruct.size() + " impurities)");
                eliminated[i] = false;
                if (pValuesCrossConstruct.size() < min) {
                    min = pValuesCrossConstruct.size();
                    minIndex = i;
                    numImpurities = min;
                    minList = new ArrayList();
                    minList.add(i);
                    continue;
                }
                if (pValuesCrossConstruct.size() != min) continue;
                minList.add(i);
            }
            if (minList.isEmpty()) break;
            if (minIndex == -1) continue;
            if (min < 0) break;
            int index = (Integer)minList.get(RandomUtil.getInstance().nextInt(minList.size()));
            eliminated[index] = true;
            numImpurities = min;
            System.out.println("Dropped " + this.tetradTest.getVarNames()[index]);
        }
        this.printClustering(clustering, eliminated);
    }

    private List<Double> listCrossConstructPValues(List<int[]> clustering, boolean[] eliminated, double cutoff) {
        ArrayList<Double> allPValues = new ArrayList<Double>();
        boolean countable = false;
        for (int p1 = 0; p1 < clustering.size(); ++p1) {
            for (int p2 = p1 + 1; p2 < clustering.size(); ++p2) {
                ArrayList<Integer> crossCluster;
                int[] choice2;
                ChoiceGenerator gen2;
                int[] choice1;
                ChoiceGenerator gen1;
                int[] cluster1 = clustering.get(p1);
                int[] cluster2 = clustering.get(p2);
                if (cluster1.length >= 3 && cluster2.length >= 1) {
                    gen1 = new ChoiceGenerator(cluster1.length, 3);
                    while ((choice1 = gen1.next()) != null) {
                        gen2 = new ChoiceGenerator(cluster2.length, 1);
                        while ((choice2 = gen2.next()) != null) {
                            crossCluster = new ArrayList<Integer>();
                            for (int i : choice1) {
                                crossCluster.add(cluster1[i]);
                            }
                            for (int i : choice2) {
                                crossCluster.add(cluster2[i]);
                            }
                            List<Double> list = this.listPValues(crossCluster, eliminated, cutoff);
                            if (list == null) continue;
                            countable = true;
                            allPValues.addAll(list);
                        }
                    }
                }
                if (cluster2.length >= 3 && cluster1.length >= 1) {
                    gen1 = new ChoiceGenerator(cluster2.length, 3);
                    while ((choice1 = gen1.next()) != null) {
                        gen2 = new ChoiceGenerator(cluster1.length, 1);
                        while ((choice2 = gen2.next()) != null) {
                            crossCluster = new ArrayList();
                            for (int i : choice1) {
                                crossCluster.add(cluster2[i]);
                            }
                            for (int i : choice2) {
                                crossCluster.add(cluster1[i]);
                            }
                            List<Double> list = this.listPValues(crossCluster, eliminated, cutoff);
                            if (list == null) continue;
                            countable = true;
                            allPValues.addAll(list);
                        }
                    }
                }
                if (cluster1.length < 2 || cluster2.length < 2) continue;
                gen1 = new ChoiceGenerator(cluster1.length, 2);
                while ((choice1 = gen1.next()) != null) {
                    gen2 = new ChoiceGenerator(cluster2.length, 2);
                    while ((choice2 = gen2.next()) != null) {
                        crossCluster = new ArrayList();
                        for (int i : choice1) {
                            crossCluster.add(cluster1[i]);
                        }
                        for (int i : choice2) {
                            crossCluster.add(cluster2[i]);
                        }
                        List<Double> list = this.listPValues2by2(crossCluster, eliminated, cutoff);
                        if (list == null) continue;
                        countable = true;
                        allPValues.addAll(list);
                    }
                }
            }
        }
        return countable ? allPValues : null;
    }

    private List<Double> listPValues(List<Integer> cluster, boolean[] eliminated, double cutoff) {
        int[] choice;
        if (cluster.size() < 4) {
            return null;
        }
        boolean countable = false;
        ArrayList<Double> pValues = new ArrayList<Double>();
        ChoiceGenerator gen = new ChoiceGenerator(cluster.size(), 4);
        while ((choice = gen.next()) != null) {
            int i = choice[0];
            int j = choice[1];
            int k = choice[2];
            int l = choice[3];
            int ci = cluster.get(i);
            int cj = cluster.get(j);
            int ck = cluster.get(k);
            int cl = cluster.get(l);
            if (eliminated[ci] || eliminated[cj] || eliminated[ck] || eliminated[cl]) continue;
            countable = true;
            double p1 = this.tetradTest.tetradPValue(ci, cj, ck, cl);
            double p2 = this.tetradTest.tetradPValue(ci, cl, cj, ck);
            double p3 = this.tetradTest.tetradPValue(ci, ck, cj, cl);
            if (p1 < cutoff) {
                this.printTetrad(ci, cj, ck, cl, p1);
                pValues.add(p1);
            }
            if (p2 < cutoff) {
                this.printTetrad(ci, cl, cj, ck, p2);
                pValues.add(p2);
            }
            if (!(p3 < cutoff)) continue;
            this.printTetrad(ci, ck, cj, cl, p3);
            pValues.add(p3);
        }
        return countable ? pValues : null;
    }

    private void printTetrad(int ci, int cj, int ck, int cl, double p1) {
        if (this.listTetrads) {
            String[] varNames = this.tetradTest.getVarNames();
            System.out.println("Tetrad <" + varNames[ci] + ", " + varNames[cj] + ", " + varNames[ck] + ", " + varNames[cl] + "> p = " + p1);
        }
    }

    private List<Double> listPValues2by2(List<Integer> cluster, boolean[] eliminated, double cutoff) {
        if (cluster.size() < 4) {
            return new ArrayList<Double>();
        }
        ArrayList<Double> pValues = new ArrayList<Double>();
        int x = cluster.get(0);
        int z = cluster.get(1);
        int y = cluster.get(2);
        int w = cluster.get(3);
        if (eliminated[x] || eliminated[z] || eliminated[y] || eliminated[w]) {
            return null;
        }
        double p1 = this.tetradTest.tetradPValue(x, y, w, z);
        if (p1 < cutoff) {
            this.printTetrad(x, y, w, z, p1);
            pValues.add(p1);
        }
        return pValues;
    }

    private void printMessage(String message) {
        Objects.requireNonNull(this);
        System.out.print(message);
    }

    private void printlnMessage(String message) {
        Objects.requireNonNull(this);
        System.out.println(message);
    }

    private void printlnMessage() {
        Objects.requireNonNull(this);
        System.out.println();
    }

    private void printClustering(List clustering, boolean[] eliminated) {
        for (Object o : clustering) {
            int[] c = (int[])o;
            this.printCluster(c, eliminated);
        }
    }

    private void printCluster(int[] c, boolean[] eliminated) {
        int i;
        String[] sorted = new String[c.length];
        for (i = 0; i < c.length; ++i) {
            sorted[i] = this.tetradTest.getVarNames()[c[i]];
            if (!eliminated[c[i]]) continue;
            sorted[i] = sorted[i] + "X";
        }
        for (i = 0; i < sorted.length - 1; ++i) {
            String min = sorted[i];
            int min_idx = i;
            for (int j = i + 1; j < sorted.length; ++j) {
                if (sorted[j].compareTo(min) >= 0) continue;
                min = sorted[j];
                min_idx = j;
            }
            String temp = sorted[i];
            sorted[i] = min;
            sorted[min_idx] = temp;
        }
        for (String s : sorted) {
            this.printMessage(s + " ");
        }
        this.printlnMessage();
    }

    private List buildSolution(List clustering, boolean[] eliminated) {
        ArrayList<int[]> solution = new ArrayList<int[]>();
        for (Object o : clustering) {
            int[] next = (int[])o;
            int[] draftArea = new int[next.length];
            int draftCount = 0;
            for (int j : next) {
                if (eliminated[j]) continue;
                draftArea[draftCount++] = j;
            }
            int[] realCluster = new int[draftCount];
            System.arraycopy(draftArea, 0, realCluster, 0, draftCount);
            solution.add(realCluster);
        }
        return solution;
    }
}

