/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Pattern;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.PatternToDag;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.util.ChoiceGenerator;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

public class MbUtils {
    public static void trimToMbNodes(Graph graph, Node target, boolean includeBidirected) {
        if (includeBidirected) {
            List<Node> pc = graph.getAdjacentNodes(target);
            List<Node> children = graph.getNodesOutTo(target, Endpoint.ARROW);
            HashSet<Node> parentsOfChildren = new HashSet<Node>();
            for (Node v : children) {
                for (Node w : graph.getAdjacentNodes(v)) {
                    if (w == target || parentsOfChildren.contains(w) || pc.contains(w)) continue;
                    if (graph.isDefCollider(target, v, w)) {
                        parentsOfChildren.add(w);
                        continue;
                    }
                    if (!graph.getNodesInTo(v, Endpoint.ARROW).contains(target) || !graph.isUndirectedFromTo(v, w)) continue;
                    parentsOfChildren.add(w);
                }
            }
            HashSet<Node> allRelevantNodes = new HashSet<Node>();
            allRelevantNodes.add(target);
            allRelevantNodes.addAll(pc);
            allRelevantNodes.addAll(parentsOfChildren);
            List<Node> irrelevantNodes = graph.getNodes();
            irrelevantNodes.removeAll(allRelevantNodes);
            graph.removeNodes(irrelevantNodes);
        } else {
            LinkedList<Node> pc = new LinkedList<Node>();
            for (Node node : graph.getAdjacentNodes(target)) {
                if (!graph.isDirectedFromTo(target, node) && !graph.isDirectedFromTo(node, target) && !graph.isUndirectedFromTo(node, target)) continue;
                pc.add(node);
            }
            LinkedList<Node> children = new LinkedList<Node>();
            for (Node v : graph.getAdjacentNodes(target)) {
                if (children.contains(v) || !graph.isDirectedFromTo(target, v)) continue;
                children.add(v);
            }
            HashSet<Node> parentsOfChildren = new HashSet<Node>();
            for (Node v : children) {
                for (Node w : graph.getAdjacentNodes(v)) {
                    if (w == target || parentsOfChildren.contains(w) || pc.contains(w) || !graph.isDirectedFromTo(target, v) || !graph.isDirectedFromTo(w, v)) continue;
                    parentsOfChildren.add(w);
                }
            }
            HashSet<Node> allRelevantNodes = new HashSet<Node>();
            allRelevantNodes.add(target);
            allRelevantNodes.addAll(pc);
            allRelevantNodes.addAll(parentsOfChildren);
            List<Node> irrelevantNodes = graph.getNodes();
            irrelevantNodes.removeAll(allRelevantNodes);
            graph.removeNodes(irrelevantNodes);
        }
    }

    public static void trimEdgesAmongParents(Graph graph, Node target) {
        List<Node> parents = graph.getParents(target);
        if (parents.size() >= 2) {
            int[] choice;
            ChoiceGenerator cg = new ChoiceGenerator(parents.size(), 2);
            while ((choice = cg.next()) != null) {
                Node w;
                Node v = parents.get(choice[0]);
                Edge edge = graph.getEdge(v, w = parents.get(choice[1]));
                if (edge == null) continue;
                graph.removeEdges(v, w);
            }
        }
    }

    public static void trimEdgesAmongParentsOfChildren(Graph graph, Node target) {
        List<Node> children = graph.getNodesOutTo(target, Endpoint.ARROW);
        HashSet<Node> parents = new HashSet<Node>();
        for (Node aChildren : children) {
            parents.addAll(graph.getParents(aChildren));
        }
        parents.remove(target);
        parents.removeAll(graph.getAdjacentNodes(target));
        ArrayList parentsOfChildren = new ArrayList(parents);
        if (parentsOfChildren.size() >= 2) {
            int[] choice;
            ChoiceGenerator cg = new ChoiceGenerator(parentsOfChildren.size(), 2);
            while ((choice = cg.next()) != null) {
                Node w;
                Node v = (Node)parentsOfChildren.get(choice[0]);
                Edge edge = graph.getEdge(v, w = (Node)parentsOfChildren.get(choice[1]));
                if (edge == null) continue;
                graph.removeEdge(v, w);
            }
        }
    }

    public static void trimToAdjacents(Graph graph, Node target) {
        for (Node node : graph.getNodes()) {
            if (node == target || graph.isAdjacentTo(node, target)) continue;
            graph.removeNode(node);
        }
    }

    public static void trimToNeighborhood(Graph graph, List<Node> neighborhood) {
        List<Node> irrelevantNodes = graph.getNodes();
        irrelevantNodes.removeAll(neighborhood);
        graph.removeNodes(irrelevantNodes);
    }

