/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Triple;
import edu.cmu.tetrad.search.Boss;
import edu.cmu.tetrad.search.FciOrient;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SepsetsGreedy;
import edu.cmu.tetrad.search.TeyssierScorer;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.SublistGenerator;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public final class LvSwap
implements GraphSearch {
    private AlgType algType = AlgType.LVSwap1;
    private Boss.AlgType bossAlgType = Boss.AlgType.BOSS1;
    private final Score score;
    ICovarianceMatrix covarianceMatrix;
    private IndependenceTest test;
    private boolean completeRuleSetUsed = true;
    private int maxPathLength = -1;
    private int numStarts = 1;
    private int depth = -1;
    private boolean useRaskuttiUhler;
    private boolean useDataOrder = true;
    private boolean useScore = true;
    private Knowledge knowledge = new Knowledge();
    private boolean verbose = false;
    private PrintStream out = System.out;
    private boolean doDiscriminatingPathTailRule = true;

    public LvSwap(IndependenceTest test, Score score) {
        this.test = test;
        this.score = score;
    }

    @Override
    public Graph search() {
        if (this.algType == AlgType.LVSwap1) {
            return this.lvswap1();
        }
        if (this.algType == AlgType.LVSwap2a) {
            return this.lvswap2a();
        }
        if (this.algType == AlgType.LVSwap2b) {
            return this.lvswap2b();
        }
        throw new IllegalArgumentException("Unexpected alg type: " + (Object)((Object)this.algType));
    }

    public Graph lvswap1() {
        TeyssierScorer scorer = new TeyssierScorer(this.test, this.score);
        Boss alg = new Boss(scorer);
        alg.setAlgType(this.bossAlgType);
        alg.setUseScore(this.useScore);
        alg.setUseRaskuttiUhler(this.useRaskuttiUhler);
        alg.setUseDataOrder(this.useDataOrder);
        alg.setDepth(this.depth);
        alg.setNumStarts(this.numStarts);
        alg.setVerbose(this.verbose);
        alg.bestOrder(this.score.getVariables());
        Graph G = alg.getGraph(true);
        EdgeListGraph GBoss = new EdgeListGraph(G);
        LvSwap.retainUnshieldedColliders(G);
        scorer.bookmark();
        List<Node> pi = scorer.getPi();
        Collections.reverse(pi);
        for (Node x : pi) {
            int[] choice;
            HashMap<Node, List<Node>> T = new HashMap<Node, List<Node>>();
            List<Node> xParents = GBoss.getParents(x);
            int _depth = this.depth < 0 ? xParents.size() : this.depth;
            _depth = FastMath.min(_depth, xParents.size());
            SublistGenerator gen = new SublistGenerator(xParents.size(), _depth);
            while ((choice = gen.next()) != null) {
                List<Node> XSubset = GraphUtils.asList(choice, xParents);
                for (Node z : LvSwap.getComplement(xParents, XSubset)) {
                    if (!this.adj(x, XSubset, GBoss).contains(z)) continue;
                    scorer.goToBookmark();
                    ArrayList<Node> _sub = new ArrayList<Node>(XSubset);
                    _sub.remove(z);
                    for (Node w : _sub) {
                        scorer.moveTo(w, scorer.index(x));
                    }
                    if (scorer.parent(z, x)) continue;
                    T.put(z, XSubset);
                }
            }
            for (Node z : T.keySet()) {
                if (((List)T.get(z)).isEmpty()) continue;
                G.removeEdge(x, z);
                for (Node y : (List)T.get(z)) {
                    if (!G.isAdjacentTo(x, y) || !G.isAdjacentTo(y, z)) continue;
                    G.setEndpoint(x, y, Endpoint.ARROW);
                    G.setEndpoint(z, y, Endpoint.ARROW);
                }
            }
        }
        scorer.goToBookmark();
        this.finalOrientation(this.knowledge, G);
        return G;
    }

    public Graph lvswap2a() {
        TeyssierScorer scorer = new TeyssierScorer(this.test, this.score);
        Boss alg = new Boss(scorer);
        alg.setAlgType(this.bossAlgType);
        alg.setUseScore(this.useScore);
        alg.setUseRaskuttiUhler(this.useRaskuttiUhler);
        alg.setUseDataOrder(this.useDataOrder);
        alg.setDepth(this.depth);
        alg.setNumStarts(this.numStarts);
        alg.setVerbose(this.verbose);
        alg.bestOrder(this.score.getVariables());
        Graph G = alg.getGraph(true);
        LvSwap.retainUnshieldedColliders(G);
        scorer.bookmark();
        HashSet<Triple> T = new HashSet<Triple>();
        List<Node> pi = scorer.getPi();
        for (int i = 0; i < 3; ++i) {
            for (Node y : pi) {
                List<Node> adjy = G.getAdjacentNodes(y);
                for (Node x : adjy) {
                    for (Node z : adjy) {
                        if (x == z || T.contains(new Triple(x, y, z))) continue;
                        scorer.goToBookmark();
                        scorer.swaptuck(x, y, z, true);
                        if (scorer.adjacent(x, z) || !scorer.collider(x, y, z)) continue;
                        Set<Node> adj = scorer.getAdjacentNodes(x);
                        adj.retainAll(scorer.getAdjacentNodes(z));
                        for (Node w : adj) {
                            if (!scorer.collider(x, w, z)) continue;
                            T.add(new Triple(x, w, z));
                        }
                    }
                }
            }
            this.removeShields(G, T);
            LvSwap.retainUnshieldedColliders(G);
            this.orientColliders(G, T);
        }
        this.finalOrientation(this.knowledge, G);
        scorer.goToBookmark();
        return G;
    }

    public Graph lvswap2b() {
        TeyssierScorer scorer = new TeyssierScorer(this.test, this.score);
        Boss alg = new Boss(scorer);
        alg.setAlgType(this.bossAlgType);
        alg.setUseScore(this.useScore);
        alg.setUseRaskuttiUhler(this.useRaskuttiUhler);
        alg.setUseDataOrder(this.useDataOrder);
        alg.setDepth(this.depth);
        alg.setNumStarts(this.numStarts);
        alg.setVerbose(this.verbose);
        alg.bestOrder(this.score.getVariables());
        Graph G = alg.getGraph(false);
        LvSwap.retainUnshieldedColliders(G);
        scorer.bookmark();
        HashSet<Triple> T = new HashSet<Triple>();
        List<Node> pi = scorer.getPi();
        for (Node y : pi) {
            List<Node> adjy = G.getAdjacentNodes(y);
            for (Node x : adjy) {
                for (Node z : adjy) {
                    Set<Node> adj;
                    if (!scorer.adjacent(x, z) || T.contains(new Triple(x, y, z))) continue;
                    scorer.goToBookmark();
                    scorer.swaptuck(x, y, z, false);
                    if (!scorer.adjacent(x, z) && scorer.collider(x, y, z)) {
                        adj = scorer.getAdjacentNodes(x);
                        adj.retainAll(scorer.getAdjacentNodes(z));
                        for (Node w : adj) {
                            if (!scorer.collider(x, w, z)) continue;
                            T.add(new Triple(x, w, z));
                        }
                    } else {
                        scorer.bookmark();
                    }
                    scorer.goToBookmark();
                    scorer.swaptuck(x, y, z, true);
                    if (scorer.adjacent(x, z) || !scorer.collider(x, y, z)) continue;
                    adj = scorer.getAdjacentNodes(x);
                    adj.retainAll(scorer.getAdjacentNodes(z));
                    for (Node w : adj) {
                        if (!scorer.collider(x, w, z)) continue;
                        T.add(new Triple(x, w, z));
                    }
                }
            }
        }
        this.removeShields(G, T);
        LvSwap.retainUnshieldedColliders(G);
        this.orientColliders(G, T);
        this.finalOrientation(this.knowledge, G);
        scorer.goToBookmark();
        return G;
    }

    public Graph lvswap3() {
        TeyssierScorer scorer = new TeyssierScorer(this.test, this.score);
        Boss alg = new Boss(scorer);
        alg.setAlgType(this.bossAlgType);
        alg.setUseScore(this.useScore);
        alg.setUseRaskuttiUhler(this.useRaskuttiUhler);
        alg.setUseDataOrder(this.useDataOrder);
        alg.setDepth(this.depth);
        alg.setNumStarts(this.numStarts);
        alg.setVerbose(this.verbose);
        alg.bestOrder(this.score.getVariables());
        Graph G = alg.getGraph(true);
        LvSwap.retainUnshieldedColliders(G);
        EdgeListGraph G2 = new EdgeListGraph(G);
        scorer.bookmark();
        HashSet<Triple> T = new HashSet<Triple>();
        HashSet<Triple> allT = new HashSet<Triple>();
        do {
            allT.addAll(T);
            G = new EdgeListGraph((Graph)G2);
            this.removeShields(G, allT);
            LvSwap.retainUnshieldedColliders(G);
            this.orientColliders(G, allT);
            T = new HashSet();
            List<Node> nodes = G.getNodes();
            for (Node y : nodes) {
                for (Node x : G.getAdjacentNodes(y)) {
                    block3: for (Node z : G.getAdjacentNodes(y)) {
                        int[] choice;
                        if (x == y || x == z || y == z || !G.isAdjacentTo(x, z)) continue;
                        scorer.goToBookmark();
                        ArrayList<Node> children = new ArrayList<Node>(scorer.getAdjacentNodes(y));
                        int _depth = this.depth < 0 ? children.size() : this.depth;
                        _depth = FastMath.min(_depth, children.size());
                        SublistGenerator gen = new SublistGenerator(children.size(), _depth);
                        while ((choice = gen.next()) != null) {
                            if (choice.length == 0) continue;
                            scorer.goToBookmark();
                            List<Node> Q = GraphUtils.asList(choice, children);
                            for (Node w : Q) {
                                scorer.moveTo(w, scorer.index(y));
                            }
                            if (!scorer.collider(x, y, z) || scorer.adjacent(x, z) || G.isDefCollider(x, y, z)) continue;
                            T.add(new Triple(x, y, z));
                            continue block3;
                        }
                    }
                }
            }
        } while (!allT.containsAll(T));
        this.finalOrientation(this.knowledge, G);
        return G;
    }

    @NotNull
    private static List<Node> getComplement(List<Node> X, List<Node> Y) {
        ArrayList<Node> complement = new ArrayList<Node>(X);
        complement.removeAll(Y);
        return complement;
    }

    private Set<Node> adj(Node x, List<Node> Y, Graph g) {
        HashSet<Node> adj = new HashSet<Node>();
        for (Node y : Y) {
            adj.addAll(g.getAdjacentNodes(y));
        }
        adj.remove(x);
        return adj;
    }

    public static void retainUnshieldedColliders(Graph graph) {
        EdgeListGraph orig = new EdgeListGraph(graph);
        graph.reorientAllWith(Endpoint.CIRCLE);
        List<Node> nodes = graph.getNodes();
        for (Node b : nodes) {
            int[] combination;
            List<Node> adjacentNodes = graph.getAdjacentNodes(b);
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null) {
                Node c;
                Node a = adjacentNodes.get(combination[0]);
                if (!orig.isDefCollider(a, b, c = adjacentNodes.get(combination[1])) || orig.isAdjacentTo(a, c)) continue;
                graph.setEndpoint(a, b, Endpoint.ARROW);
                graph.setEndpoint(c, b, Endpoint.ARROW);
            }
        }
    }

    public static void retainColliders(Graph graph) {
        EdgeListGraph orig = new EdgeListGraph(graph);
        graph.reorientAllWith(Endpoint.CIRCLE);
        List<Node> nodes = graph.getNodes();
        for (Node b : nodes) {
            int[] combination;
            List<Node> adjacentNodes = graph.getAdjacentNodes(b);
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null) {
                Node c;
                Node a = adjacentNodes.get(combination[0]);
                if (!orig.isDefCollider(a, b, c = adjacentNodes.get(combination[1]))) continue;
                graph.setEndpoint(a, b, Endpoint.ARROW);
                graph.setEndpoint(c, b, Endpoint.ARROW);
            }
        }
    }

    public static void retainArrows(Graph graph) {
        EdgeListGraph orig = new EdgeListGraph(graph);
        graph.reorientAllWith(Endpoint.CIRCLE);
        List<Node> nodes = graph.getNodes();
        for (Node x : nodes) {
            for (Node y : nodes) {
                if (x == y || orig.getEndpoint(x, y) != Endpoint.ARROW) continue;
                graph.setEndpoint(x, y, Endpoint.ARROW);
            }
        }
    }

    private void finalOrientation(Knowledge knowledge2, Graph G) {
        SepsetsGreedy sepsets = new SepsetsGreedy(G, this.test, null, this.depth);
        FciOrient fciOrient = new FciOrient(sepsets);
        fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed);
        fciOrient.setDoDiscriminatingPathColliderRule(false);
        fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathTailRule);
        fciOrient.setMaxPathLength(this.maxPathLength);
        fciOrient.setKnowledge(knowledge2);
        fciOrient.setVerbose(this.verbose);
        fciOrient.doFinalOrientation(G);
    }

    private void removeShields(Graph graph, Set<Triple> unshieldedColliders) {
        for (Triple triple : unshieldedColliders) {
            Node w;
            Node x = triple.getX();
            Edge edge = graph.getEdge(x, w = triple.getZ());
            if (edge == null) continue;
            graph.removeEdge(x, w);
            this.out.println("Removing (swap rule): " + edge);
        }
    }

    private void orientColliders(Graph graph, Set<Triple> unshieldedColliders) {
        for (Triple triple : unshieldedColliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node w = triple.getZ();
            if (!graph.isAdjacentTo(x, y) || !graph.isAdjacentTo(y, w) || graph.isDefCollider(x, y, w)) continue;
            graph.setEndpoint(x, y, Endpoint.ARROW);
            graph.setEndpoint(w, y, Endpoint.ARROW);
            this.out.println("Orienting collider (Swap rule): " + GraphUtils.pathString(graph, x, y, w));
        }
    }

    public void setCompleteRuleSetUsed(boolean completeRuleSetUsed) {
        this.completeRuleSetUsed = completeRuleSetUsed;
    }

    public void setMaxPathLength(int maxPathLength) {
        if (maxPathLength < -1) {
            throw new IllegalArgumentException("Max path length must be -1 (unlimited) or >= 0: " + maxPathLength);
        }
        this.maxPathLength = maxPathLength;
    }

    public void setTest(IndependenceTest test) {
        this.test = test;
    }

    public void setCovarianceMatrix(ICovarianceMatrix covarianceMatrix) {
        this.covarianceMatrix = covarianceMatrix;
    }

    public void setNumStarts(int numStarts) {
        this.numStarts = numStarts;
    }

    public void setDepth(int depth) {
        this.depth = depth;
    }

    public void setUseRaskuttiUhler(boolean useRaskuttiUhler) {
        this.useRaskuttiUhler = useRaskuttiUhler;
    }

    public void setUseScore(boolean useScore) {
        this.useScore = useScore;
    }

    public void setUseDataOrder(boolean useDataOrder) {
        this.useDataOrder = useDataOrder;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setOut(PrintStream out) {
        this.out = out;
    }

    public void setAlgType(AlgType bossAlgType) {
        this.algType = bossAlgType;
    }

    public void setBossAlgType(Boss.AlgType algType) {
        this.bossAlgType = algType;
    }

    public void setDoDefiniteDiscriminatingPathTailRule(boolean doDiscriminatingPathTailRule) {
        this.doDiscriminatingPathTailRule = doDiscriminatingPathTailRule;
    }

    public static enum AlgType {
        LVSwap1,
        LVSwap2a,
        LVSwap2b;

    }
}

