/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.search.ClusterSignificance;
import edu.cmu.tetrad.search.ContinuousTetradTest;
import edu.cmu.tetrad.search.DeltaTetradTest;
import edu.cmu.tetrad.search.TestType;
import edu.cmu.tetrad.search.Tetrad;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;

public class FindOneFactorClusters {
    private final CorrelationMatrix corr;
    private final List<Node> variables;
    private final double alpha;
    private final DeltaTetradTest test;
    private final ContinuousTetradTest test2;
    private final transient DataModel dataModel;
    private final TestType testType;
    private List<List<Node>> clusters;
    private boolean verbose;
    private boolean significanceChecked;
    private final Algorithm algorithm;
    private ClusterSignificance.CheckType checkType = ClusterSignificance.CheckType.Clique;

    public FindOneFactorClusters(ICovarianceMatrix cov, TestType testType, Algorithm algorithm, double alpha) {
        if (testType == null) {
            throw new NullPointerException("Null indepTest type.");
        }
        cov = new CovarianceMatrix(cov);
        this.variables = cov.getVariables();
        this.alpha = alpha;
        this.testType = testType;
        this.test = new DeltaTetradTest(cov);
        this.test2 = new ContinuousTetradTest(cov, testType, alpha);
        this.dataModel = cov;
        this.algorithm = algorithm;
        this.corr = new CorrelationMatrix(cov);
    }

    public FindOneFactorClusters(DataSet dataSet, TestType testType, Algorithm algorithm, double alpha) {
        if (testType == null) {
            throw new NullPointerException("Null test type.");
        }
        this.variables = dataSet.getVariables();
        this.alpha = alpha;
        this.testType = testType;
        this.test = new DeltaTetradTest(dataSet);
        this.test2 = new ContinuousTetradTest(dataSet, testType, alpha);
        this.dataModel = dataSet;
        this.algorithm = algorithm;
        this.corr = new CorrelationMatrix(dataSet);
    }

    private int findFrequentestIndex(Integer[] outliers) {
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (Integer outlier : outliers) {
            if (map.containsKey(outlier)) {
                map.put(outlier, (Integer)map.get(outlier) + 1);
                continue;
            }
            map.put(outlier, 1);
        }
        Set set = map.entrySet();
        Iterator it = set.iterator();
        int nums = 0;
        int key = 0;
        while (it.hasNext()) {
            Map.Entry entry = it.next();
            if ((Integer)entry.getValue() <= nums) continue;
            nums = (Integer)entry.getValue();
            key = (Integer)entry.getKey();
        }
        return key;
    }

    private ArrayList<Integer> removeVariables(Matrix correlationMatrix, double lowerBound, double upperBound, double percentBound) {
        Integer[] outlier = new Integer[correlationMatrix.rows() * (correlationMatrix.rows() - 1)];
        int count = 0;
        for (int i = 2; i < correlationMatrix.rows() + 1; ++i) {
            for (int j = 1; j < i; ++j) {
                if (FastMath.abs(correlationMatrix.get(i - 1, j - 1)) < lowerBound || FastMath.abs(correlationMatrix.get(i - 1, j - 1)) > upperBound) {
                    outlier[count * 2] = i;
                    outlier[count * 2 + 1] = j;
                } else {
                    outlier[count * 2] = 0;
                    outlier[count * 2 + 1] = 0;
                }
                ++count;
            }
        }
        ArrayList<Integer> removedVariables = new ArrayList<Integer>();
        while (outlier.length > 1 && (double)removedVariables.size() < percentBound * (double)correlationMatrix.rows()) {
            int worstVariable = this.findFrequentestIndex(outlier);
            if (worstVariable > 0) {
                removedVariables.add(worstVariable);
            }
            for (int i = 1; i < outlier.length + 1; ++i) {
                if (outlier[i - 1] != worstVariable) continue;
                outlier[i - 1] = 0;
                if (i % 2 != 0) {
                    outlier[i] = 0;
                    continue;
                }
                outlier[i - 2] = 0;
            }
            outlier = this.removeZeroIndex(outlier);
        }
        this.log(removedVariables.size() + " variables removed: " + ClusterSignificance.variablesForIndices(removedVariables, this.variables));
        return removedVariables;
    }