    public static void trimToDistance(Graph graph, Node target, int distance) {
        Set<Node> nodes = MbUtils.getNeighborhood(graph, target, distance);
        List<Node> irrelevantNodes = graph.getNodes();
        irrelevantNodes.removeAll(nodes);
        graph.removeNodes(irrelevantNodes);
    }

    public static Set<Node> getNeighborhood(Graph graph, Node target, int distance) {
        if (distance < 1) {
            throw new IllegalArgumentException("Distance must be >= 1.");
        }
        HashSet<Node> nodes = new HashSet<Node>();
        nodes.add(target);
        HashSet tier = new HashSet(nodes);
        for (int i = 0; i < distance; ++i) {
            HashSet<Node> adjacents = new HashSet<Node>();
            for (Node aTier : tier) {
                adjacents.addAll(graph.getAdjacentNodes(aTier));
            }
            nodes.addAll(adjacents);
            tier = new HashSet(adjacents);
        }
        return nodes;
    }

    public static List<Graph> generatePatternDags(Graph pattern, boolean orientBidirectedEdges) {
        return new LinkedList<Graph>(MbUtils.listPatternDags(new EdgeListGraph(pattern), orientBidirectedEdges));
    }

    private static Set<Graph> listPatternDags(Graph mbPattern, boolean orientBidirectedEdges) {
        HashSet<Graph> dags = new HashSet<Graph>();
        EdgeListGraph graph = new EdgeListGraph(mbPattern);
        MeekRules rules = new MeekRules();
        rules.orientImplied(graph);
        List<Edge> edges = graph.getEdges();
        Edge edge = null;
        for (Edge _edge : edges) {
            if (orientBidirectedEdges && Edges.isBidirectedEdge(_edge)) {
                edge = _edge;
                break;
            }
            if (!Edges.isUndirectedEdge(_edge)) continue;
            edge = _edge;
            break;
        }
        if (edge == null) {
            dags.add(graph);
            return dags;
        }
        graph.setEndpoint(edge.getNode2(), edge.getNode1(), Endpoint.TAIL);
        graph.setEndpoint(edge.getNode1(), edge.getNode2(), Endpoint.ARROW);
        dags.addAll(MbUtils.listPatternDags(graph, orientBidirectedEdges));
        graph.setEndpoint(edge.getNode1(), edge.getNode2(), Endpoint.TAIL);
        graph.setEndpoint(edge.getNode2(), edge.getNode1(), Endpoint.ARROW);
        dags.addAll(MbUtils.listPatternDags(graph, orientBidirectedEdges));
        return dags;
    }

    public static List<Graph> generateMbDags(Graph mbPattern, boolean orientBidirectedEdges, IndependenceTest test, int depth, Node target) {
        return new LinkedList<Graph>(MbUtils.listMbDags(new EdgeListGraph(mbPattern), orientBidirectedEdges, test, depth, target));
    }

    private static Set<Graph> listMbDags(Graph mbPattern, boolean orientBidirectedEdges, IndependenceTest test, int depth, Node target) {
        HashSet<Graph> dags = new HashSet<Graph>();
        EdgeListGraph graph = new EdgeListGraph(mbPattern);
        MbUtils.doAbbreviatedMbOrientation(graph, test, depth, target);
        List<Edge> edges = graph.getEdges();
        Edge edge = null;
        for (Edge _edge : edges) {
            if (orientBidirectedEdges && Edges.isBidirectedEdge(_edge)) {
                edge = _edge;
                break;
            }
            if (!Edges.isUndirectedEdge(_edge)) continue;
            edge = _edge;
            break;
        }
        if (edge == null) {
            dags.add(graph);
            return dags;
        }
        graph.setEndpoint(edge.getNode2(), edge.getNode1(), Endpoint.TAIL);
        graph.setEndpoint(edge.getNode1(), edge.getNode2(), Endpoint.ARROW);
        dags.addAll(MbUtils.listMbDags(graph, orientBidirectedEdges, test, depth, target));
        graph.setEndpoint(edge.getNode1(), edge.getNode2(), Endpoint.TAIL);
        graph.setEndpoint(edge.getNode2(), edge.getNode1(), Endpoint.ARROW);
        dags.addAll(MbUtils.listMbDags(graph, orientBidirectedEdges, test, depth, target));
        return dags;
    }

    public static Dag getOneMbDag(Graph mbPattern) {
        PatternToDag search = new PatternToDag(new Pattern(mbPattern));
        return search.patternToDagMeek();
    }

    private static void doAbbreviatedMbOrientation(Graph graph, IndependenceTest test, int depth, Node target) {
        SearchGraphUtils.orientUsingMeekRulesLocally(new Knowledge(), graph, test, depth);
        MbUtils.trimToMbNodes(graph, target, false);
        MbUtils.trimEdgesAmongParents(graph, target);
        MbUtils.trimEdgesAmongParentsOfChildren(graph, target);
    }
}

