/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Triple;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.MbUtils;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

public final class Cefs {
    private IndependenceTest test;
    private List<Node> variables;
    private Node target;
    private int depth;
    private Graph resultGraph;
    private long numIndependenceTests;
    private int[] maxRemainingAtDepth;
    private Node[] maxVariableAtDepth;
    private Set<Node> visited;
    private long elapsedTime;
    private Dag trueMb;
    private Knowledge knowledge = new Knowledge();
    private Set<Triple> allTriples;
    private Set<Triple> colliderTriples;
    private Set<Triple> noncolliderTriples;
    private Set<Triple> ambiguousTriples;
    private Graph graph;
    private boolean aggressivelyPreventCycles = false;
    private TetradLogger logger = TetradLogger.getInstance();

    public Cefs(IndependenceTest test, int depth) {
        if (test == null) {
            throw new NullPointerException();
        }
        if (depth == -1) {
            depth = Integer.MAX_VALUE;
        }
        if (depth < 0) {
            throw new IllegalArgumentException("Depth must be >= -1: " + depth);
        }
        this.test = test;
        this.depth = depth;
        this.variables = test.getVariables();
    }

    public boolean isAggressivelyPreventCycles() {
        return this.aggressivelyPreventCycles;
    }

    public void setAggressivelyPreventCycles(boolean aggressivelyPreventCycles) {
        this.aggressivelyPreventCycles = aggressivelyPreventCycles;
    }

    public Graph search(String targetName) {
        long start = System.currentTimeMillis();
        this.numIndependenceTests = 0L;
        this.allTriples = new HashSet<Triple>();
        this.ambiguousTriples = new HashSet<Triple>();
        this.colliderTriples = new HashSet<Triple>();
        this.noncolliderTriples = new HashSet<Triple>();
        if (targetName == null) {
            throw new IllegalArgumentException("Null target name not permitted");
        }
        this.target = this.getVariableForName(targetName);
        this.maxRemainingAtDepth = new int[20];
        this.maxVariableAtDepth = new Node[20];
        Arrays.fill(this.maxRemainingAtDepth, -1);
        Arrays.fill(this.maxVariableAtDepth, null);
        TetradLogger.getInstance().log("info", "target = " + this.getTarget());
        EdgeListGraph graph = new EdgeListGraph();
        this.visited = new HashSet<Node>();
        TetradLogger.getInstance().log("info", "BEGINNING step 1 (prune target).");
        graph.addNode(this.getTarget());
        this.constructFan(this.getTarget(), graph);
        TetradLogger.getInstance().log("graph", "After step 1 (prune target)" + graph);
        TetradLogger.getInstance().log("info", "BEGINNING step 2 (prune PC).");
        for (Node v : graph.getAdjacentNodes(this.getTarget())) {
            this.constructFan(v, graph);
        }
        TetradLogger.getInstance().log("graph", "After step 2 (prune PC)" + graph);
        TetradLogger.getInstance().log("info", "BEGINNING step 4 (PC Orient).");
        SearchGraphUtils.pcOrientbk(this.knowledge, graph, graph.getNodes());
        LinkedList<Node> _visited = new LinkedList<Node>(this.getVisited());
        this.orientUnshieldedTriples(this.knowledge, graph, this.getTest(), this.getDepth(), _visited);
        MeekRules meekRules = new MeekRules();
        meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
        meekRules.setKnowledge(this.knowledge);
        meekRules.orientImplied(graph);
        TetradLogger.getInstance().log("graph", "After step 4 (PC Orient)" + graph);
        MbUtils.trimToAdjacents(graph, this.target);
        TetradLogger.getInstance().log("graph", "After step 6 (Remove edges among P and P of C)" + graph);
        this.finishUp(start, graph);
        this.logger.log("graph", "\nReturning this graph: " + graph);
        this.graph = graph;
        return graph;
    }

    public Set<Triple> getAmbiguousTriples() {
        return new HashSet<Triple>(this.ambiguousTriples);
    }

    public Set<Triple> getColliderTriples() {
        return this.colliderTriples;
    }

    public Set<Triple> getNoncolliderTriples() {
        return this.noncolliderTriples;
    }

    public long getNumIndependenceTests() {
        return this.numIndependenceTests;
    }

    public Node getTarget() {
        return this.target;
    }

    public double getElapsedTime() {
        return this.elapsedTime;
    }

    public Dag getTrueMb() {
        return this.trueMb;
    }

    public void setTrueMb(Dag trueMb) {
        this.trueMb = trueMb;
    }

    public int getDepth() {
        return this.depth;
    }

    public Graph resultGraph() {
        return this.resultGraph;
    }

    public List<Node> findMb(String targetName) {
        Graph graph = this.search(targetName);
        List<Node> nodes = graph.getNodes();
        nodes.remove(this.target);
        return nodes;
    }

    public IndependenceTest getTest() {
        return this.test;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
    }

    public Graph getGraph() {
        return this.graph;
    }

    private Set<Node> getVisited() {
        return this.visited;
    }

    private void constructFan(Node target, Graph graph) {
        this.addAllowableAssociates(target, graph);
        this.prune(target, graph);
    }