    private Integer[] removeZeroIndex(Integer[] outlier) {
        ArrayList list = new ArrayList();
        Collections.addAll(list, outlier);
        for (Integer element : outlier) {
            if (element >= 1) continue;
            list.remove(element);
        }
        return list.toArray(new Integer[1]);
    }

    public Graph search() {
        Set<List<Integer>> allClusters;
        if (this.algorithm == Algorithm.SAG) {
            allClusters = this.estimateClustersTetradsFirst();
        } else if (this.algorithm == Algorithm.GAP) {
            allClusters = this.estimateClustersTriplesFirst();
        } else {
            throw new IllegalStateException("Expected SAG or GAP: " + this.testType);
        }
        this.clusters = ClusterSignificance.variablesForIndices2(allClusters, this.variables);
        System.out.println("allClusters = " + allClusters);
        System.out.println("this.clusters = " + this.clusters);
        ClusterSignificance clusterSignificance = new ClusterSignificance(this.variables, this.dataModel);
        clusterSignificance.printClusterPValues(allClusters);
        return this.convertToGraph(allClusters);
    }

    private Set<List<Integer>> estimateClustersTriplesFirst() {
        List<Integer> _variables = this.allVariables();
        Set<Set<Integer>> triples = this.findPuretriples(_variables);
        Set<Set<Integer>> combined = this.combinePuretriples(triples, _variables);
        HashSet<List<Integer>> _combined = new HashSet<List<Integer>>();
        for (Set<Integer> c : combined) {
            ArrayList<Integer> a = new ArrayList<Integer>(c);
            Collections.sort(a);
            _combined.add(a);
        }
        return _combined;
    }

    private List<Integer> allVariables() {
        ArrayList<Integer> _variables = new ArrayList<Integer>();
        for (int i = 0; i < this.variables.size(); ++i) {
            _variables.add(i);
        }
        return _variables;
    }

    private Set<List<Integer>> estimateClustersTetradsFirst() {
        List<Integer> _variables = this.allVariables();
        Set<List<Integer>> pureClusters = this.findPureClusters(_variables);
        Set<List<Integer>> mixedClusters = this.findMixedClusters(_variables, this.unionPure(pureClusters));
        HashSet<List<Integer>> allClusters = new HashSet<List<Integer>>(pureClusters);
        allClusters.addAll(mixedClusters);
        return allClusters;
    }

    private Set<Set<Integer>> findPuretriples(List<Integer> allVariables) {
        int[] choice;
        if (allVariables.size() < 4) {
            return new HashSet<Set<Integer>>();
        }
        this.log("Finding pure triples.");
        ChoiceGenerator gen = new ChoiceGenerator(allVariables.size(), 3);
        HashSet<Set<Integer>> puretriples = new HashSet<Set<Integer>>();
        block0: while ((choice = gen.next()) != null && !Thread.currentThread().isInterrupted()) {
            int n3;
            int n2;
            int n1 = allVariables.get(choice[0]);
            List<Integer> triple = this.triple(n1, n2 = allVariables.get(choice[1]).intValue(), n3 = allVariables.get(choice[2]).intValue());
            if (this.zeroCorr(triple)) continue;
            for (int o : allVariables) {
                List<Integer> quartet;
                boolean vanishes;
                if (Thread.currentThread().isInterrupted()) break;
                if (triple.contains(o) || (vanishes = this.vanishes(quartet = this.quartet(n1, n2, n3, o)))) continue;
                continue block0;
            }
            HashSet<Integer> _cluster = new HashSet<Integer>(triple);
            if (this.verbose) {
                this.log("++" + ClusterSignificance.variablesForIndices(triple, this.variables));
            }
            puretriples.add(_cluster);
        }
        return puretriples;
    }

