/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Triple;
import edu.cmu.tetrad.search.IGraphSearch;
import edu.cmu.tetrad.search.IMbSearch;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;

public final class PcMb
implements IMbSearch,
IGraphSearch {
    private final IndependenceTest test;
    private final TetradLogger logger = TetradLogger.getInstance();
    private List<Node> variables;
    private List<Node> targets;
    private int depth;
    private Graph resultGraph;
    private int numIndependenceTests;
    private int[] maxRemainingAtDepth;
    private Set<Node> a;
    private long elapsedTime;
    private Knowledge knowledge = new Knowledge();
    private Set<Triple> ambiguousTriples;
    private boolean meekPreventCycles;
    private boolean findMb = false;

    public PcMb(IndependenceTest test, int depth) {
        if (test == null) {
            throw new NullPointerException();
        }
        if (depth == -1) {
            depth = Integer.MAX_VALUE;
        }
        if (depth < 0) {
            throw new IllegalArgumentException("Depth must be >= -1: " + depth);
        }
        this.test = test;
        this.depth = depth;
        this.variables = test.getVariables();
    }

    private static boolean isArrowheadAllowed1(Node from, Node to, Knowledge knowledge) {
        if (knowledge == null) {
            return true;
        }
        return !knowledge.isRequired(to.toString(), from.toString()) && !knowledge.isForbidden(from.toString(), to.toString());
    }

    public void setMeekPreventCycles(boolean meekPreventCycles) {
        this.meekPreventCycles = meekPreventCycles;
    }

    public Graph search(List<Node> targets) {
        long start = MillisecondTimes.timeMillis();
        this.numIndependenceTests = 0;
        this.ambiguousTriples = new HashSet<Triple>();
        if (targets == null) {
            throw new IllegalArgumentException("Null targets name not permitted");
        }
        this.targets = targets;
        this.logger.log("info", "Target = " + targets);
        this.maxRemainingAtDepth = new int[20];
        Arrays.fill(this.maxRemainingAtDepth, -1);
        this.logger.log("info", "targets = " + this.getTargets());
        EdgeListGraph graph = new EdgeListGraph();
        this.a = new HashSet<Node>();
        this.logger.log("info", "BEGINNING step 1 (prune targets).");
        for (Node target : this.getTargets()) {
            if (target == null) {
                throw new NullPointerException("Target not specified");
            }
            graph.addNode(target);
            this.constructFan(target, graph);
            this.logger.log("graph", "After step 1 (prune targets)" + graph);
            this.logger.log("graph", "After step 1 (prune targets)" + graph);
        }
        this.logger.log("info", "BEGINNING step 2 (prune PC).");
        if (this.findMb) {
            for (Node target : this.getTargets()) {
                block2: for (Node node : graph.getAdjacentNodes(target)) {
                    if (Thread.currentThread().isInterrupted()) break;
                    this.constructFan(node, graph);
                    block3: for (Node w : graph.getAdjacentNodes(node)) {
                        int[] choice;
                        if (Thread.currentThread().isInterrupted()) continue block2;
                        if (this.a.contains(w)) continue;
                        LinkedList<Node> _a = new LinkedList<Node>(this.a);
                        _a.retainAll(graph.getAdjacentNodes(w));
                        if (_a.size() > 1) continue;
                        ArrayList<Node> adjT = new ArrayList<Node>(graph.getAdjacentNodes(target));
                        SublistGenerator cg = new SublistGenerator(adjT.size(), this.depth);
                        while ((choice = cg.next()) != null && !Thread.currentThread().isInterrupted()) {
                            Set<Node> s = GraphUtils.asSet(choice, adjT);
                            if (!s.contains(node) || !this.independent(target, w, s)) continue;
                            graph.removeEdge(node, w);
                            continue block3;
                        }
                    }
                }
                this.logger.log("graph", "After step 2 (prune PC)" + graph);
                this.logger.log("info", "BEGINNING step 3 (prune PCPC).");
                for (Node node : graph.getAdjacentNodes(target)) {
                    for (Node w : graph.getAdjacentNodes(node)) {
                        if (this.getA().contains(w)) continue;
                        this.constructFan(w, graph);
                    }
                }
            }
        }
        this.logger.log("graph", "After step 3 (prune PCPC)" + graph);
        this.logger.log("info", "BEGINNING step 4 (PC Orient).");
        GraphSearchUtils.pcOrientbk(this.knowledge, graph, graph.getNodes());
        LinkedList<Node> _visited = new LinkedList<Node>(this.getA());
        this.orientUnshieldedTriples(this.knowledge, graph, this.getDepth(), _visited);
        MeekRules meekRules = new MeekRules();
        meekRules.setMeekPreventCycles(this.meekPreventCycles);
        meekRules.setKnowledge(this.knowledge);
        meekRules.orientImplied(graph);
        this.logger.log("graph", "After step 4 (PC Orient)" + graph);
        this.logger.log("info", "BEGINNING step 5 (Trim graph to {T} U PC U {Parents(Children(T))}).");
        if (this.findMb) {
            HashSet<Node> mb = new HashSet<Node>();
            for (Node n : graph.getNodes()) {
                for (Node t : targets) {
                    if (graph.isAdjacentTo(t, n)) {
                        mb.add(n);
                        continue;
                    }
                    for (Node m : graph.getChildren(t)) {
                        if (!graph.isParentOf(n, m)) continue;
                        mb.add(n);
                    }
                }
            }
            block10: for (Node n : graph.getNodes()) {
                for (Node t : targets) {
                    if (t != n) continue;
                    continue block10;
                }
                if (mb.contains(n)) continue;
                graph.removeNode(n);
            }
        } else {
            for (Edge edge : graph.getEdges()) {
                if (targets.contains(edge.getNode1()) || targets.contains(edge.getNode2())) continue;
                graph.removeEdge(edge);
            }
        }
        this.logger.log("graph", "After step 6 (Remove edges among P and P of C)" + graph);
        this.finishUp(start, graph);
        return graph;
    }

    @Override
    public Graph search() {
        this.numIndependenceTests = 0;
        this.ambiguousTriples = new HashSet<Triple>();
        this.maxRemainingAtDepth = new int[20];
        Arrays.fill(this.maxRemainingAtDepth, -1);
        EdgeListGraph graph = new EdgeListGraph();
        this.a = new HashSet<Node>();
        this.variables = this.test.getVariables();
        Node target = this.variables.get(0);
        graph.addNode(target);
        for (Node node : this.variables) {
            if (!graph.containsNode(node)) {
                graph.addNode(node);
            }
            this.constructFan(node, graph);
        }
        for (Node node : this.variables) {
            if (graph.containsNode(node)) continue;
            graph.addNode(node);
        }
        this.orientUnshieldedTriples(this.knowledge, graph, this.getDepth(), graph.getNodes());
        MeekRules meekRules = new MeekRules();
        meekRules.setMeekPreventCycles(this.meekPreventCycles);
        meekRules.setKnowledge(this.knowledge);
        meekRules.orientImplied(graph);
        return graph;
    }

    public Set<Triple> getAmbiguousTriples() {
        return new HashSet<Triple>(this.ambiguousTriples);
    }

    @Override
    public int getNumIndependenceTests() {
        return this.numIndependenceTests;
    }

    public List<Node> getTargets() {
        return this.targets;
    }

    public long getElapsedTime() {
        return this.elapsedTime;
    }

    @Override
    public String getAlgorithmName() {
        return "PC-MB";
    }

    public int getDepth() {
        return this.depth;
    }

    public void setDepth(int depth) {
        if (depth < 0) {
            depth = 1000;
        }
        this.depth = depth;
    }

    public Graph resultGraph() {
        return this.resultGraph;
    }

    @Override
    public Set<Node> findMb(Node target) {
        Graph graph = this.search(Collections.singletonList(target));
        HashSet<Node> nodes = new HashSet<Node>(graph.getNodes());
        nodes.remove(target);
        return nodes;
    }

    public IndependenceTest getTest() {
        return this.test;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    private Set<Node> getA() {
        return this.a;
    }

    private void constructFan(Node target, Graph graph) {
        this.addAllowableAssociates(target, graph);
        this.prune(target, graph);
    }

    private void addAllowableAssociates(Node v, Graph graph) {
        this.getA().add(v);
        int numAssociated = 0;
        for (Node w : this.variables) {
            if (this.getA().contains(w) || this.independent(v, w, new HashSet<Node>()) || this.edgeForbidden(v, w)) continue;
            this.addEdge(graph, w, v);
            ++numAssociated;
        }
        this.noteMaxAtDepth(0, numAssociated);
    }

    private void prune(Node node, Graph graph) {
        for (int depth = 1; depth <= this.getDepth(); ++depth) {
            if (graph.getAdjacentNodes(node).size() < depth) {
                return;
            }
            this.prune(node, graph, depth);
        }
    }

    private void prune(Node node, Graph graph, int depth) {
        this.logger.log("pruning", "Trying to remove edges adjacent to node " + node + ", depth = " + depth + ".");
        LinkedList<Node> a = new LinkedList<Node>(graph.getAdjacentNodes(node));
        block0: for (Node y : a) {
            int[] choice;
            List<Node> adjNode = new LinkedList<Node>(graph.getAdjacentNodes(node));
            adjNode.remove(y);
            adjNode = this.possibleParents(node, adjNode);
            if (adjNode.size() < depth) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjNode.size(), depth);
            while ((choice = cg.next()) != null && !Thread.currentThread().isInterrupted()) {
                Set<Node> condSet = GraphUtils.asSet(choice, adjNode);
                if (!this.independent(node, y, condSet) || this.edgeRequired(node, y)) continue;
                graph.removeEdge(node, y);
                if (!graph.getEdges(y).isEmpty() || this.getTargets().contains(y)) continue block0;
                graph.removeNode(y);
                continue block0;
            }
        }
        int numAdjacents = graph.getAdjacentNodes(node).size();
        this.noteMaxAtDepth(depth, numAdjacents);
    }

    private void finishUp(long start, Graph graph) {
        long stop = MillisecondTimes.timeMillis();
        this.elapsedTime = stop - start;
        double seconds = (double)this.elapsedTime / 1000.0;
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        this.logger.log("info", "PC-MB took " + nf.format(seconds) + " seconds.");
        this.logger.log("info", "Number of independence tests performed = " + this.getNumIndependenceTests());
        this.resultGraph = graph;
    }

    private boolean independent(Node v, Node w, Set<Node> z) {
        boolean independent = this.getTest().checkIndependence(v, w, z).isIndependent();
        ++this.numIndependenceTests;
        return independent;
    }

    private void addEdge(Graph graph, Node w, Node v) {
        if (!graph.containsNode(w)) {
            graph.addNode(w);
        }
        graph.addUndirectedEdge(v, w);
    }

    private void noteMaxAtDepth(int depth, int numAdjacents) {
        if (depth < this.maxRemainingAtDepth.length && numAdjacents > this.maxRemainingAtDepth[depth]) {
            this.maxRemainingAtDepth[depth] = numAdjacents;
        }
    }

    private void orientUnshieldedTriples(Knowledge knowledge, Graph graph, int depth, List<Node> nodes) {
        this.logger.log("info", "Starting Collider Orientation:");
        this.ambiguousTriples = new HashSet<Triple>();
        if (nodes == null) {
            nodes = graph.getNodes();
        }
        for (Node y : nodes) {
            int[] combination;
            ArrayList<Node> adjacentNodes = new ArrayList<Node>(graph.getAdjacentNodes(y));
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null && !Thread.currentThread().isInterrupted()) {
                Node z;
                Node x = (Node)adjacentNodes.get(combination[0]);
                if (graph.isAdjacentTo(x, z = (Node)adjacentNodes.get(combination[1]))) continue;
                TripleType type = this.getTripleType(graph, x, y, z, depth);
                if (type == TripleType.COLLIDER) {
                    if (!this.colliderAllowed(x, y, z, knowledge)) continue;
                    graph.setEndpoint(x, y, Endpoint.ARROW);
                    graph.setEndpoint(z, y, Endpoint.ARROW);
                    this.logger.log("tripleClassifications", "Collider oriented: " + Triple.pathString(graph, x, y, z));
                    continue;
                }
                if (type == TripleType.AMBIGUOUS) {
                    Triple triple = new Triple(x, y, z);
                    this.ambiguousTriples.add(triple);
                    graph.addAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
                    this.logger.log("tripleClassifications", "tripleClassifications: " + Triple.pathString(graph, x, y, z));
                    continue;
                }
                this.logger.log("tripleClassifications", "tripleClassifications: " + Triple.pathString(graph, x, y, z));
            }
        }
        this.logger.log("info", "Finishing Collider Orientation.");
    }

    private TripleType getTripleType(Graph graph, Node x, Node y, Node z, int depth) {
        Set<Node> condSet;
        int[] choice;
        ChoiceGenerator cg;
        int d;
        boolean existsSepsetContainingY = false;
        boolean existsSepsetNotContainingY = false;
        HashSet<Node> __nodes = new HashSet<Node>(graph.getAdjacentNodes(x));
        __nodes.remove(z);
        LinkedList<Node> _nodes = new LinkedList<Node>(__nodes);
        int _depth = depth;
        if (_depth == -1) {
            _depth = Integer.MAX_VALUE;
        }
        _depth = FastMath.min(_depth, _nodes.size());
        for (d = 0; d <= _depth && !Thread.currentThread().isInterrupted(); ++d) {
            cg = new ChoiceGenerator(_nodes.size(), d);
            while ((choice = cg.next()) != null && !Thread.currentThread().isInterrupted()) {
                condSet = GraphUtils.asSet(choice, _nodes);
                if (!this.independent(x, z, condSet)) continue;
                if (condSet.contains(y)) {
                    existsSepsetContainingY = true;
                    continue;
                }
                existsSepsetNotContainingY = true;
            }
        }
        __nodes = new HashSet<Node>(graph.getAdjacentNodes(z));
        __nodes.remove(x);
        _nodes = new LinkedList<Node>(__nodes);
        _depth = depth;
        if (_depth == -1) {
            _depth = Integer.MAX_VALUE;
        }
        _depth = FastMath.min(_depth, _nodes.size());
        for (d = 0; d <= _depth && !Thread.currentThread().isInterrupted(); ++d) {
            cg = new ChoiceGenerator(_nodes.size(), d);
            while ((choice = cg.next()) != null && !Thread.currentThread().isInterrupted()) {
                condSet = GraphUtils.asSet(choice, _nodes);
                if (!this.independent(x, z, condSet)) continue;
                if (condSet.contains(y)) {
                    existsSepsetContainingY = true;
                    continue;
                }
                existsSepsetNotContainingY = true;
            }
        }
        if (existsSepsetContainingY == existsSepsetNotContainingY) {
            return TripleType.AMBIGUOUS;
        }
        if (!existsSepsetNotContainingY) {
            return TripleType.NONCOLLIDER;
        }
        return TripleType.COLLIDER;
    }

    private boolean edgeForbidden(Node x1, Node x2) {
        return this.getKnowledge().isForbidden(x1.toString(), x2.toString()) && this.getKnowledge().isForbidden(x2.toString(), x1.toString());
    }

    private boolean edgeRequired(Node x1, Node x2) {
        return this.getKnowledge().isRequired(x1.toString(), x2.toString()) || this.getKnowledge().isRequired(x2.toString(), x1.toString());
    }

    private List<Node> possibleParents(Node node, List<Node> adjNode) {
        LinkedList<Node> possibleParents = new LinkedList<Node>();
        String _x = node.getName();
        for (Node z : adjNode) {
            String _z = z.getName();
            if (!this.possibleParentOf(_z, _x, this.knowledge)) continue;
            possibleParents.add(z);
        }
        return possibleParents;
    }

    private boolean possibleParentOf(String z, String x, Knowledge knowledge) {
        return !knowledge.isForbidden(z, x) && !knowledge.isRequired(x, z);
    }

    private boolean colliderAllowed(Node x, Node y, Node z, Knowledge knowledge) {
        return PcMb.isArrowheadAllowed1(x, y, knowledge) && PcMb.isArrowheadAllowed1(z, y, knowledge);
    }

    public void setVariables(List<Node> variables) {
        this.variables = variables;
    }

    public void setFindMb(boolean findMb) {
        this.findMb = findMb;
    }

    private static enum TripleType {
        COLLIDER,
        NONCOLLIDER,
        AMBIGUOUS;

    }
}

