/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.search.utils.DeltaSextadTest;
import edu.cmu.tetrad.search.utils.Sextad;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;

public class Ftfc {
    private final CorrelationMatrix corr;
    private final List<Node> variables;
    private final double alpha;
    private final DeltaSextadTest test;
    private final transient DataModel dataModel;
    private final Algorithm algorithm;
    private List<List<Node>> clusters;
    private boolean verbose;

    public Ftfc(ICovarianceMatrix cov, Algorithm algorithm, double alpha) {
        cov = new CovarianceMatrix(cov);
        this.variables = cov.getVariables();
        this.alpha = alpha;
        this.test = new DeltaSextadTest(cov);
        this.dataModel = cov;
        this.algorithm = algorithm;
        this.corr = new CorrelationMatrix(cov);
    }

    public Ftfc(DataSet dataSet, Algorithm algorithm, double alpha) {
        this.variables = dataSet.getVariables();
        this.alpha = alpha;
        this.test = new DeltaSextadTest(dataSet);
        this.dataModel = dataSet;
        this.algorithm = algorithm;
        this.corr = new CorrelationMatrix(dataSet);
    }

    public Graph search() {
        Set<List<Integer>> allClusters;
        if (this.algorithm == Algorithm.SAG) {
            allClusters = this.estimateClustersSAG();
        } else if (this.algorithm == Algorithm.GAP) {
            allClusters = this.estimateClustersGAP();
        } else {
            throw new IllegalStateException("Expected SAG or GAP: " + (Object)((Object)this.algorithm));
        }
        this.clusters = this.variablesForIndices(allClusters);
        return this.convertToGraph(allClusters);
    }

