/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Fas;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.search.IndTestDSep;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.search.SepsetMap;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class PcSearchRsch
implements GraphSearch {
    private IndependenceTest independenceTest;
    private IndependenceTest graphicalTest;
    private Knowledge knowledge;
    private SepsetMap sepset;
    private int depth = Integer.MAX_VALUE;
    private Graph graph;
    private Graph trueGraph;
    private int cefp;
    private int cefn;
    private int cindfp;
    private int collfp;
    private long elapsedTime;
    private int numTests;

    public PcSearchRsch(IndependenceTest independenceTest, Knowledge knowledge) {
        if (independenceTest == null) {
            throw new NullPointerException();
        }
        if (knowledge == null) {
            throw new NullPointerException();
        }
        this.independenceTest = independenceTest;
        this.knowledge = knowledge;
    }

    public IndependenceTest getIndependenceTest() {
        return this.independenceTest;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public SepsetMap getSepset() {
        return this.sepset;
    }

    public int getDepth() {
        return this.depth;
    }

    public void setDepth(int depth) {
        this.depth = depth;
    }

    public Graph getPartialGraph() {
        return new EdgeListGraph(this.graph);
    }

    @Override
    public Graph search() {
        return this.search(this.independenceTest.getVariables());
    }

    @Override
    public long getElapsedTime() {
        return this.elapsedTime;
    }

    public Graph search(List<Node> nodes) {
        TetradLogger.getInstance().log("info", "Starting PC algorithm.");
        TetradLogger.getInstance().log("info", "Independence test = " + this.independenceTest + ".");
        long startTime = System.currentTimeMillis();
        if (this.getIndependenceTest() == null) {
            throw new NullPointerException();
        }
        List<Node> allNodes = this.getIndependenceTest().getVariables();
        if (!allNodes.containsAll(nodes)) {
            throw new IllegalArgumentException("All of the given nodes must be in the domain of the independence test provided.");
        }
        this.graph = new EdgeListGraph(nodes);
        this.graph.fullyConnect(Endpoint.TAIL);
        Fas fas = new Fas(this.graph, this.getIndependenceTest());
        fas.setKnowledge(this.getKnowledge());
        fas.setDepth(this.getDepth());
        this.graph = fas.search();
        this.sepset = fas.getSepsets();
        this.pcOrientbk(this.knowledge, this.graph, nodes);
        this.orientCollidersUsingSepsets(this.sepset, this.knowledge, this.graph);
        MeekRules rules = new MeekRules();
        rules.setKnowledge(this.knowledge);
        rules.orientImplied(this.graph);
        TetradLogger.getInstance().log("graph", "\nReturning this graph: " + this.graph);
        long endTime = System.currentTimeMillis();
        this.elapsedTime = endTime - startTime;
        TetradLogger.getInstance().log("info", "Elapsed time = " + (double)this.elapsedTime / 1000.0 + " s");
        TetradLogger.getInstance().log("info", "Finishing PC algorithm.");
        return this.graph;
    }

    public void pcOrientbk(Knowledge bk, Graph graph, List<Node> nodes) {
        Node to;
        Node from;
        KnowledgeEdge edge;
        TetradLogger.getInstance().log("info", "Starting BK Orientation.");
        Iterator<KnowledgeEdge> it = bk.forbiddenEdgesIterator();
        while (it.hasNext()) {
            edge = it.next();
            from = this.translate(edge.getFrom(), nodes);
            to = this.translate(edge.getTo(), nodes);
            if (from == null || to == null || graph.getEdge(from, to) == null) continue;
            graph.removeEdge(from, to);
            graph.addDirectedEdge(from, to);
            graph.setEndpoint(from, to, Endpoint.TAIL);
            graph.setEndpoint(to, from, Endpoint.ARROW);
            TetradLogger.getInstance().log("impliedOrientation", SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(to, from)));
        }
        it = bk.requiredEdgesIterator();
        while (it.hasNext()) {
            edge = it.next();
            from = this.translate(edge.getFrom(), nodes);
            to = this.translate(edge.getTo(), nodes);
            if (from == null || to == null || graph.getEdge(from, to) == null) continue;
            graph.setEndpoint(to, from, Endpoint.TAIL);
            graph.setEndpoint(from, to, Endpoint.ARROW);
            TetradLogger.getInstance().log("impliedOrientation", SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to)));
        }
        TetradLogger.getInstance().log("info", "Finishing BK Orientation.");
    }

    public void orientCollidersUsingSepsets(SepsetMap set, Knowledge knowledge, Graph graph) {
        TetradLogger.getInstance().log("info", "Starting Collider Orientation:");
        List<Node> nodes = graph.getNodes();
        for (Node a : nodes) {
            int[] combination;
            List<Node> adjacentNodes = graph.getAdjacentNodes(a);
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null) {
                Node b = adjacentNodes.get(combination[0]);
                Node c = adjacentNodes.get(combination[1]);
                Node trueA = this.getTrueGraph().getNode(a.getName());
                Node trueB = this.getTrueGraph().getNode(b.getName());
                Node trueC = this.getTrueGraph().getNode(c.getName());
                Graph trueGraph = this.getTrueGraph();
                if (graph.isAdjacentTo(b, c)) continue;
                List<Node> sepset = set.get(b, c);
                if (sepset == null) {
                    throw new IllegalArgumentException();
                }
                if (!sepset.contains(a) && this.isArrowpointAllowed(b, a, knowledge) && this.isArrowpointAllowed(c, a, knowledge)) {
                    graph.setEndpoint(b, a, Endpoint.ARROW);
                    graph.setEndpoint(c, a, Endpoint.ARROW);
                    TetradLogger.getInstance().log("info", SearchLogUtils.colliderOrientedMsg(b, a, c, sepset));
                    if (!trueGraph.isAdjacentTo(trueB, trueA) || !trueGraph.isAdjacentTo(trueC, trueA)) {
                        ++this.cefp;
                    }
                    if (trueGraph.isAdjacentTo(trueB, trueC)) {
                        ++this.cefn;
                    }
                    if (trueGraph.getEndpoint(trueB, trueA) == Endpoint.ARROW && trueGraph.getEndpoint(trueC, trueA) == Endpoint.ARROW) {
                        ++this.collfp;
                        this.printSubsetMessage(graph, b, a, c, sepset);
                    } else {
                        this.printSubsetMessage(graph, b, a, c, sepset);
                    }
                    LinkedList<Node> trueS = new LinkedList<Node>();
                    for (Node s : sepset) {
                        trueS.add(this.getTrueGraph().getNode(s.getName()));
                    }
                    if (trueGraph.isDSeparatedFrom(trueB, trueC, trueS)) continue;
                    ++this.cindfp;
                    continue;
                }
                if (trueGraph.getEndpoint(trueB, trueA) == Endpoint.ARROW && trueGraph.getEndpoint(trueC, trueA) == Endpoint.ARROW) continue;
                ++this.collfp;
            }
        }
        TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
    }

    private boolean isCollider(Node x, Node y, Node z) {
        if (this.graph.isAdjacentTo(x, z)) {
            return false;
        }
        return !this.existsLocalSepsetWith(x, y, z, this.independenceTest, this.graph, this.depth);
    }

    private boolean isGraphicalCollider(Node x, Node y, Node z) {
        if (this.graph.isAdjacentTo(x, z)) {
            return false;
        }
        return !this.existsLocalSepsetWithGraphical(x, y, z, this.depth);
    }

    public boolean existsLocalSepsetWith(Node x, Node y, Node z, IndependenceTest test, Graph graph, int depth) {
        this.numTests = 0;
        Node trueX = this.getTrueGraph().getNode(x.getName());
        Node trueY = this.getTrueGraph().getNode(y.getName());
        Node trueZ = this.getTrueGraph().getNode(z.getName());
        HashSet<Node> __nodes = new HashSet<Node>(this.trueGraph.getAdjacentNodes(trueX));
        __nodes.addAll(this.trueGraph.getAdjacentNodes(trueZ));
        __nodes.remove(trueX);
        __nodes.remove(trueZ);
        LinkedList<Node> _nodes = new LinkedList<Node>();
        for (Node node : __nodes) {
            _nodes.add(graph.getNode(node.getName()));
        }
        TetradLogger.getInstance().log("details", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes);
        int _depth = depth;
        if (_depth == -1) {
            _depth = Integer.MAX_VALUE;
        }
        _depth = Math.min(_depth, _nodes.size());
        for (int d = 1; d <= _depth; ++d) {
            int[] choice;
            if (_nodes.size() < d) continue;
            ChoiceGenerator cg2 = new ChoiceGenerator(_nodes.size(), d);
            while ((choice = cg2.next()) != null) {
                List<Node> condSet = PcSearchRsch.asList(choice, _nodes);
                if (!condSet.contains(y)) continue;
                boolean independent = test.isIndependent(x, z, condSet);
                ++this.numTests;
                if (!independent) continue;
                return true;
            }
        }
        return false;
    }

    public boolean existsLocalSepsetWithGraphical(Node x, Node y, Node z, int depth) {
        this.numTests = 0;
        Node trueX = this.trueGraph.getNode(x.getName());
        Node trueY = this.trueGraph.getNode(y.getName());
        Node trueZ = this.trueGraph.getNode(z.getName());
        LinkedList<Node> _nodes = new LinkedList<Node>();
        _nodes.addAll(this.trueGraph.getAdjacentNodes(trueX));
        _nodes.addAll(this.trueGraph.getAdjacentNodes(trueZ));
        TetradLogger.getInstance().log("details", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes);
        int _depth = depth;
        if (_depth == -1) {
            _depth = Integer.MAX_VALUE;
        }
        _depth = Math.min(_depth, _nodes.size());
        for (int d = 1; d <= _depth; ++d) {
            int[] choice;
            if (_nodes.size() < d) continue;
            ChoiceGenerator cg2 = new ChoiceGenerator(_nodes.size(), d);
            while ((choice = cg2.next()) != null) {
                List<Node> condSet = PcSearchRsch.asList(choice, _nodes);
                if (!condSet.contains(trueY)) continue;
                boolean independent = this.trueGraph.isDSeparatedFrom(trueX, trueZ, condSet);
                ++this.numTests;
                if (!independent) continue;
                return true;
            }
        }
        return false;
    }

    public static List<Node> asList(int[] indices, List<Node> nodes) {
        LinkedList<Node> list = new LinkedList<Node>();
        for (int i : indices) {
            list.add(nodes.get(i));
        }
        return list;
    }

    private void printSubsetMessage(Graph graph, Node b, Node a, Node c, List<Node> sepset) {
        Node trueA = this.trueGraph.getNode(a.getName());
        Node trueB = this.trueGraph.getNode(b.getName());
        Node trueC = this.trueGraph.getNode(c.getName());
        StringBuilder triple = new StringBuilder();
        triple.append(b);
        if (this.trueGraph.isAdjacentTo(trueA, trueB)) {
            if (this.trueGraph.isDirectedFromTo(trueA, trueB)) {
                triple.append("<--");
            } else {
                triple.append("-->");
            }
        } else {
            triple.append("   ");
        }
        triple.append(a);
        if (this.trueGraph.isAdjacentTo(trueA, trueC)) {
            if (this.trueGraph.isDirectedFromTo(trueA, trueC)) {
                triple.append("-->");
            } else {
                triple.append("<--");
            }
        } else {
            triple.append("   ");
        }
        triple.append(c);
        boolean unshielded = !this.trueGraph.isAdjacentTo(trueB, trueC);
        boolean dsep = this.trueGraph.isDSeparatedFrom(trueB, trueC, new LinkedList<Node>());
        boolean localCol = this.isCollider(b, a, c);
        boolean graphicalCol = this.isGraphicalCollider(b, a, c);
        System.out.println(triple + "\t" + sepset + "\t" + (unshielded ? "T" : "F") + "\t" + (dsep ? "T" : "F") + "\t" + (localCol ? "T" : "F") + "\t" + (graphicalCol ? "T" : "F"));
    }

    public Node translate(String a, List<Node> nodes) {
        for (Node node : nodes) {
            if (!node.getName().equals(a)) continue;
            return node;
        }
        return null;
    }

    public boolean isArrowpointAllowed(Object from, Object to, Knowledge knowledge) {
        if (knowledge == null) {
            return true;
        }
        return !knowledge.edgeRequired(to.toString(), from.toString()) && !knowledge.edgeForbidden(from.toString(), to.toString());
    }

    public int getCefp() {
        return this.cefp;
    }

    public int getCefn() {
        return this.cefn;
    }

    public int getCindfp() {
        return this.cindfp;
    }

    public int getCollfp() {
        return this.collfp;
    }

    public int getBide() {
        int numBidirected = 0;
        for (Edge edge : this.graph.getEdges()) {
            if (!Edges.isBidirectedEdge(edge)) continue;
            ++numBidirected;
        }
        return numBidirected;
    }

    public int getAlle() {
        return this.graph.getNumEdges();
    }

    public void setTrueGraph(Dag trueGraph) {
        this.trueGraph = trueGraph;
        this.graphicalTest = new IndTestDSep(trueGraph);
    }

    public Graph getTrueGraph() {
        return this.trueGraph;
    }
}