    private void addAllowableAssociates(Node v, Graph graph) {
        this.getVisited().add(v);
        int numAssociated = 0;
        for (Node w : this.variables) {
            if (this.getVisited().contains(w) || graph.containsNode(w) && graph.isAdjacentTo(v, w) || this.independent(v, w, new LinkedList<Node>()) || this.edgeForbidden(v, w)) continue;
            this.addEdge(graph, w, v);
            ++numAssociated;
        }
        this.noteMaxAtDepth(0, numAssociated, v);
    }

    private void prune(Node node, Graph graph) {
        for (int depth = 1; depth <= this.getDepth(); ++depth) {
            if (graph.getAdjacentNodes(node).size() < depth) {
                return;
            }
            this.prune(node, graph, depth);
        }
    }

    private void prune(Node node, Graph graph, int depth) {
        TetradLogger.getInstance().log("pruning", "Trying to remove edges adjacent to node " + node + ", depth = " + depth + ".");
        LinkedList<Node> a = new LinkedList<Node>(graph.getAdjacentNodes(node));
        block0: for (Node y : a) {
            int[] choice;
            List<Node> adjNode = new LinkedList<Node>(graph.getAdjacentNodes(node));
            adjNode.remove(y);
            adjNode = this.possibleParents(node, adjNode);
            if (adjNode.size() < depth) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjNode.size(), depth);
            while ((choice = cg.next()) != null) {
                List<Node> condSet = SearchGraphUtils.asList(choice, adjNode);
                if (!this.independent(node, y, condSet) || this.edgeRequired(node, y)) continue;
                graph.removeEdge(node, y);
                if (!graph.getEdges(y).isEmpty() || y == this.getTarget()) continue block0;
                graph.removeNode(y);
                continue block0;
            }
        }
        int numAdjacents = graph.getAdjacentNodes(node).size();
        this.noteMaxAtDepth(depth, numAdjacents, node);
    }

    private void finishUp(long start, Graph graph) {
        long stop = System.currentTimeMillis();
        this.elapsedTime = stop - start;
        double seconds = (double)this.elapsedTime / 1000.0;
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        TetradLogger.getInstance().log("info", "MB fan search took " + nf.format(seconds) + " seconds.");
        TetradLogger.getInstance().log("info", "Number of independence tests performed = " + this.getNumIndependenceTests());
        this.resultGraph = graph;
    }

    private boolean independent(Node v, Node w, List<Node> z) {
        boolean independent = this.getTest().isIndependent(v, w, z);
        if (independent && this.getTrueMb() != null) {
            Edge edge;
            Node node1 = this.getTrueMb().getNode(v.getName());
            Node node2 = this.getTrueMb().getNode(w.getName());
            if (node1 != null && node2 != null && (edge = this.getTrueMb().getEdge(node1, node2)) != null) {
                NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
                System.out.println("Edge removed that was in the true MB:");
                System.out.println("\tTrue edge = " + edge);
                System.out.println("\t" + SearchLogUtils.independenceFact(v, w, z) + "\tp = " + nf.format(this.getTest().getPValue()));
            }
        }
        ++this.numIndependenceTests;
        return independent;
    }

    private void addEdge(Graph graph, Node w, Node v) {
        if (!graph.containsNode(w)) {
            graph.addNode(w);
        }
        graph.addUndirectedEdge(v, w);
    }

    private Node getVariableForName(String targetVariableName) {
        Node target = null;
        for (Node V : this.variables) {
            if (!V.getName().equals(targetVariableName)) continue;
            target = V;
            break;
        }
        if (target == null) {
            throw new IllegalArgumentException("Target variable not in dataset: " + targetVariableName);
        }
        return target;
    }

    private void noteMaxAtDepth(int depth, int numAdjacents, Node to) {
        if (depth < this.maxRemainingAtDepth.length && numAdjacents > this.maxRemainingAtDepth[depth]) {
            this.maxRemainingAtDepth[depth] = numAdjacents;
            this.maxVariableAtDepth[depth] = to;
        }
    }

    private void orientUnshieldedTriples(Knowledge knowledge, Graph graph, IndependenceTest test, int depth, List<Node> nodes) {
        TetradLogger.getInstance().log("info", "Starting Collider Orientation:");
        this.colliderTriples = new HashSet<Triple>();
        this.noncolliderTriples = new HashSet<Triple>();
        this.ambiguousTriples = new HashSet<Triple>();
        if (nodes == null) {
            nodes = graph.getNodes();
        }
        for (Node y : nodes) {
            int[] combination;
            List<Node> adjacentNodes = graph.getAdjacentNodes(y);
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null) {
                Node z;
                Node x = adjacentNodes.get(combination[0]);
                if (graph.isAdjacentTo(x, z = adjacentNodes.get(combination[1]))) continue;
                this.allTriples.add(new Triple(x, y, z));
                TripleType type = this.getTripleType(graph, x, y, z, test, depth);
                if (type == TripleType.COLLIDER) {
                    if (this.colliderAllowed(x, y, z, knowledge)) {
                        graph.setEndpoint(x, y, Endpoint.ARROW);
                        graph.setEndpoint(z, y, Endpoint.ARROW);
                        this.logger.log("tripleClassifications", "Collider oriented: " + Triple.pathString(graph, x, y, z));
                    }
                    this.colliderTriples.add(new Triple(x, y, z));
                    continue;
                }
                if (type == TripleType.AMBIGUOUS) {
                    Triple triple = new Triple(x, y, z);
                    this.ambiguousTriples.add(triple);
                    graph.addAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
                    this.logger.log("tripleClassifications", "Ambiguous triple oriented: " + Triple.pathString(graph, x, y, z));
                    continue;
                }
                this.noncolliderTriples.add(new Triple(x, y, z));
                this.logger.log("tripleClassifications", "Noncollider oriented: " + Triple.pathString(graph, x, y, z));
            }
        }
        TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
    }

    private TripleType getTripleType(Graph graph, Node x, Node y, Node z, IndependenceTest test, int depth) {
        List<Node> condSet;
        int[] choice;
        ChoiceGenerator cg;
        int d;
        boolean existsSepsetContainingY = false;
        boolean existsSepsetNotContainingY = false;
        HashSet<Node> __nodes = new HashSet<Node>(graph.getAdjacentNodes(x));
        __nodes.remove(z);
        LinkedList<Node> _nodes = new LinkedList<Node>(__nodes);
        int _depth = depth;
        if (_depth == -1) {
            _depth = Integer.MAX_VALUE;
        }
        _depth = Math.min(_depth, _nodes.size());
        for (d = 0; d <= _depth; ++d) {
            cg = new ChoiceGenerator(_nodes.size(), d);
            while ((choice = cg.next()) != null) {
                condSet = Cefs.asList(choice, _nodes);
                if (!test.isIndependent(x, z, condSet)) continue;
                if (condSet.contains(y)) {
                    existsSepsetContainingY = true;
                    continue;
                }
                existsSepsetNotContainingY = true;
            }
        }
        __nodes = new HashSet<Node>(graph.getAdjacentNodes(z));
        __nodes.remove(x);
        _nodes = new LinkedList<Node>(__nodes);
        _depth = depth;
        if (_depth == -1) {
            _depth = Integer.MAX_VALUE;
        }
        _depth = Math.min(_depth, _nodes.size());
        for (d = 0; d <= _depth; ++d) {
            cg = new ChoiceGenerator(_nodes.size(), d);
            while ((choice = cg.next()) != null) {
                condSet = Cefs.asList(choice, _nodes);
                if (!test.isIndependent(x, z, condSet)) continue;
                if (condSet.contains(y)) {
                    existsSepsetContainingY = true;
                    continue;
                }
                existsSepsetNotContainingY = true;
            }
        }
        if (existsSepsetContainingY == existsSepsetNotContainingY) {
            return TripleType.AMBIGUOUS;
        }
        if (!existsSepsetNotContainingY) {
            return TripleType.NONCOLLIDER;
        }
        return TripleType.COLLIDER;
    }

    private boolean edgeForbidden(Node x1, Node x2) {
        return this.getKnowledge().edgeForbidden(((Object)x1).toString(), ((Object)x2).toString()) && this.getKnowledge().edgeForbidden(((Object)x2).toString(), ((Object)x1).toString());
    }

    private boolean edgeRequired(Node x1, Node x2) {
        return this.getKnowledge().edgeRequired(((Object)x1).toString(), ((Object)x2).toString()) || this.getKnowledge().edgeRequired(((Object)x2).toString(), ((Object)x1).toString());
    }

    private List<Node> possibleParents(Node node, List<Node> adjNode) {
        LinkedList<Node> possibleParents = new LinkedList<Node>();
        String _x = node.getName();
        for (Node z : adjNode) {
            String _z = z.getName();
            if (!this.possibleParentOf(_z, _x, this.knowledge)) continue;
            possibleParents.add(z);
        }
        return possibleParents;
    }

    private boolean possibleParentOf(String z, String x, Knowledge knowledge) {
        return !knowledge.edgeForbidden(z, x) && !knowledge.edgeRequired(x, z);
    }

    private static List<Node> asList(int[] indices, List<Node> nodes) {
        LinkedList<Node> list = new LinkedList<Node>();
        for (int i : indices) {
            list.add(nodes.get(i));
        }
        return list;
    }

    private boolean colliderAllowed(Node x, Node y, Node z, Knowledge knowledge) {
        return Cefs.isArrowpointAllowed1(x, y, knowledge) && Cefs.isArrowpointAllowed1(z, y, knowledge);
    }

    private static boolean isArrowpointAllowed1(Node from, Node to, Knowledge knowledge) {
        if (knowledge == null) {
            return true;
        }
        return !knowledge.edgeRequired(((Object)to).toString(), ((Object)from).toString()) && !knowledge.edgeForbidden(((Object)from).toString(), ((Object)to).toString());
    }

    private static enum TripleType {
        COLLIDER,
        NONCOLLIDER,
        AMBIGUOUS;

    }
}