    /*
     * WARNING - void declaration
     */
    private Set<Set<Integer>> combinePuretriples(Set<Set<Integer>> puretriples, List<Integer> _variables) {
        this.log("Growing pure triples.");
        HashSet<HashSet<Integer>> grown = new HashSet<HashSet<Integer>>();
        HashSet t = new HashSet();
        boolean bl = false;
        int total = puretriples.size();
        while (!Thread.currentThread().isInterrupted() && puretriples.iterator().hasNext()) {
            int[] choice2;
            Set<Integer> cluster = puretriples.iterator().next();
            HashSet<Integer> _cluster = new HashSet<Integer>(cluster);
            for (int o : _variables) {
                int[] choice;
                if (Thread.currentThread().isInterrupted()) break;
                if (_cluster.contains(o)) continue;
                ArrayList<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
                int rejected = 0;
                int accepted = 0;
                ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
                while ((choice = gen.next()) != null && !Thread.currentThread().isInterrupted()) {
                    t.clear();
                    t.add((Integer)_cluster2.get(choice[0]));
                    t.add((Integer)_cluster2.get(choice[1]));
                    t.add(o);
                    if (!puretriples.contains(t)) {
                        ++rejected;
                        continue;
                    }
                    ++accepted;
                }
                if (rejected > accepted) continue;
                _cluster.add(o);
                ClusterSignificance clusterSignificance = new ClusterSignificance(this.variables, this.dataModel);
                clusterSignificance.setCheckType(this.checkType);
                if (!this.significanceChecked || !clusterSignificance.significant(_cluster2, this.alpha)) continue;
                _cluster2.remove(o);
            }
            ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
            ArrayList<Integer> _cluster3 = new ArrayList<Integer>(_cluster);
            while ((choice2 = gen2.next()) != null && !Thread.currentThread().isInterrupted()) {
                int n1 = (Integer)_cluster3.get(choice2[0]);
                int n2 = (Integer)_cluster3.get(choice2[1]);
                int n3 = (Integer)_cluster3.get(choice2[2]);
                t.clear();
                t.add(n1);
                t.add(n2);
                t.add(n3);
                puretriples.remove(t);
            }
            if (this.verbose) {
                void var5_6;
                this.log("Grown " + (int)(++var5_6) + " of " + total + ": " + ClusterSignificance.variablesForIndices(new ArrayList<Integer>(_cluster), this.variables));
            }
            grown.add(_cluster);
            if (!puretriples.isEmpty()) continue;
        }
        this.log("Choosing among grown clusters.");
        for (Set set : grown) {
            ArrayList<Integer> _l = new ArrayList<Integer>(set);
            Collections.sort(_l);
            if (!this.verbose) continue;
            this.log("Grown: " + ClusterSignificance.variablesForIndices(_l, this.variables));
        }
        HashSet<Set<Integer>> out = new HashSet<Set<Integer>>();
        ArrayList arrayList = new ArrayList(grown);
        arrayList.sort((o1, o2) -> o2.size() - o1.size());
        HashSet all = new HashSet();
        block5: for (Set cluster : arrayList) {
            for (Integer i : cluster) {
                if (!all.contains(i)) continue;
                continue block5;
            }
            out.add(cluster);
            all.addAll(cluster);
        }
        return out;
    }

    private Set<List<Integer>> findPureClusters(List<Integer> _variables) {
        HashSet<List<Integer>> clusters = new HashSet<List<Integer>>();
        block0: while (!_variables.isEmpty()) {
            int[] choice;
            if (this.verbose) {
                System.out.println(_variables);
            }
            if (_variables.size() < 4) break;
            ChoiceGenerator gen = new ChoiceGenerator(_variables.size(), 4);
            while ((choice = gen.next()) != null && !Thread.currentThread().isInterrupted()) {
                int n4;
                int n3;
                int n2;
                int n1 = _variables.get(choice[0]);
                List<Integer> cluster = this.quartet(n1, n2 = _variables.get(choice[1]).intValue(), n3 = _variables.get(choice[2]).intValue(), n4 = _variables.get(choice[3]).intValue());
                if (!this.pure(cluster)) continue;
                this.addOtherVariables(_variables, cluster);
                if (this.verbose) {
                    this.log("Cluster found: " + ClusterSignificance.variablesForIndices(cluster, this.variables));
                }
                clusters.add(cluster);
                _variables.removeAll(cluster);
                continue block0;
            }
            break block0;
        }
        return clusters;
    }