    public List<List<Node>> getClusters() {
        return this.clusters;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private Set<List<Integer>> estimateClustersGAP() {
        List<Integer> _variables = this.allVariables();
        Set<List<Integer>> pentads = this.findPurepentads(_variables);
        Set<List<Integer>> combined = this.combinePurePentads(pentads, _variables);
        HashSet<List<Integer>> _combined = new HashSet<List<Integer>>();
        for (List<Integer> c : combined) {
            ArrayList<Integer> a = new ArrayList<Integer>(c);
            Collections.sort(a);
            _combined.add(a);
        }
        return _combined;
    }

    private List<Integer> allVariables() {
        ArrayList<Integer> _variables = new ArrayList<Integer>();
        for (int i = 0; i < this.variables.size(); ++i) {
            _variables.add(i);
        }
        return _variables;
    }

    private Set<List<Integer>> estimateClustersSAG() {
        List<Integer> _variables = this.allVariables();
        Set<List<Integer>> pureClusters = this.findPureClusters(_variables);
        Set<List<Integer>> mixedClusters = this.findMixedClusters(pureClusters, _variables, this.unionPure(pureClusters));
        HashSet<List<Integer>> allClusters = new HashSet<List<Integer>>(pureClusters);
        allClusters.addAll(mixedClusters);
        return allClusters;
    }

    private Set<List<Integer>> findPurepentads(List<Integer> variables) {
        int[] choice;
        if (variables.size() < 6) {
            return new HashSet<List<Integer>>();
        }
        this.log("Finding pure pentads.", true);
        ChoiceGenerator gen = new ChoiceGenerator(variables.size(), 5);
        HashSet<List<Integer>> purePentads = new HashSet<List<Integer>>();
        block0: while ((choice = gen.next()) != null) {
            int n5;
            int n4;
            int n3;
            int n2;
            int n1 = variables.get(choice[0]);
            List<Integer> pentad = this.pentad(n1, n2 = variables.get(choice[1]).intValue(), n3 = variables.get(choice[2]).intValue(), n4 = variables.get(choice[3]).intValue(), n5 = variables.get(choice[4]).intValue());
            if (this.zeroCorr(pentad, 4)) continue;
            for (int o : variables) {
                if (pentad.contains(o)) continue;
                List<Integer> sextet = this.sextet(n1, n2, n3, n4, n5, o);
                Collections.sort(sextet);
                boolean vanishes = this.vanishes(sextet);
                if (vanishes) continue;
                continue block0;
            }
            ArrayList<Integer> _cluster = new ArrayList<Integer>(pentad);
            if (this.verbose) {
                System.out.println(this.variablesForIndices(pentad));
                this.log("++" + this.variablesForIndices(pentad), false);
            }
            purePentads.add(_cluster);
        }
        return purePentads;
    }

    /*
     * WARNING - void declaration
     */
    private Set<List<Integer>> combinePurePentads(Set<List<Integer>> purePentads, List<Integer> _variables) {
        this.log("Growing pure pentads.", true);
        HashSet<ArrayList<Integer>> grown = new HashSet<ArrayList<Integer>>();
        ArrayList t = new ArrayList();
        boolean bl = false;
        int total = purePentads.size();
        while (purePentads.iterator().hasNext()) {
            int[] choice2;
            List<Integer> cluster = purePentads.iterator().next();
            ArrayList<Integer> _cluster = new ArrayList<Integer>(cluster);
            block1: for (int o : _variables) {
                int[] choice;
                if (_cluster.contains(o)) continue;
                ArrayList<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
                ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 4);
                while ((choice = gen.next()) != null) {
                    int n1 = (Integer)_cluster2.get(choice[0]);
                    int n2 = (Integer)_cluster2.get(choice[1]);
                    int n3 = (Integer)_cluster2.get(choice[2]);
                    int n4 = (Integer)_cluster2.get(choice[3]);
                    t.clear();
                    t.add(n1);
                    t.add(n2);
                    t.add(n3);
                    t.add(n4);
                    t.add(o);
                    Collections.sort(t);
                    if (purePentads.contains(t)) continue;
                    continue block1;
                }
                _cluster.add(o);
            }
            ChoiceGenerator choiceGenerator = new ChoiceGenerator(_cluster.size(), 5);
            ArrayList<Integer> _cluster3 = new ArrayList<Integer>(_cluster);
            while ((choice2 = choiceGenerator.next()) != null) {
                int n1 = (Integer)_cluster3.get(choice2[0]);
                int n2 = (Integer)_cluster3.get(choice2[1]);
                int n3 = (Integer)_cluster3.get(choice2[2]);
                int n4 = (Integer)_cluster3.get(choice2[3]);
                int n5 = (Integer)_cluster3.get(choice2[4]);
                t.clear();
                t.add(n1);
                t.add(n2);
                t.add(n3);
                t.add(n4);
                t.add(n5);
                Collections.sort(t);
                purePentads.remove(t);
            }
            if (this.verbose) {
                void var5_6;
                System.out.println("Grown " + (int)(++var5_6) + " of " + total + ": " + _cluster);
            }
            grown.add(_cluster);
            if (!purePentads.isEmpty()) continue;
        }
        this.log("Choosing among grown clusters.", true);
        for (List list : grown) {
            ArrayList<Integer> _l = new ArrayList<Integer>(list);
            Collections.sort(_l);
            if (!this.verbose) continue;
            this.log("Grown: " + this.variablesForIndices(_l), false);
        }
        HashSet<List<Integer>> out = new HashSet<List<Integer>>();
        ArrayList arrayList = new ArrayList(grown);
        arrayList.sort((o1, o2) -> o2.size() - o1.size());
        ArrayList all = new ArrayList();
        block5: for (List cluster : arrayList) {
            for (Integer i : cluster) {
                if (!all.contains(i)) continue;
                continue block5;
            }
            out.add(cluster);
            all.addAll(cluster);
        }
        boolean significanceCalculated = false;
        for (List list : out) {
            this.log("OUT: " + this.variablesForIndices(new ArrayList<Integer>(list)), true);
        }
        return out;
    }

