/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Triple;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.IGraphSearch;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndTestFisherZ;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.search.work_in_progress.VcFas;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.CombinationGenerator;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public final class SampleVcpcFast
implements IGraphSearch {
    private final IndependenceTest independenceTest;
    private final TetradLogger logger = TetradLogger.getInstance();
    private final DataSet dataSet;
    private final ICovarianceMatrix covMatrix;
    private Knowledge knowledge = new Knowledge();
    private int depth = 1000;
    private Graph graph;
    private long elapsedTime;
    private Set<Triple> colliderTriples;
    private Set<Triple> noncolliderTriples;
    private Set<Triple> ambiguousTriples;
    private Set<Edge> definitelyNonadjacencies;
    private boolean meekPreventCycles;
    private Map<Edge, Set<Node>> apparentlyNonadjacencies;
    private boolean verbose;
    private SemPm semPm;
    private SemIm semIm;

    public SampleVcpcFast(IndependenceTest independenceTest) {
        if (independenceTest == null) {
            throw new NullPointerException();
        }
        if (!(independenceTest instanceof IndTestFisherZ)) {
            throw new IllegalArgumentException("Need Fisher Z test to proceed with algorithm");
        }
        this.independenceTest = independenceTest;
        this.dataSet = (DataSet)independenceTest.getData();
        this.covMatrix = new CovarianceMatrix(this.dataSet);
    }

    private static Node traverseFuturePath(Node node, Edge edge1, Edge edge2) {
        Endpoint E1 = edge1.getProximalEndpoint(node);
        Endpoint E2 = edge2.getProximalEndpoint(node);
        Endpoint E3 = edge2.getDistalEndpoint(node);
        Endpoint E4 = edge1.getDistalEndpoint(node);
        if (E1 == Endpoint.ARROW && E2 == Endpoint.ARROW && E3 == Endpoint.TAIL) {
            return null;
        }
        if (E4 == Endpoint.ARROW) {
            return null;
        }
        if (E4 == Endpoint.TAIL && E1 == Endpoint.TAIL && E2 == Endpoint.TAIL && E3 == Endpoint.TAIL) {
            return null;
        }
        return edge2.getDistalNode(node);
    }

    public static void futureNodeVisit(Graph graph, Node b, LinkedList<Node> path, Set<Node> futureNodes) {
        path.addLast(b);
        futureNodes.add(b);
        for (Edge edge2 : graph.getEdges(b)) {
            Node c;
            int size = path.size();
            if (path.size() < 2) {
                c = edge2.getDistalNode(b);
            } else {
                Node a = path.get(size - 2);
                Edge edge1 = graph.getEdge(a, b);
                c = SampleVcpcFast.traverseFuturePath(b, edge1, edge2);
            }
            if (c == null || path.contains(c)) continue;
            SampleVcpcFast.futureNodeVisit(graph, c, path, futureNodes);
        }
        path.removeLast();
    }

    public static boolean isArrowheadAllowed1(Node from, Node to, Knowledge knowledge) {
        return knowledge == null || !knowledge.isRequired(to.toString(), from.toString()) && !knowledge.isForbidden(from.toString(), to.toString());
    }

    public SemIm getSemIm() {
        return this.semIm;
    }

    public void setSemIm(SemIm semIm) {
        this.semIm = semIm;
    }

    public boolean isMeekPreventCycles() {
        return this.meekPreventCycles;
    }

    public void setMeekPreventCycles(boolean meekPreventCycles) {
        this.meekPreventCycles = meekPreventCycles;
    }

    public long getElapsedTime() {
        return this.elapsedTime;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public IndependenceTest getIndependenceTest() {
        return this.independenceTest;
    }

    public int getDepth() {
        return this.depth;
    }

    public void setDepth(int depth) {
        if (depth < -1) {
            throw new IllegalArgumentException("Depth must be -1 or >= 0: " + depth);
        }
        if (depth == Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Depth must not be Integer.MAX_VALUE, due to a known bug.");
        }
        this.depth = depth;
    }

    public Set<Triple> getAmbiguousTriples() {
        return new HashSet<Triple>(this.ambiguousTriples);
    }

    public Set<Triple> getColliderTriples() {
        return new HashSet<Triple>(this.colliderTriples);
    }

    public Set<Triple> getNoncolliderTriples() {
        return new HashSet<Triple>(this.noncolliderTriples);
    }

    public Set<Edge> getAdjacencies() {
        return new HashSet<Edge>(this.graph.getEdges());
    }

    public Set<Edge> getApparentNonadjacencies() {
        return new HashSet<Edge>(this.apparentlyNonadjacencies.keySet());
    }

    public Set<Edge> getDefiniteNonadjacencies() {
        return new HashSet<Edge>(this.definitelyNonadjacencies);
    }

    @Override
    public Graph search() {
        int[] combination;
        this.logger.log("info", "Starting VCCPC algorithm");
        this.logger.log("info", "Independence test = " + this.getIndependenceTest() + ".");
        this.ambiguousTriples = new HashSet<Triple>();
        this.colliderTriples = new HashSet<Triple>();
        this.noncolliderTriples = new HashSet<Triple>();
        VcFas fas = new VcFas(this.getIndependenceTest());
        this.definitelyNonadjacencies = new HashSet<Edge>();
        long startTime = MillisecondTimes.timeMillis();
        List<Node> allNodes = this.getIndependenceTest().getVariables();
        fas.setKnowledge(this.getKnowledge());
        fas.setDepth(this.getDepth());
        fas.setVerbose(this.verbose);
        this.graph = fas.search();
        this.apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
        if (this.isDoOrientation()) {
            if (this.verbose) {
                System.out.println("CPC orientation...");
            }
            GraphSearchUtils.pcOrientbk(this.knowledge, this.graph, allNodes);
            this.orientUnshieldedTriples(this.knowledge, this.getIndependenceTest(), this.getDepth());
            MeekRules meekRules = new MeekRules();
            meekRules.setMeekPreventCycles(this.meekPreventCycles);
            meekRules.setKnowledge(this.knowledge);
            meekRules.orientImplied(this.graph);
        }
        ArrayList<Triple> ambiguousTriples = new ArrayList<Triple>(this.graph.getAmbiguousTriples());
        int[] dims = new int[ambiguousTriples.size()];
        for (int i = 0; i < ambiguousTriples.size(); ++i) {
            dims[i] = 2;
        }
        ArrayList<Iterator<Edge>> CPDAGs = new ArrayList<Iterator<Edge>>();
        IdentityHashMap newColliders = new IdentityHashMap();
        IdentityHashMap newNonColliders = new IdentityHashMap();
        CombinationGenerator generator = new CombinationGenerator(dims);
        while ((combination = generator.next()) != null) {
            Iterator<Edge> _graph = new EdgeListGraph(this.graph);
            newColliders.put(_graph, new ArrayList());
            newNonColliders.put(_graph, new ArrayList());
            for (int k = 0; k < combination.length; ++k) {
                Triple triple = (Triple)ambiguousTriples.get(k);
                _graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
                if (combination[k] == 0) {
                    ((List)newColliders.get(_graph)).add(triple);
                    Node x = triple.getX();
                    Node y = triple.getY();
                    Node z = triple.getZ();
                    _graph.setEndpoint(x, y, Endpoint.ARROW);
                    _graph.setEndpoint(z, y, Endpoint.ARROW);
                }
                if (combination[k] != 1) continue;
                ((List)newNonColliders.get(_graph)).add(triple);
            }
            CPDAGs.add(_graph);
        }
        block3: for (Graph graph : new ArrayList(CPDAGs)) {
            Node z;
            Node y;
            Node x;
            List colliders = (List)newColliders.get(graph);
            List nonColliders = (List)newNonColliders.get(graph);
            for (Triple triple : colliders) {
                x = triple.getX();
                y = triple.getY();
                z = triple.getZ();
                if (!graph.getEdge(x, y).pointsTowards(x) && !graph.getEdge(y, z).pointsTowards(z)) continue;
                CPDAGs.remove(graph);
                continue block3;
            }
            for (Triple triple : colliders) {
                x = triple.getX();
                y = triple.getY();
                z = triple.getZ();
                graph.setEndpoint(x, y, Endpoint.ARROW);
                graph.setEndpoint(z, y, Endpoint.ARROW);
            }
            for (Triple triple : nonColliders) {
                x = triple.getX();
                y = triple.getY();
                z = triple.getZ();
                if (graph.getEdge(x, y).pointsTowards(y)) {
                    graph.removeEdge(y, z);
                    graph.addDirectedEdge(y, z);
                }
                if (!graph.getEdge(y, z).pointsTowards(y)) continue;
                graph.removeEdge(x, y);
                graph.addDirectedEdge(y, x);
            }
            for (Edge edge : graph.getEdges()) {
                x = edge.getNode1();
                y = edge.getNode2();
                if (!Edges.isBidirectedEdge(edge)) continue;
                graph.removeEdge(x, y);
                graph.addUndirectedEdge(x, y);
            }
            MeekRules rules = new MeekRules();
            rules.orientImplied(graph);
            if (!graph.paths().existsDirectedCycle()) continue;
            CPDAGs.remove(graph);
        }
        block8: for (Edge edge : this.apparentlyNonadjacencies.keySet()) {
            Node x = edge.getNode1();
            Iterator<Object> y = edge.getNode2();
            for (Graph _graph : new ArrayList(CPDAGs)) {
                Iterator<Edge> boundaryX = new HashSet<Node>(this.boundary(x, _graph));
                HashSet<Node> boundaryY = new HashSet<Node>(this.boundary((Node)((Object)y), _graph));
                HashSet<Node> futureX = new HashSet<Node>(this.future(x, _graph));
                HashSet<Node> futureY = new HashSet<Node>(this.future((Node)((Object)y), _graph));
                if (y == x || boundaryX.contains(y) || boundaryY.contains(x)) continue;
                IndependenceTest test = this.independenceTest;
                if ((futureX.contains(y) || test.checkIndependence(x, (Node)((Object)y), (Set<Node>)((Object)boundaryX)).isIndependent()) && (futureY.contains(x) || test.checkIndependence((Node)((Object)y), x, boundaryY).isIndependent())) continue;
                continue block8;
            }
            this.definitelyNonadjacencies.add(edge);
        }
        for (Edge edge : this.definitelyNonadjacencies) {
            if (!this.apparentlyNonadjacencies.containsKey(edge)) continue;
            this.apparentlyNonadjacencies.keySet().remove(edge);
        }
        this.setSemIm(this.semIm);
        RegressionDataset sampleRegression = new RegressionDataset(this.dataSet);
        System.out.println(sampleRegression.getGraph());
        this.graph = GraphUtils.replaceNodes(this.graph, this.dataSet.getVariables());
        HashMap<Edge, double[]> sampleRegress = new HashMap<Edge, double[]>();
        HashMap<Edge, Double> edgeCoefs = new HashMap<Edge, Double>();
        block11: for (Node z : this.graph.getNodes()) {
            Double c;
            Node b;
            Node a;
            Set<Edge> adj = this.getAdj(z, this.graph);
            for (Edge edge : this.apparentlyNonadjacencies.keySet()) {
                if (z != edge.getNode1() && z != edge.getNode2()) continue;
                for (Edge adjacency : adj) {
                    sampleRegress.put(adjacency, null);
                    Node a2 = adjacency.getNode1();
                    Node b2 = adjacency.getNode2();
                    if (this.semIm.existsEdgeCoef(a2, b2)) {
                        Double c2 = this.semIm.getEdgeCoef(a2, b2);
                        edgeCoefs.put(adjacency, c2);
                        continue;
                    }
                    edgeCoefs.put(adjacency, 0.0);
                }
                continue block11;
            }
            for (Edge nonadj : this.definitelyNonadjacencies) {
                if (nonadj.getNode1() != z && nonadj.getNode2() != z) continue;
                double[] d = new double[]{0.0, 0.0};
                sampleRegress.put(nonadj, d);
                Node a3 = nonadj.getNode1();
                Node b3 = nonadj.getNode2();
                if (this.semIm.existsEdgeCoef(a3, b3)) {
                    Double c3 = this.semIm.getEdgeCoef(a3, b3);
                    edgeCoefs.put(nonadj, c3);
                    continue;
                }
                edgeCoefs.put(nonadj, 0.0);
            }
            HashSet<Edge> parentsOfZ = new HashSet<Edge>();
            Set<Edge> _adj = this.getAdj(z, this.graph);
            for (Edge _adjacency : _adj) {
                if (!_adjacency.isDirected()) {
                    for (Edge adjacency : adj) {
                        sampleRegress.put(adjacency, null);
                        a = adjacency.getNode1();
                        b = adjacency.getNode2();
                        if (this.semIm.existsEdgeCoef(a, b)) {
                            c = this.semIm.getEdgeCoef(a, b);
                            edgeCoefs.put(adjacency, c);
                            continue;
                        }
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
                if (!_adjacency.pointsTowards(z)) continue;
                parentsOfZ.add(_adjacency);
            }
            for (Edge edge : parentsOfZ) {
                if (!edge.pointsTowards(edge.getNode2())) continue;
                RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
                System.out.println(result);
                double[] d = result.getCoef();
                sampleRegress.put(edge, d);
                a = edge.getNode1();
                b = edge.getNode2();
                if (this.semIm.existsEdgeCoef(a, b)) {
                    c = this.semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(edge, c);
                    continue;
                }
                edgeCoefs.put(edge, 0.0);
            }
        }
        System.out.println("All IM: " + this.semIm + "Finish");
        System.out.println("Just IM coefs: " + this.semIm.getEdgeCoef());
        System.out.println("IM Coef Map: " + edgeCoefs);
        System.out.println("Regress Coef Map: " + sampleRegress);
        for (Edge edge : sampleRegress.keySet()) {
            System.out.println(" Sample Regression: " + edge + Arrays.toString((double[])sampleRegress.get(edge)));
        }
        for (Edge edge : this.graph.getEdges()) {
            System.out.println("Sample edge: " + Arrays.toString((double[])sampleRegress.get(edge)));
        }
        System.out.println("Sample VCPC:");
        System.out.println("# of CPDAGs: " + CPDAGs.size());
        long endTime = MillisecondTimes.timeMillis();
        this.elapsedTime = endTime - startTime;
        System.out.println("Search Time (seconds):" + this.elapsedTime / 1000L + " s");
        System.out.println("Search Time (milli):" + this.elapsedTime + " ms");
        System.out.println("# of Apparent Nonadj: " + this.apparentlyNonadjacencies.size());
        System.out.println("# of Definite Nonadj: " + this.definitelyNonadjacencies.size());
        TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + this.apparentlyNonadjacencies);
        TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + this.definitelyNonadjacencies);
        TetradLogger.getInstance().log("CPDAGs", "Disambiguated CPDAGs: " + CPDAGs);
        TetradLogger.getInstance().log("info", "Elapsed time = " + (double)this.elapsedTime / 1000.0 + " s");
        TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
        this.logTriples();
        TetradLogger.getInstance().flush();
        return this.graph;
    }

    public ICovarianceMatrix getCov() {
        return this.covMatrix;
    }

    private Set<Edge> getAdj(Node node, Graph graph) {
        HashSet<Edge> adj = new HashSet<Edge>();
        for (Edge edge : graph.getEdges()) {
            if (node == edge.getNode1()) {
                adj.add(edge);
            }
            if (node != edge.getNode2()) continue;
            adj.add(edge);
        }
        return adj;
    }

    private Set<Node> boundary(Node x, Graph graph) {
        HashSet<Node> boundary = new HashSet<Node>();
        List<Node> adj = graph.getAdjacentNodes(x);
        for (Node y : adj) {
            if (!graph.isParentOf(y, x) && !Edges.isUndirectedEdge(graph.getEdge(x, y))) continue;
            boundary.add(y);
        }
        return boundary;
    }

    private Set<Node> future(Node x, Graph graph) {
        HashSet<Node> futureNodes = new HashSet<Node>();
        LinkedList<Node> path = new LinkedList<Node>();
        SampleVcpcFast.futureNodeVisit(graph, x, path, futureNodes);
        futureNodes.remove(x);
        List<Node> adj = graph.getAdjacentNodes(x);
        for (Node y : adj) {
            if (!graph.isParentOf(y, x) && !Edges.isUndirectedEdge(graph.getEdge(x, y))) continue;
            futureNodes.remove(y);
        }
        return futureNodes;
    }

    private void logTriples() {
        TetradLogger.getInstance().log("info", "\nCollider triples:");
        for (Triple triple : this.colliderTriples) {
            TetradLogger.getInstance().log("info", "Collider: " + triple);
        }
        TetradLogger.getInstance().log("info", "\nNoncollider triples:");
        for (Triple triple : this.noncolliderTriples) {
            TetradLogger.getInstance().log("info", "Noncollider: " + triple);
        }
        TetradLogger.getInstance().log("info", "\nAmbiguous triples (i.e. list of triples for which \nthere is ambiguous data about whether they are colliders or not):");
        for (Triple triple : this.getAmbiguousTriples()) {
            TetradLogger.getInstance().log("info", "Ambiguous: " + triple);
        }
    }

    private void orientUnshieldedTriples(Knowledge knowledge, IndependenceTest test, int depth) {
        TetradLogger.getInstance().log("info", "Starting Collider Orientation:");
        this.colliderTriples = new HashSet<Triple>();
        this.noncolliderTriples = new HashSet<Triple>();
        this.ambiguousTriples = new HashSet<Triple>();
        List<Node> nodes = this.graph.getNodes();
        for (Node y : nodes) {
            int[] combination;
            ArrayList<Node> adjacentNodes = new ArrayList<Node>(this.graph.getAdjacentNodes(y));
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null) {
                Node z;
                Node x = (Node)adjacentNodes.get(combination[0]);
                if (this.graph.isAdjacentTo(x, z = (Node)adjacentNodes.get(combination[1]))) continue;
                GraphSearchUtils.CpcTripleType type = GraphSearchUtils.getCpcTripleType(x, y, z, test, depth, this.graph);
                if (type == GraphSearchUtils.CpcTripleType.COLLIDER) {
                    if (this.colliderAllowed(x, y, z, knowledge)) {
                        this.graph.setEndpoint(x, y, Endpoint.ARROW);
                        this.graph.setEndpoint(z, y, Endpoint.ARROW);
                        TetradLogger.getInstance().log("colliderOrientations", LogUtilsSearch.colliderOrientedMsg(x, y, z));
                    }
                    this.colliderTriples.add(new Triple(x, y, z));
                    continue;
                }
                if (type == GraphSearchUtils.CpcTripleType.AMBIGUOUS) {
                    Triple triple = new Triple(x, y, z);
                    this.ambiguousTriples.add(triple);
                    this.graph.addAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
                    Edge edge = Edges.undirectedEdge(x, z);
                    this.definitelyNonadjacencies.add(edge);
                    continue;
                }
                this.noncolliderTriples.add(new Triple(x, y, z));
            }
        }
        TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
    }

    private boolean colliderAllowed(Node x, Node y, Node z, Knowledge knowledge) {
        return SampleVcpcFast.isArrowheadAllowed1(x, y, knowledge) && SampleVcpcFast.isArrowheadAllowed1(z, y, knowledge);
    }

    public boolean isDoOrientation() {
        return true;
    }

    public Graph getGraph() {
        return this.graph;
    }

    public void setGraph(Graph graph) {
        this.graph = graph;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public SemPm getSemPm() {
        return this.semPm;
    }

    public void setSemPm(SemPm semPm) {
        this.semPm = semPm;
    }
}