    private void addOtherVariables(List<Integer> _variables, List<Integer> cluster) {
        block0: for (int o : _variables) {
            int[] choice2;
            if (cluster.contains(o)) continue;
            ArrayList<Integer> _cluster = new ArrayList<Integer>(cluster);
            ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
            while ((choice2 = gen2.next()) != null && !Thread.currentThread().isInterrupted()) {
                int t1 = (Integer)_cluster.get(choice2[0]);
                int t2 = (Integer)_cluster.get(choice2[1]);
                int t3 = (Integer)_cluster.get(choice2[2]);
                List<Integer> quartet = this.triple(t1, t2, t3);
                quartet.add(o);
                if (this.pure(quartet)) continue;
                continue block0;
            }
            this.log("Extending by " + this.variables.get(o));
            cluster.add(o);
        }
    }

    private boolean modelInsignificantWithNewCluster(Set<List<Integer>> clusters, List<Integer> cluster, List<Node> variable, DataModel dataModel) {
        ArrayList<List<Integer>> __clusters = new ArrayList<List<Integer>>(clusters);
        __clusters.add(cluster);
        ClusterSignificance clusterSignificance = new ClusterSignificance(this.variables, dataModel);
        clusterSignificance.setCheckType(this.checkType);
        double significance3 = clusterSignificance.getModelPValue(__clusters);
        if (this.verbose) {
            this.log("Significance * " + __clusters + " = " + significance3);
        }
        return significance3 < this.alpha;
    }