    private Set<List<Integer>> findPureClusters(List<Integer> _variables) {
        HashSet<List<Integer>> clusters = new HashSet<List<Integer>>();
        block0: for (int k = 6; k >= 6; --k) {
            block1: while (!_variables.isEmpty()) {
                int[] choice;
                if (this.verbose) {
                    System.out.println(_variables);
                }
                if (_variables.size() < 6) continue block0;
                ChoiceGenerator gen = new ChoiceGenerator(_variables.size(), 6);
                while ((choice = gen.next()) != null) {
                    int n6;
                    int n5;
                    int n4;
                    int n3;
                    int n2;
                    int n1 = _variables.get(choice[0]);
                    List<Integer> cluster = this.sextet(n1, n2 = _variables.get(choice[1]).intValue(), n3 = _variables.get(choice[2]).intValue(), n4 = _variables.get(choice[3]).intValue(), n5 = _variables.get(choice[4]).intValue(), n6 = _variables.get(choice[5]).intValue());
                    if (!this.pure(cluster)) continue;
                    if (this.verbose) {
                        this.log("Found a pure: " + this.variablesForIndices(cluster), false);
                    }
                    this.addOtherVariables(_variables, cluster);
                    if (cluster.size() < k) continue;
                    if (this.verbose) {
                        this.log("Cluster found: " + this.variablesForIndices(cluster), true);
                        System.out.println("Indices for cluster = " + cluster);
                    }
                    clusters.add(cluster);
                    _variables.removeAll(cluster);
                    continue block1;
                }
                break block1;
            }
        }
        return clusters;
    }

