/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Bes;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.TeyssierScorer2;
import edu.cmu.tetrad.util.MillisecondTimes;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import org.jetbrains.annotations.NotNull;

public class BossMB2 {
    private final List<Node> variables;
    private final Score score;
    private Knowledge knowledge = new Knowledge();
    private long start;
    private boolean verbose = true;
    private int depth = 4;
    private boolean findMb = false;
    private final List<Graph> graphs = new ArrayList<Graph>();

    public BossMB2(@NotNull Score score) {
        this.score = score;
        this.variables = new ArrayList<Node>(score.getVariables());
    }

    public Graph search(@NotNull List<Node> order) {
        long start = MillisecondTimes.timeMillis();
        order = new ArrayList<Node>(order);
        TeyssierScorer2 scorer0 = new TeyssierScorer2(this.score);
        scorer0.setKnowledge(this.knowledge);
        scorer0.score(order);
        this.start = MillisecondTimes.timeMillis();
        this.makeValidKnowledgeOrder(order);
        System.out.println("Initial score = " + scorer0.score() + " Elapsed = " + (double)(MillisecondTimes.timeMillis() - start) / 1000.0 + " s");
        ArrayList<Node> _targets = new ArrayList<Node>(scorer0.getPi());
        Collections.sort(_targets);
        EdgeListGraph combinedGraph = new EdgeListGraph(_targets);
        ArrayList<MyTask2> tasks = new ArrayList<MyTask2>();
        try {
            for (Node node : _targets) {
                tasks.add(new MyTask2(scorer0, node));
            }
            List futures = ForkJoinPool.commonPool().invokeAll(tasks);
            for (Future future : futures) {
                Graph g = (Graph)future.get();
                this.graphs.add(g);
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        for (Graph graph : this.graphs) {
            for (Edge e : graph.getEdges()) {
                combinedGraph.addEdge(e);
            }
        }
        for (Edge edge : combinedGraph.getEdges()) {
            if (edge.isDirected()) {
                if (!combinedGraph.containsEdge(edge) || !combinedGraph.containsEdge(edge.reverse())) continue;
                combinedGraph.removeEdge(edge);
                combinedGraph.removeEdge(edge.reverse());
                continue;
            }
            if (!Edges.isUndirectedEdge(edge)) continue;
            Node node = edge.getNode1();
            Node n2 = edge.getNode2();
            List<Edge> edges = combinedGraph.getEdges(node, n2);
            for (Edge _e : edges) {
                if (edge == _e) continue;
                combinedGraph.removeEdge(edge);
            }
        }
        long stop = MillisecondTimes.timeMillis();
        System.out.println("Elapsed time = " + (double)(stop - start) / 1000.0 + " s");
        return combinedGraph;
    }

    private Graph targetVisit(TeyssierScorer2 scorer0, Node target) throws InterruptedException {
        List<Node> pi1;
        TeyssierScorer2 scorer = new TeyssierScorer2(scorer0);
        List<Node> pi2 = scorer.getPi();
        do {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException();
            }
            pi1 = pi2;
            this.betterMutationBossTuck(scorer, Collections.singletonList(target));
        } while (!pi1.equals(pi2 = this.besOrder(scorer)));
        scorer.score(pi2);
        Graph graph = scorer.getGraph(true);
        if (this.findMb) {
            HashSet<Node> mb = new HashSet<Node>();
            for (Node n : graph.getNodes()) {
                if (graph.isAdjacentTo(target, n)) {
                    mb.add(n);
                    continue;
                }
                for (Node m : graph.getChildren(target)) {
                    if (!graph.isParentOf(n, m)) continue;
                    mb.add(n);
                }
            }
            for (Node n : graph.getNodes()) {
                if (target == n || mb.contains(n)) continue;
                graph.removeNode(n);
            }
        } else {
            for (Edge e : graph.getEdges()) {
                Node n1 = e.getNode1();
                Node n2 = e.getNode2();
                if (!graph.isAdjacentTo(target, n1) || !graph.isAdjacentTo(target, n2)) continue;
                graph.removeEdge(e);
            }
        }
        System.out.println("Graph for " + target + " = " + graph);
        System.out.println();
        return graph;
    }

    public void setFindMb(boolean findMb) {
        this.findMb = findMb;
    }

    public void betterMutationBossTuck(@NotNull TeyssierScorer2 scorer, List<Node> targets) throws InterruptedException {
        List<Node> p2;
        List<Node> p1;
        do {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException();
            }
            p1 = scorer.getPi();
            Graph g = scorer.getGraph(false);
            HashSet<Node> keep = new HashSet<Node>(targets);
            for (Node n : targets) {
                keep.addAll(g.getAdjacentNodes(n));
            }
            if (this.findMb) {
                for (Node k : new HashSet<Node>(keep)) {
                    keep.addAll(g.getAdjacentNodes(k));
                }
            }
            ArrayList<Node> _pi = new ArrayList<Node>();
            for (Node n : scorer.getPi()) {
                if (!keep.contains(n)) continue;
                _pi.add(n);
            }
            double sp = scorer.score(_pi);
            scorer.bookmark();
            if (this.verbose) {
                System.out.println("After snips: # vars = " + scorer.getPi().size() + " # Edges = " + scorer.getNumEdges() + " Score = " + scorer.score() + " (betterMutation) Elapsed " + (double)(MillisecondTimes.timeMillis() - this.start) / 1000.0 + " s" + " order = " + scorer.getPi());
            }
            for (Node x : scorer.getPi()) {
                if (Thread.currentThread().isInterrupted()) {
                    throw new InterruptedException();
                }
                int i = scorer.index(x);
                for (int j = i - 1; j >= 0; --j) {
                    if (Thread.currentThread().isInterrupted()) {
                        throw new InterruptedException();
                    }
                    if (!scorer.tuck(x, j)) continue;
                    if ((double)scorer.score() > sp && !this.violatesKnowledge(scorer.getPi())) {
                        sp = scorer.score();
                        scorer.bookmark();
                        if (!this.verbose) continue;
                        System.out.println("# vars = " + scorer.getPi().size() + " # Edges = " + scorer.getNumEdges() + " Score = " + scorer.score() + " (betterMutation) Elapsed " + (double)(MillisecondTimes.timeMillis() - this.start) / 1000.0 + " s");
                        continue;
                    }
                    scorer.goToBookmark();
                }
            }
        } while (!p1.equals(p2 = scorer.getPi()));
    }