    private Set<List<Integer>> findMixedClusters(List<Integer> remaining, Set<Integer> unionPure) {
        HashSet<List<Integer>> triples = new HashSet<List<Integer>>();
        if (unionPure.isEmpty()) {
            return new HashSet<List<Integer>>();
        }
        block0: while (remaining.size() >= 3) {
            int[] choice;
            ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 3);
            while ((choice = gen.next()) != null && !Thread.currentThread().isInterrupted()) {
                int t2 = remaining.get(choice[0]);
                int t3 = remaining.get(choice[1]);
                int t4 = remaining.get(choice[2]);
                ArrayList<Integer> cluster = new ArrayList<Integer>();
                cluster.add(t2);
                cluster.add(t3);
                cluster.add(t4);
                if (this.zeroCorr(cluster)) continue;
                boolean allVanish = true;
                boolean someVanish = false;
                for (int t1 : this.allVariables()) {
                    if (Thread.currentThread().isInterrupted()) break;
                    if (cluster.contains(t1)) continue;
                    ArrayList<Integer> _cluster = new ArrayList<Integer>(cluster);
                    _cluster.add(t1);
                    if (this.vanishes(_cluster)) {
                        someVanish = true;
                        continue;
                    }
                    allVanish = false;
                    break;
                }
                if (!someVanish || !allVanish) continue;
                triples.add(cluster);
                unionPure.addAll(cluster);
                remaining.removeAll(cluster);
                if (!this.verbose) continue block0;
                this.log("3-cluster found: " + ClusterSignificance.variablesForIndices(cluster, this.variables));
                continue block0;
            }
            break block0;
        }
        return triples;
    }

    private int dofDrton(int n) {
        int dof = (n - 2) * (n - 3) / 2 - 2;
        if (dof < 0) {
            dof = 0;
        }
        return dof;
    }

    private boolean pure(List<Integer> quartet) {
        if (this.zeroCorr(quartet)) {
            return false;
        }
        if (this.vanishes(quartet)) {
            for (int o : this.allVariables()) {
                if (quartet.contains(o)) continue;
                for (int i = 0; i < quartet.size(); ++i) {
                    ArrayList<Integer> _quartet = new ArrayList<Integer>(quartet);
                    _quartet.remove(quartet.get(i));
                    _quartet.add(o);
                    if (this.vanishes(_quartet)) continue;
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    private List<Integer> quartet(int n1, int n2, int n3, int n4) {
        ArrayList<Integer> quartet = new ArrayList<Integer>();
        quartet.add(n1);
        quartet.add(n2);
        quartet.add(n3);
        quartet.add(n4);
        if (new HashSet(quartet).size() < 4) {
            throw new IllegalArgumentException("quartet elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ", " + n4 + ">");
        }
        return quartet;
    }

    private List<Integer> triple(int n1, int n2, int n3) {
        ArrayList<Integer> triple = new ArrayList<Integer>();
        triple.add(n1);
        triple.add(n2);
        triple.add(n3);
        if (new HashSet(triple).size() < 3) {
            throw new IllegalArgumentException("triple elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ">");
        }
        return triple;
    }

    private boolean vanishes(List<Integer> quartet) {
        int n1 = quartet.get(0);
        int n2 = quartet.get(1);
        int n3 = quartet.get(2);
        int n4 = quartet.get(3);
        return this.vanishes(n1, n2, n3, n4);
    }

    private boolean zeroCorr(List<Integer> cluster) {
        int count = 0;
        for (int i = 0; i < cluster.size(); ++i) {
            for (int j = i + 1; j < cluster.size(); ++j) {
                double r = this.corr.getValue(cluster.get(i), cluster.get(j));
                int N = this.corr.getSampleSize();
                double f = FastMath.sqrt(N) * FastMath.log((1.0 + r) / (1.0 - r));
                double p = 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0.0, 1.0, FastMath.abs(f)));
                if (!(p > this.alpha)) continue;
                ++count;
            }
        }
        return count >= 1;
    }

    public List<List<Node>> getClusters() {
        return this.clusters;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private boolean vanishes(int x, int y, int z, int w) {
        if (this.testType == TestType.TETRAD_DELTA) {
            Tetrad t1 = new Tetrad(this.variables.get(x), this.variables.get(y), this.variables.get(z), this.variables.get(w));
            Tetrad t2 = new Tetrad(this.variables.get(x), this.variables.get(y), this.variables.get(w), this.variables.get(z));
            return this.test.getPValue(t1, t2) > this.alpha;
        }
        if (this.testType == TestType.TETRAD_WISHART) {
            return this.test2.tetradPValue(x, y, z, w) > this.alpha && this.test2.tetradPValue(x, y, w, z) > this.alpha;
        }
        throw new IllegalArgumentException("Only the delta and wishart tests are being used: " + this.testType);
    }

    private Graph convertSearchGraphNodes(Set<Set<Node>> clusters) {
        EdgeListGraph graph = new EdgeListGraph();
        ArrayList<GraphNode> latents = new ArrayList<GraphNode>();
        ArrayList<Set<Node>> _clusters = new ArrayList<Set<Node>>(clusters);
        for (int i = 0; i < _clusters.size(); ++i) {
            GraphNode latent = new GraphNode("_L" + (i + 1));
            latent.setNodeType(NodeType.LATENT);
            latents.add(latent);
            graph.addNode(latent);
            for (Node node : (Set)_clusters.get(i)) {
                if (!graph.containsNode(node)) {
                    graph.addNode(node);
                }
                graph.addDirectedEdge((Node)latents.get(i), node);
            }
        }
        return graph;
    }

    private Graph convertToGraph(Set<List<Integer>> allClusters) {
        HashSet<Set<Node>> _clustering = new HashSet<Set<Node>>();
        for (List<Integer> cluster : allClusters) {
            HashSet<Node> nodes = new HashSet<Node>();
            for (int i : cluster) {
                nodes.add(this.variables.get(i));
            }
            _clustering.add(nodes);
        }
        return this.convertSearchGraphNodes(_clustering);
    }

    private Set<Integer> unionPure(Set<List<Integer>> pureClusters) {
        HashSet<Integer> unionPure = new HashSet<Integer>();
        for (List<Integer> cluster : pureClusters) {
            unionPure.addAll(cluster);
        }
        return unionPure;
    }

    private void log(String s) {
        if (this.verbose) {
            TetradLogger.getInstance().forceLogMessage(s);
        }
    }

    public void setSignificanceChecked(boolean significanceChecked) {
        this.significanceChecked = significanceChecked;
    }

    public void setCheckType(ClusterSignificance.CheckType checkType) {
        this.checkType = checkType;
    }

    public static enum Algorithm {
        SAG,
        GAP;

    }
}

