/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.FasConcurrent;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.OrientCollidersMaxP;
import edu.cmu.tetrad.search.SepsetMap;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;

public final class CcdMax
implements GraphSearch {
    private final IndependenceTest independenceTest;
    private int depth = -1;
    private boolean applyOrientAwayFromCollider;
    private long elapsed;
    private Knowledge knowledge = new Knowledge();
    private boolean useHeuristic = true;
    private int maxPathLength = 3;
    private boolean useOrientTowardDConnections = true;
    private boolean orientConcurrentFeedbackLoops = true;
    private boolean doColliderOrientations = true;
    private boolean collapseTiers;
    private SepsetMap sepsetMap;

    public CcdMax(IndependenceTest test) {
        if (test == null) {
            throw new NullPointerException();
        }
        this.independenceTest = test;
    }

    @Override
    public Graph search() {
        System.out.println("FAS");
        Graph graph = this.fastAdjacencySearch();
        System.out.println("Orienting from background knowledge");
        for (Edge edge : graph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            if (this.knowledge.isForbidden(y.getName(), x.getName()) || this.knowledge.isRequired(x.getName(), y.getName())) {
                graph.removeEdge(x, y);
                graph.addDirectedEdge(x, y);
                continue;
            }
            if (!this.knowledge.isForbidden(x.getName(), y.getName()) && !this.knowledge.isRequired(y.getName(), x.getName())) continue;
            graph.removeEdge(y, x);
            graph.addDirectedEdge(y, x);
        }
        System.out.println("Bishop's hat");
        if (this.orientConcurrentFeedbackLoops) {
            this.orientTwoShieldConstructs(graph);
        }
        System.out.println("Max P collider orientation");
        if (this.doColliderOrientations) {
            OrientCollidersMaxP orientCollidersMaxP = new OrientCollidersMaxP(this.independenceTest);
            orientCollidersMaxP.setUseHeuristic(this.useHeuristic);
            orientCollidersMaxP.setMaxPathLength(this.maxPathLength);
            orientCollidersMaxP.setKnowledge(this.knowledge);
            orientCollidersMaxP.orient(graph);
        }
        System.out.println("Orient away from collider");
        if (this.applyOrientAwayFromCollider) {
            this.orientAwayFromArrow(graph);
        }
        System.out.println("Toward D-connection");
        if (this.useOrientTowardDConnections) {
            this.orientTowardDConnection(graph);
        }
        System.out.println("Done");
        if (this.collapseTiers) {
            return this.collapseGraph(graph);
        }
        return graph;
    }

    private Graph collapseGraph(Graph graph) {
        Node x;
        ArrayList<Node> nodes = new ArrayList<Node>();
        for (String n : this.independenceTest.getVariableNames()) {
            String[] s = n.split(":");
            if (s.length != 1) continue;
            x = this.independenceTest.getVariable(s[0]);
            nodes.add(x);
        }
        EdgeListGraph _graph = new EdgeListGraph(nodes);
        for (Edge edge : graph.getEdges()) {
            Node yy;
            x = edge.getNode1();
            Node y = edge.getNode2();
            String[] sx = x.getName().split(":");
            String[] sy = y.getName().split(":");
            int lagx = sx.length == 1 ? 0 : Integer.parseInt(sx[1]);
            int lagy = sy.length == 1 ? 0 : Integer.parseInt(sy[1]);
            int maxInto = this.knowledge.getNumTiers() - 1;
            if (!(!edge.pointsTowards(x) && lagy < maxInto || !edge.pointsTowards(y) && lagx < maxInto)) continue;
            String xName = sx[0];
            String yName = sy[0];
            Node xx = this.independenceTest.getVariable(xName);
            if (xx == (yy = this.independenceTest.getVariable(yName))) continue;
            Edge _edge = new Edge(xx, yy, edge.getEndpoint1(), edge.getEndpoint2());
            if (!_graph.containsEdge(_edge)) {
                _graph.addEdge(_edge);
            }
            Edge undir = Edges.undirectedEdge(xx, yy);
            if (_graph.getEdges(xx, yy).size() <= 1 || !_graph.containsEdge(undir)) continue;
            _graph.removeEdge(undir);
        }
        return _graph;
    }

    public int getDepth() {
        return this.depth;
    }

    public void setDepth(int depth) {
        this.depth = depth;
    }

    public long getElapsedTime() {
        return this.elapsed;
    }

    public void setApplyOrientAwayFromCollider(boolean applyOrientAwayFromCollider) {
        this.applyOrientAwayFromCollider = applyOrientAwayFromCollider;
    }

    private Graph fastAdjacencySearch() {
        long start = MillisecondTimes.timeMillis();
        FasConcurrent fas = new FasConcurrent(this.independenceTest);
        fas.setStable(true);
        fas.setDepth(this.getDepth());
        fas.setKnowledge(this.knowledge);
        fas.setVerbose(false);
        Graph graph = fas.search();
        if (this.useOrientTowardDConnections) {
            this.sepsetMap = fas.getSepsets();
        }
        long stop = MillisecondTimes.timeMillis();
        this.elapsed = stop - start;
        return new EdgeListGraph(graph);
    }

    private void orientTwoShieldConstructs(Graph graph) {
        TetradLogger.getInstance().log("info", "\nStep E");
        for (Node c : graph.getNodes()) {
            List<Node> adj = graph.getAdjacentNodes(c);
            for (int i = 0; i < adj.size(); ++i) {
                Node a = adj.get(i);
                for (int j = i + 1; j < adj.size(); ++j) {
                    Node b = adj.get(j);
                    if (a == b || graph.isAdjacentTo(a, b)) continue;
                    for (Node d : adj) {
                        if (d == a || d == b || !graph.isAdjacentTo(d, a) || !graph.isAdjacentTo(d, b) || this.sepset(graph, a, b, this.set(new Node[0]), this.set(c, d)) == null || graph.getEdges().size() == 2 || Edges.isDirectedEdge(graph.getEdge(c, d)) || graph.getEdge(a, c).pointsTowards(a) || graph.getEdge(a, d).pointsTowards(a) || graph.getEdge(b, c).pointsTowards(b) || graph.getEdge(b, d).pointsTowards(b) || this.sepset(graph, a, b, this.set(c, d), this.set(new Node[0])) == null) continue;
                        this.orientCollider(graph, a, c, b);
                        this.orientCollider(graph, a, d, b);
                        this.addFeedback(graph, c, d);
                    }
                }
            }
        }
    }

    private void orientTowardDConnection(Graph graph) {
        block0: for (Edge edge : graph.getEdges()) {
            List<Node> sepsetay;
            List<Node> sepsetax;
            if (!Edges.isUndirectedEdge(edge)) continue;
            HashSet<Node> surround = new HashSet<Node>();
            Node b = edge.getNode1();
            Node c = edge.getNode2();
            surround.add(b);
            for (int i = 1; i < 3; ++i) {
                for (Node z : new HashSet(surround)) {
                    surround.addAll(graph.getAdjacentNodes(z));
                }
            }
            surround.remove(b);
            surround.remove(c);
            graph.getAdjacentNodes(b).forEach(surround::remove);
            graph.getAdjacentNodes(c).forEach(surround::remove);
            boolean orient = false;
            boolean agree = true;
            for (Node a : surround) {
                sepsetax = this.maxPSepset(a, b, graph).getCond();
                sepsetay = this.maxPSepset(a, c, graph).getCond();
                if (sepsetax == null || sepsetay == null || sepsetax.equals(sepsetay)) continue;
                if (sepsetax.containsAll(sepsetay)) {
                    orient = true;
                    continue;
                }
                agree = false;
            }
            if (orient && agree) {
                this.addDirectedEdge(graph, c, b);
            }
            for (Node a : surround) {
                if (b == a || c == a || graph.getAdjacentNodes(b).contains(a) || graph.getAdjacentNodes(c).contains(a)) continue;
                sepsetax = this.sepsetMap.get(a, b);
                sepsetay = this.sepsetMap.get(a, c);
                if (sepsetax == null || sepsetay == null || sepsetay.contains(b) || sepsetay.containsAll(sepsetax) || this.independenceTest.checkIndependence(a, b, sepsetay).independent()) continue;
                this.addDirectedEdge(graph, c, b);
                continue block0;
            }
        }
    }

    private void addDirectedEdge(Graph graph, Node a, Node b) {
        graph.removeEdges(a, b);
        graph.addDirectedEdge(a, b);
        this.orientAwayFromArrow(graph, a, b);
    }

    private void addFeedback(Graph graph, Node a, Node b) {
        graph.removeEdges(a, b);
        graph.addEdge(Edges.directedEdge(a, b));
        graph.addEdge(Edges.directedEdge(b, a));
    }

    private void orientCollider(Graph graph, Node a, Node b, Node c) {
        if (this.wouldCreateBadCollider(graph, a, b)) {
            return;
        }
        if (this.wouldCreateBadCollider(graph, c, b)) {
            return;
        }
        if (graph.getEdges(a, b).size() > 1) {
            return;
        }
        if (graph.getEdges(b, c).size() > 1) {
            return;
        }
        if (this.knowledge.isForbidden(a.getName(), b.getName())) {
            return;
        }
        if (this.knowledge.isForbidden(c.getName(), b.getName())) {
            return;
        }
        graph.removeEdge(a, b);
        graph.removeEdge(c, b);
        graph.addDirectedEdge(a, b);
        graph.addDirectedEdge(c, b);
    }

    private void orientAwayFromArrow(Graph graph, Node a, Node b) {
        if (!this.applyOrientAwayFromCollider) {
            return;
        }
        for (Node c : graph.getAdjacentNodes(b)) {
            if (c == a) continue;
            this.orientAwayFromArrowVisit(a, b, c, graph);
        }
    }

    private boolean wouldCreateBadCollider(Graph graph, Node x, Node y) {
        for (Node z : graph.getAdjacentNodes(y)) {
            if (x == z || !graph.isDefCollider(x, y, z)) continue;
            return true;
        }
        return false;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public void setUseHeuristic(boolean useHeuristic) {
        this.useHeuristic = useHeuristic;
    }

    public int getMaxPathLength() {
        return this.maxPathLength;
    }

    public void setMaxPathLength(int maxPathLength) {
        this.maxPathLength = maxPathLength;
    }

    public void setUseOrientTowardDConnections(boolean useOrientTowardDConnections) {
        this.useOrientTowardDConnections = useOrientTowardDConnections;
    }

    public void setOrientConcurrentFeedbackLoops(boolean orientConcurrentFeedbackLoops) {
        this.orientConcurrentFeedbackLoops = orientConcurrentFeedbackLoops;
    }

    public void setDoColliderOrientations(boolean doColliderOrientations) {
        this.doColliderOrientations = doColliderOrientations;
    }

    public void setCollapseTiers(boolean collapseTiers) {
        this.collapseTiers = collapseTiers;
    }

    private Pair maxPSepset(Node i, Node k, Graph graph) {
        double _p = Double.POSITIVE_INFINITY;
        List<Node> _v = null;
        List<Node> adji = graph.getAdjacentNodes(i);
        List<Node> adjk = graph.getAdjacentNodes(k);
        adji.remove(k);
        adjk.remove(i);
        for (int d = 0; d <= FastMath.min(this.depth == -1 ? 1000 : this.depth, FastMath.max(adji.size(), adjk.size())); ++d) {
            double p2;
            List<Node> v2;
            int[] choice;
            ChoiceGenerator gen;
            if (d <= adji.size()) {
                gen = new ChoiceGenerator(adji.size(), d);
                while ((choice = gen.next()) != null) {
                    v2 = GraphUtils.asList(choice, adji);
                    if (this.isForbidden(i, k, v2)) continue;
                    try {
                        this.getIndependenceTest().checkIndependence(i, k, v2);
                        p2 = this.getIndependenceTest().getScore();
                        if (!(p2 < _p)) continue;
                        _p = p2;
                        _v = v2;
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                        return new Pair(null, Double.POSITIVE_INFINITY);
                    }
                }
            }
            if (d > adjk.size()) continue;
            gen = new ChoiceGenerator(adjk.size(), d);
            while ((choice = gen.next()) != null) {
                v2 = GraphUtils.asList(choice, adjk);
                try {
                    this.getIndependenceTest().checkIndependence(i, k, v2);
                    p2 = this.getIndependenceTest().getScore();
                    if (!(p2 < _p)) continue;
                    _p = p2;
                    _v = v2;
                }
                catch (Exception e) {
                    e.printStackTrace();
                    return new Pair(null, Double.POSITIVE_INFINITY);
                }
            }
        }
        return new Pair(_v, _p);
    }

    private boolean isForbidden(Node i, Node k, List<Node> v) {
        for (Node w : v) {
            if (this.knowledge.isForbidden(w.getName(), i.getName())) {
                return true;
            }
            if (!this.knowledge.isForbidden(w.getName(), k.getName())) continue;
            return true;
        }
        return false;
    }

    private List<Node> sepset(Graph graph, Node a, Node c, Set<Node> containing, Set<Node> notContaining) {
        List<Node> adj = graph.getAdjacentNodes(a);
        adj.addAll(graph.getAdjacentNodes(c));
        adj.remove(c);
        adj.remove(a);
        for (int d = 0; d <= FastMath.min(this.depth == -1 ? 1000 : this.depth, adj.size()); ++d) {
            int[] choice;
            ChoiceGenerator gen = new ChoiceGenerator(adj.size(), d);
            while ((choice = gen.next()) != null) {
                Set<Node> v2 = GraphUtils.asSet(choice, adj);
                v2.addAll(containing);
                v2.removeAll(notContaining);
                v2.remove(a);
                v2.remove(c);
                if (this.isForbidden(a, c, new ArrayList<Node>(v2))) continue;
                this.getIndependenceTest().checkIndependence(a, c, new ArrayList<Node>(v2));
                double p2 = this.getIndependenceTest().getScore();
                if (!(p2 < 0.0)) continue;
                return new ArrayList<Node>(v2);
            }
        }
        return null;
    }

    private Set<Node> set(Node ... n) {
        HashSet<Node> S = new HashSet<Node>();
        Collections.addAll(S, n);
        return S;
    }

    private IndependenceTest getIndependenceTest() {
        return this.independenceTest;
    }

    private void orientAwayFromArrow(Graph graph) {
        for (Edge edge : graph.getEdges()) {
            Node n2;
            Node n1 = edge.getNode1();
            if ((edge = graph.getEdge(n1, n2 = edge.getNode2())).pointsTowards(n1)) {
                this.orientAwayFromArrow(graph, n2, n1);
                continue;
            }
            if (!edge.pointsTowards(n2)) continue;
            this.orientAwayFromArrow(graph, n1, n2);
        }
    }

    private void orientAwayFromArrowVisit(Node a, Node b, Node c, Graph graph) {
        Edge e;
        Node d;
        Edge f;
        if (graph.getEdges(b, c).size() > 1) {
            return;
        }
        if (!Edges.isUndirectedEdge(graph.getEdge(b, c))) {
            return;
        }
        if (graph.isAdjacentTo(a, c)) {
            return;
        }
        if (this.knowledge.isForbidden(b.getName(), c.getName())) {
            return;
        }
        if (this.wouldCreateBadCollider(graph, b, c)) {
            return;
        }
        this.addDirectedEdge(graph, b, c);
        ArrayList<Edge> undirectedEdges = new ArrayList<Edge>();
        for (Node d2 : graph.getAdjacentNodes(c)) {
            Edge e2;
            if (d2 == b || !Edges.isUndirectedEdge(e2 = graph.getEdge(c, d2))) continue;
            undirectedEdges.add(e2);
        }
        for (Node d2 : graph.getAdjacentNodes(c)) {
            if (d2 == b) continue;
            this.orientAwayFromArrowVisit(b, c, d2, graph);
        }
        Iterator<Node> iterator = undirectedEdges.iterator();
        while (iterator.hasNext() && (f = graph.getEdge(c, d = Edges.traverse(c, e = (Edge)((Object)iterator.next())))).pointsTowards(d)) {
        }
    }

    private static class Pair {
        private final List<Node> cond;
        private final double score;

        Pair(List<Node> cond, double score) {
            this.cond = cond;
            this.score = score;
        }

        public List<Node> getCond() {
            return this.cond;
        }

        public double getScore() {
            return this.score;
        }
    }
}