    public List<Node> besOrder(TeyssierScorer2 scorer) {
        Graph graph = scorer.getGraph(true);
        Bes bes = new Bes(this.score);
        bes.setDepth(this.depth);
        bes.setVerbose(this.verbose);
        bes.setKnowledge(this.knowledge);
        bes.bes(graph, scorer.getPi());
        return graph.paths().validOrder(scorer.getPi(), true);
    }

    private void makeValidKnowledgeOrder(List<Node> order) {
        if (!this.knowledge.isEmpty()) {
            order.sort((o1, o2) -> {
                if (o1.getName().equals(o2.getName())) {
                    return 0;
                }
                if (this.knowledge.isRequired(o1.getName(), o2.getName())) {
                    return 1;
                }
                if (this.knowledge.isRequired(o2.getName(), o1.getName())) {
                    return -1;
                }
                if (this.knowledge.isForbidden(o2.getName(), o1.getName())) {
                    return -1;
                }
                if (this.knowledge.isForbidden(o1.getName(), o2.getName())) {
                    return 1;
                }
                return 1;
            });
        }
    }

    @NotNull
    public List<Graph> getGraphs() {
        return this.graphs;
    }

    public List<Node> getVariables() {
        return this.variables;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public void setDepth(int depth) {
        if (depth < -1) {
            throw new IllegalArgumentException("Depth should be >= -1.");
        }
        this.depth = depth;
    }

    private boolean violatesKnowledge(List<Node> order) {
        if (!this.knowledge.isEmpty()) {
            for (int i = 0; i < order.size(); ++i) {
                for (int j = i + 1; j < order.size(); ++j) {
                    if (!this.knowledge.isForbidden(order.get(i).getName(), order.get(j).getName())) continue;
                    return true;
                }
            }
        }
        return false;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    class MyTask2
    implements Callable<Graph> {
        final TeyssierScorer2 scorer0;
        final Node target;

        MyTask2(TeyssierScorer2 scorer0, Node target) {
            this.scorer0 = scorer0;
            this.target = target;
        }

        @Override
        public Graph call() throws InterruptedException {
            return BossMB2.this.targetVisit(this.scorer0, this.target);
        }
    }

    public static enum AlgType {
        BOSS_OLD,
        BOSS;

    }
}