    private void addOtherVariables(List<Integer> _variables, List<Integer> cluster) {
        block0: for (int o : _variables) {
            int[] choice;
            if (cluster.contains(o)) continue;
            ArrayList<Integer> _cluster = new ArrayList<Integer>(cluster);
            ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 6);
            while ((choice = gen2.next()) != null) {
                int t1 = (Integer)_cluster.get(choice[0]);
                int t2 = (Integer)_cluster.get(choice[1]);
                int t3 = (Integer)_cluster.get(choice[2]);
                int t4 = (Integer)_cluster.get(choice[3]);
                int t5 = (Integer)_cluster.get(choice[4]);
                List<Integer> sextad = this.pentad(t1, t2, t3, t4, t5);
                sextad.add(o);
                if (this.pure(sextad)) continue;
                continue block0;
            }
            this.log("Extending by " + this.variables.get(o), false);
            cluster.add(o);
        }
    }

    private Set<List<Integer>> findMixedClusters(Set<List<Integer>> clusters, List<Integer> remaining, Set<Integer> unionPure) {
        HashSet<List<Integer>> pentads = new HashSet<List<Integer>>();
        HashSet<List<Integer>> _clusters = new HashSet<List<Integer>>(clusters);
        if (unionPure.isEmpty()) {
            return new HashSet<List<Integer>>();
        }
        block0: while (remaining.size() >= 5) {
            int[] choice;
            if (this.verbose) {
                this.log("UnionPure = " + this.variablesForIndices(new ArrayList<Integer>(unionPure)), false);
            }
            ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 5);
            while ((choice = gen.next()) != null) {
                int t2 = remaining.get(choice[0]);
                int t3 = remaining.get(choice[1]);
                int t4 = remaining.get(choice[2]);
                int t5 = remaining.get(choice[3]);
                int t6 = remaining.get(choice[4]);
                ArrayList<Integer> cluster = new ArrayList<Integer>();
                cluster.add(t2);
                cluster.add(t3);
                cluster.add(t4);
                cluster.add(t5);
                cluster.add(t6);
                if (this.zeroCorr(cluster, 4)) continue;
                boolean allVanish = true;
                boolean someVanish = false;
                for (int t1 : this.allVariables()) {
                    if (cluster.contains(t1)) continue;
                    ArrayList<Integer> _cluster = new ArrayList<Integer>(cluster);
                    _cluster.add(t1);
                    if (this.vanishes(_cluster)) {
                        someVanish = true;
                        continue;
                    }
                    allVanish = false;
                    break;
                }
                if (!someVanish || !allVanish) continue;
                pentads.add(cluster);
                _clusters.add(cluster);
                unionPure.addAll(cluster);
                remaining.removeAll(cluster);
                if (!this.verbose) continue block0;
                this.log("3-cluster found: " + this.variablesForIndices(cluster), false);
                continue block0;
            }
            break block0;
        }
        return pentads;
    }

    private double significance(List<Integer> cluster) {
        double chisq = this.getClusterChiSquare(cluster);
        int n = cluster.size();
        int dof = this.dofHarman(n);
        double q = ProbUtils.chisqCdf(chisq, dof);
        return 1.0 - q;
    }

    private int dofHarman(int n) {
        int dof = n * (n - 5) / 2 + 1;
        if (dof < 0) {
            dof = 0;
        }
        return dof;
    }

    private List<Node> variablesForIndices(List<Integer> cluster) {
        ArrayList<Node> _cluster = new ArrayList<Node>();
        for (int c : cluster) {
            _cluster.add(this.variables.get(c));
        }
        return _cluster;
    }

    private List<List<Node>> variablesForIndices(Set<List<Integer>> clusters) {
        ArrayList<List<Node>> variables = new ArrayList<List<Node>>();
        for (List<Integer> cluster : clusters) {
            variables.add(this.variablesForIndices(cluster));
        }
        return variables;
    }

    private boolean pure(List<Integer> sextet) {
        if (this.zeroCorr(sextet, 5)) {
            return false;
        }
        if (this.vanishes(sextet)) {
            for (int o : this.allVariables()) {
                if (sextet.contains(o)) continue;
                for (int i = 0; i < sextet.size(); ++i) {
                    ArrayList<Integer> _sextet = new ArrayList<Integer>(sextet);
                    _sextet.remove(sextet.get(i));
                    _sextet.add(i, o);
                    if (this.vanishes(_sextet)) continue;
                    return false;
                }
            }
            System.out.println("PURE: " + this.variablesForIndices(sextet));
            return true;
        }
        return false;
    }

    private double getClusterChiSquare(List<Integer> cluster) {
        SemIm im = this.estimateClusterModel(cluster);
        return im.getChiSquare();
    }

    private SemIm estimateClusterModel(List<Integer> sextet) {
        EdgeListGraph g = new EdgeListGraph();
        GraphNode l1 = new GraphNode("L1");
        l1.setNodeType(NodeType.LATENT);
        GraphNode l2 = new GraphNode("L2");
        l2.setNodeType(NodeType.LATENT);
        g.addNode(l1);
        g.addNode(l2);
        for (Integer aQuartet : sextet) {
            Node n = this.variables.get(aQuartet);
            g.addNode(n);
            g.addDirectedEdge(l1, n);
            g.addDirectedEdge(l2, n);
        }
        SemPm pm = new SemPm(g);
        SemEstimator est = this.dataModel instanceof DataSet ? new SemEstimator((DataSet)this.dataModel, pm, (SemOptimizer)new SemOptimizerEm()) : new SemEstimator((CovarianceMatrix)this.dataModel, pm, (SemOptimizer)new SemOptimizerEm());
        return est.estimate();
    }

    private List<Integer> sextet(int n1, int n2, int n3, int n4, int n5, int n6) {
        ArrayList<Integer> sextet = new ArrayList<Integer>();
        sextet.add(n1);
        sextet.add(n2);
        sextet.add(n3);
        sextet.add(n4);
        sextet.add(n5);
        sextet.add(n6);
        if (new HashSet(sextet).size() < 6) {
            throw new IllegalArgumentException("sextet elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ", " + n4 + ", " + n5 + ", " + n6 + ">");
        }
        return sextet;
    }

    private List<Integer> pentad(int n1, int n2, int n3, int n4, int n5) {
        ArrayList<Integer> pentad = new ArrayList<Integer>();
        pentad.add(n1);
        pentad.add(n2);
        pentad.add(n3);
        pentad.add(n4);
        pentad.add(n5);
        if (new HashSet(pentad).size() < 5) {
            throw new IllegalArgumentException("pentad elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ", " + n4 + ", " + n5 + ">");
        }
        return pentad;
    }

    private boolean vanishes(List<Integer> sextet) {
        int n6;
        int n5;
        int n4;
        int n3;
        int n2;
        int n1 = sextet.get(0);
        return this.vanishes(n1, n2 = sextet.get(1).intValue(), n3 = sextet.get(2).intValue(), n4 = sextet.get(3).intValue(), n5 = sextet.get(4).intValue(), n6 = sextet.get(5).intValue()) && this.vanishes(n3, n2, n1, n6, n5, n4) && this.vanishes(n4, n5, n6, n1, n2, n3) && this.vanishes(n6, n5, n4, n3, n2, n1);
    }

    private boolean zeroCorr(List<Integer> cluster, int n) {
        int count = 0;
        for (int i = 0; i < cluster.size(); ++i) {
            for (int j = i + 1; j < cluster.size(); ++j) {
                double r = this.corr.getValue(cluster.get(i), cluster.get(j));
                int N = this.corr.getSampleSize();
                double f = FastMath.sqrt(N) * FastMath.log((1.0 + r) / (1.0 - r));
                double p = 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0.0, 1.0, FastMath.abs(f)));
                if (!(p > this.alpha)) continue;
                ++count;
            }
        }
        return count >= n;
    }

    private boolean vanishes(int n1, int n2, int n3, int n4, int n5, int n6) {
        Sextad t1 = new Sextad(n1, n2, n3, n4, n5, n6);
        Sextad t2 = new Sextad(n1, n5, n6, n2, n3, n4);
        Sextad t3 = new Sextad(n1, n4, n6, n2, n3, n5);
        Sextad t5 = new Sextad(n1, n3, n4, n2, n5, n6);
        Sextad t6 = new Sextad(n1, n3, n5, n2, n4, n6);
        Sextad t7 = new Sextad(n1, n3, n6, n2, n4, n5);
        Sextad t8 = new Sextad(n1, n2, n4, n3, n5, n6);
        Sextad t9 = new Sextad(n1, n2, n5, n3, n4, n6);
        Sextad t10 = new Sextad(n1, n2, n6, n3, n4, n5);
        ArrayList<Sextad[]> independents = new ArrayList<Sextad[]>();
        independents.add(new Sextad[]{t1, t2, t3, t5, t6});
        for (Sextad[] sextads : independents) {
            double p = this.test.getPValue(sextads);
            if (Double.isNaN(p)) {
                return false;
            }
            if (!(p < this.alpha)) continue;
            return false;
        }
        return true;
    }

    private Graph convertSearchGraphNodes(Set<Set<Node>> clusters) {
        EdgeListGraph graph = new EdgeListGraph(this.variables);
        ArrayList<GraphNode> latents = new ArrayList<GraphNode>();
        for (int i = 0; i < clusters.size(); ++i) {
            GraphNode latent = new GraphNode("_L" + (i + 1));
            latent.setNodeType(NodeType.LATENT);
            latents.add(latent);
            graph.addNode(latent);
        }
        ArrayList<Set<Node>> _clusters = new ArrayList<Set<Node>>(clusters);
        for (int i = 0; i < latents.size(); ++i) {
            for (Node node : (Set)_clusters.get(i)) {
                if (!graph.containsNode(node)) {
                    graph.addNode(node);
                }
                graph.addDirectedEdge((Node)latents.get(i), node);
            }
        }
        return graph;
    }

    private Graph convertToGraph(Set<List<Integer>> allClusters) {
        HashSet<Set<Node>> _clustering = new HashSet<Set<Node>>();
        for (List<Integer> cluster : allClusters) {
            HashSet<Node> nodes = new HashSet<Node>();
            for (int i : cluster) {
                nodes.add(this.variables.get(i));
            }
            _clustering.add(nodes);
        }
        return this.convertSearchGraphNodes(_clustering);
    }

    private Set<Integer> unionPure(Set<List<Integer>> pureClusters) {
        HashSet<Integer> unionPure = new HashSet<Integer>();
        for (List<Integer> cluster : pureClusters) {
            unionPure.addAll(cluster);
        }
        return unionPure;
    }

    private void log(String s, boolean toLog) {
        if (toLog) {
            TetradLogger.getInstance().log("info", s);
        }
    }

    public static enum Algorithm {
        SAG,
        GAP;

    }
}

