/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.utils;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class BesPermutation {
    private final List<Node> variables;
    private final Score score;
    private Knowledge knowledge = new Knowledge();
    private boolean verbose = true;
    private int depth = -1;

    public BesPermutation(@NotNull Score score) {
        this.score = score;
        this.variables = score.getVariables();
    }

    @NotNull
    public List<Node> getVariables() {
        return this.variables;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private void buildIndexing(List<Node> nodes, Map<Node, Integer> hashIndices) {
        int i = -1;
        for (Node n : nodes) {
            hashIndices.put(n, ++i);
        }
    }

    public void bes(Graph graph, List<Node> order, List<Node> suborder) {
        HashMap<Node, Integer> hashIndices = new HashMap<Node, Integer>();
        ConcurrentSkipListSet<Arrow> sortedArrowsBack = new ConcurrentSkipListSet<Arrow>();
        ConcurrentHashMap<Edge, ArrowConfigBackward> arrowsMapBackward = new ConcurrentHashMap<Edge, ArrowConfigBackward>();
        int[] arrowIndex = new int[1];
        this.buildIndexing(order, hashIndices);
        this.reevaluateBackward(new HashSet<Node>(order), graph, hashIndices, arrowIndex, sortedArrowsBack, arrowsMapBackward);
        while (!sortedArrowsBack.isEmpty()) {
            Edge edge;
            Node y;
            Arrow arrow = (Arrow)sortedArrowsBack.first();
            sortedArrowsBack.remove(arrow);
            Node x = arrow.getA();
            if (!graph.isAdjacentTo(x, y = arrow.getB()) || (edge = graph.getEdge(x, y)).pointsTowards(x) || !this.getNaYX(x, y, graph).equals(arrow.getNaYX()) || !new HashSet<Node>(graph.getParents(y)).equals(new HashSet<Node>(arrow.getParents())) || !this.validDelete(x, y, arrow.getHOrT(), arrow.getNaYX(), graph, suborder)) continue;
            HashSet<Node> complement = new HashSet<Node>(arrow.getNaYX());
            complement.removeAll(arrow.getHOrT());
            double _bump = this.deleteEval(x, y, complement, arrow.parents, hashIndices);
            this.delete(x, y, arrow.getHOrT(), _bump, arrow.getNaYX(), graph);
            Set<Node> process = this.revertToCPDAG(graph);
            process.add(x);
            process.add(y);
            process.addAll(graph.getAdjacentNodes(x));
            process.addAll(graph.getAdjacentNodes(y));
            this.reevaluateBackward(new HashSet<Node>(process), graph, hashIndices, arrowIndex, sortedArrowsBack, arrowsMapBackward);
        }
    }

    private void delete(Node x, Node y, Set<Node> H, double bump, Set<Node> naYX, Graph graph) {
        Edge oldxy = graph.getEdge(x, y);
        HashSet<Node> diff = new HashSet<Node>(naYX);
        diff.removeAll(H);
        graph.removeEdge(oldxy);
        int numEdges = graph.getNumEdges();
        if (numEdges % 1000 == 0 && numEdges > 0) {
            System.out.println("Num edges (backwards) = " + numEdges);
        }
        if (this.verbose) {
            int cond = diff.size() + graph.getParents(y).size();
            String message = graph.getNumEdges() + ". DELETE " + x + " --> " + y + " H = " + H + " NaYX = " + naYX + " degree = " + GraphUtils.getDegree(graph) + " indegree = " + GraphUtils.getIndegree(graph) + " diff = " + diff + " (" + bump + ")  cond = " + cond;
            TetradLogger.getInstance().forceLogMessage(message);
        }
        for (Node h : H) {
            Edge oldxh;
            if (graph.isParentOf(h, y) || graph.isParentOf(h, x)) continue;
            Edge oldyh = graph.getEdge(y, h);
            graph.removeEdge(oldyh);
            graph.addEdge(Edges.directedEdge(y, h));
            if (this.verbose) {
                TetradLogger.getInstance().forceLogMessage("--- Directing " + oldyh + " to " + graph.getEdge(y, h));
            }
            if (!Edges.isUndirectedEdge(oldxh = graph.getEdge(x, h))) continue;
            graph.removeEdge(oldxh);
            graph.addEdge(Edges.directedEdge(x, h));
            if (!this.verbose) continue;
            TetradLogger.getInstance().forceLogMessage("--- Directing " + oldxh + " to " + graph.getEdge(x, h));
        }
    }

    private double deleteEval(Node x, Node y, Set<Node> complement, Set<Node> parents, Map<Node, Integer> hashIndices) {
        HashSet<Node> set = new HashSet<Node>(complement);
        set.addAll(parents);
        set.remove(x);
        return -this.scoreGraphChange(x, y, set, hashIndices);
    }

    private double scoreGraphChange(Node x, Node y, Set<Node> parents, Map<Node, Integer> hashIndices) {
        int xIndex = hashIndices.get(x);
        int yIndex = hashIndices.get(y);
        if (x == y) {
            throw new IllegalArgumentException();
        }
        if (parents.contains(y)) {
            throw new IllegalArgumentException();
        }
        int[] parentIndices = new int[parents.size()];
        int count = 0;
        for (Node parent : parents) {
            parentIndices[count++] = hashIndices.get(parent);
        }
        return this.score.localScoreDiff(xIndex, yIndex, parentIndices);
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    private Set<Node> revertToCPDAG(Graph graph) {
        MeekRules rules = new MeekRules();
        rules.setKnowledge(this.getKnowledge());
        rules.setMeekPreventCycles(true);
        boolean meekVerbose = false;
        rules.setVerbose(meekVerbose);
        return rules.orientImplied(graph);
    }

    private boolean validDelete(Node x, Node y, Set<Node> H, Set<Node> naYX, Graph graph, List<Node> suborder) {
        if (this.existsKnowledge()) {
            for (Node h : H) {
                if (this.knowledge.isForbidden(x.getName(), h.getName())) {
                    return false;
                }
                if (!this.knowledge.isForbidden(y.getName(), h.getName())) continue;
                return false;
            }
        }
        HashSet<Node> diff = new HashSet<Node>(naYX);
        diff.removeAll(H);
        if (!this.isClique(diff, graph)) {
            return false;
        }
        if (this.existsKnowledge()) {
            graph = new EdgeListGraph(graph);
            Edge oldxy = graph.getEdge(x, y);
            graph.removeEdge(oldxy);
            for (Node h : H) {
                if (graph.isParentOf(h, y) || graph.isParentOf(h, x)) continue;
                Edge oldyh = graph.getEdge(y, h);
                graph.removeEdge(oldyh);
                graph.addEdge(Edges.directedEdge(y, h));
                Edge oldxh = graph.getEdge(x, h);
                if (!Edges.isUndirectedEdge(oldxh)) continue;
                graph.removeEdge(oldxh);
                graph.addEdge(Edges.directedEdge(x, h));
            }
            this.revertToCPDAG(graph);
            ArrayList<Node> initialOrder = new ArrayList<Node>(suborder);
            Collections.reverse(initialOrder);
            while (!initialOrder.isEmpty()) {
                Node b;
                Iterator itr = initialOrder.iterator();
                do {
                    if (!itr.hasNext()) {
                        return false;
                    }
                    b = (Node)itr.next();
                } while (this.invalidSink(b, graph));
                graph.removeNode(b);
                itr.remove();
            }
        }
        return true;
    }

    private boolean invalidSink(Node x, Graph graph) {
        LinkedList<Node> neighbors = new LinkedList<Node>();
        for (Edge edge : graph.getEdges(x)) {
            if (edge.getDistalEndpoint(x) == Endpoint.ARROW) {
                return true;
            }
            if (edge.getProximalEndpoint(x) != Endpoint.TAIL) continue;
            neighbors.add(edge.getDistalNode(x));
        }
        while (!neighbors.isEmpty()) {
            Node y = (Node)neighbors.pop();
            for (Node z : neighbors) {
                if (graph.isAdjacentTo(y, z)) continue;
                return true;
            }
        }
        return false;
    }

    private boolean existsKnowledge() {
        return !this.knowledge.isEmpty();
    }

    private boolean isClique(Set<Node> nodes, Graph graph) {
        ArrayList<Node> _nodes = new ArrayList<Node>(nodes);
        for (int i = 0; i < _nodes.size(); ++i) {
            for (int j = i + 1; j < _nodes.size(); ++j) {
                if (graph.isAdjacentTo((Node)_nodes.get(i), (Node)_nodes.get(j))) continue;
                return false;
            }
        }
        return true;
    }

    private Set<Node> getNaYX(Node x, Node y, Graph graph) {
        List<Node> adj = graph.getAdjacentNodes(y);
        HashSet<Node> nayx = new HashSet<Node>();
        for (Node z : adj) {
            Edge yz;
            if (z == x || !Edges.isUndirectedEdge(yz = graph.getEdge(y, z)) || !graph.isAdjacentTo(z, x)) continue;
            nayx.add(z);
        }
        return nayx;
    }

    private void reevaluateBackward(Set<Node> toProcess, Graph graph, Map<Node, Integer> hashIndices, int[] arrowIndex, SortedSet<Arrow> sortedArrowsBack, Map<Edge, ArrowConfigBackward> arrowsMapBackward) {
        for (Node r : toProcess) {
            ArrayList<Node> adjacentNodes = new ArrayList<Node>(toProcess);
            class BackwardTask
            extends RecursiveTask<Boolean> {
                final Map<Edge, ArrowConfigBackward> arrowsMapBackward;
                private final Node r;
                private final List<Node> adj;
                private final Map<Node, Integer> hashIndices;
                private final int chunk;
                private final int from;
                private final int to;
                private final SortedSet<Arrow> sortedArrowsBack;
                final /* synthetic */ Graph val$graph;
                final /* synthetic */ int[] val$arrowIndex;
                final /* synthetic */ BesPermutation this$0;

                BackwardTask(Node r, List<Node> adj, int chunk, int from, int to, Map<Node, Integer> hashIndices, SortedSet<Arrow> sortedArrowsBack, Map<Edge, ArrowConfigBackward> arrowsMapBackward) {
                    this.this$0 = this$0;
                    this.val$graph = var10_10;
                    this.val$arrowIndex = var11_11;
                    this.adj = adj;
                    this.hashIndices = hashIndices;
                    this.chunk = chunk;
                    this.from = from;
                    this.to = to;
                    this.r = r;
                    this.sortedArrowsBack = sortedArrowsBack;
                    this.arrowsMapBackward = arrowsMapBackward;
                }

                @Override
                protected Boolean compute() {
                    if (this.to - this.from <= this.chunk) {
                        for (int _w = this.from; _w < this.to; ++_w) {
                            Node w = this.adj.get(_w);
                            Edge e = this.val$graph.getEdge(w, this.r);
                            if (e == null) continue;
                            if (e.pointsTowards(this.r)) {
                                this.this$0.calculateArrowsBackward(w, this.r, this.val$graph, this.arrowsMapBackward, this.hashIndices, this.val$arrowIndex, this.sortedArrowsBack);
                                continue;
                            }
                            if (e.pointsTowards(w)) {
                                this.this$0.calculateArrowsBackward(this.r, w, this.val$graph, this.arrowsMapBackward, this.hashIndices, this.val$arrowIndex, this.sortedArrowsBack);
                                continue;
                            }
                            this.this$0.calculateArrowsBackward(w, this.r, this.val$graph, this.arrowsMapBackward, this.hashIndices, this.val$arrowIndex, this.sortedArrowsBack);
                            this.this$0.calculateArrowsBackward(this.r, w, this.val$graph, this.arrowsMapBackward, this.hashIndices, this.val$arrowIndex, this.sortedArrowsBack);
                        }
                    } else {
                        int mid = (this.to - this.from) / 2;
                        ArrayList<BackwardTask> tasks = new ArrayList<BackwardTask>();
                        tasks.add(new BackwardTask(this.this$0, this.r, this.adj, this.chunk, this.from, this.from + mid, this.hashIndices, this.sortedArrowsBack, this.arrowsMapBackward, this.val$graph, this.val$arrowIndex));
                        tasks.add(new BackwardTask(this.this$0, this.r, this.adj, this.chunk, this.from + mid, this.to, this.hashIndices, this.sortedArrowsBack, this.arrowsMapBackward, this.val$graph, this.val$arrowIndex));
                        BackwardTask.invokeAll(tasks);
                    }
                    return true;
                }
            }
            ForkJoinPool.commonPool().invoke(new BackwardTask(this, r, adjacentNodes, this.getChunkSize(adjacentNodes.size()), 0, adjacentNodes.size(), hashIndices, sortedArrowsBack, arrowsMapBackward, graph, arrowIndex));
        }
    }

    private int getChunkSize(int n) {
        int chunk = n / Runtime.getRuntime().availableProcessors();
        if (chunk < 100) {
            chunk = 100;
        }
        return chunk;
    }

    private void calculateArrowsBackward(Node a, Node b, Graph graph, Map<Edge, ArrowConfigBackward> arrowsMapBackward, Map<Node, Integer> hashIndices, int[] arrowIndex, SortedSet<Arrow> sortedArrowsBack) {
        int[] choice;
        if (this.existsKnowledge() && !this.getKnowledge().noEdgeRequired(a.getName(), b.getName())) {
            return;
        }
        Set<Node> naYX = this.getNaYX(a, b, graph);
        HashSet<Node> parents = new HashSet<Node>(graph.getParents(b));
        ArrayList<Node> _naYX = new ArrayList<Node>(naYX);
        ArrowConfigBackward config = new ArrowConfigBackward(naYX, parents);
        ArrowConfigBackward storedConfig = arrowsMapBackward.get(Edges.directedEdge(a, b));
        if (storedConfig != null && storedConfig.equals(config)) {
            return;
        }
        arrowsMapBackward.put(Edges.directedEdge(a, b), new ArrowConfigBackward(naYX, parents));
        int _depth = FastMath.min(this.depth, _naYX.size());
        SublistGenerator gen = new SublistGenerator(_naYX.size(), _depth);
        Set<Node> maxComplement = null;
        double maxBump = Double.NEGATIVE_INFINITY;
        while ((choice = gen.next()) != null) {
            Set<Node> complement = GraphUtils.asSet(choice, _naYX);
            double _bump = this.deleteEval(a, b, complement, parents, hashIndices);
            if (!(_bump > maxBump)) continue;
            maxBump = _bump;
            maxComplement = complement;
        }
        if (maxBump > 0.0) {
            HashSet<Node> _H = new HashSet<Node>(naYX);
            _H.removeAll(maxComplement);
            this.addArrowBackward(a, b, _H, naYX, parents, maxBump, arrowIndex, sortedArrowsBack);
        }
    }

    private void addArrowBackward(Node a, Node b, Set<Node> hOrT, Set<Node> naYX, Set<Node> parents, double bump, int[] arrowIndex, SortedSet<Arrow> sortedArrowsBack) {
        int n = arrowIndex[0];
        arrowIndex[0] = n + 1;
        Arrow arrow = new Arrow(bump, a, b, hOrT, null, naYX, parents, n);
        sortedArrowsBack.add(arrow);
    }

    private static class Arrow
    implements Comparable<Arrow> {
        private final double bump;
        private final Node a;
        private final Node b;
        private final Set<Node> hOrT;
        private final Set<Node> naYX;
        private final Set<Node> parents;
        private final int index;
        private Set<Node> TNeighbors;

        Arrow(double bump, Node a, Node b, Set<Node> hOrT, Set<Node> capTorH, Set<Node> naYX, Set<Node> parents, int index) {
            this.bump = bump;
            this.a = a;
            this.b = b;
            this.setTNeighbors(capTorH);
            this.hOrT = hOrT;
            this.naYX = naYX;
            this.index = index;
            this.parents = parents;
        }

        public double getBump() {
            return this.bump;
        }

        public Node getA() {
            return this.a;
        }

        public Node getB() {
            return this.b;
        }

        Set<Node> getHOrT() {
            return this.hOrT;
        }

        Set<Node> getNaYX() {
            return this.naYX;
        }

        @Override
        public int compareTo(@NotNull Arrow arrow) {
            int compare = Double.compare(arrow.getBump(), this.getBump());
            if (compare == 0) {
                return Integer.compare(this.getIndex(), arrow.getIndex());
            }
            return compare;
        }

        public String toString() {
            return "Arrow<" + this.a + "->" + this.b + " bump = " + this.bump + " t/h = " + this.hOrT + " TNeighbors = " + this.getTNeighbors() + " parents = " + this.parents + " naYX = " + this.naYX + ">";
        }

        public int getIndex() {
            return this.index;
        }

        public Set<Node> getTNeighbors() {
            return this.TNeighbors;
        }

        public void setTNeighbors(Set<Node> TNeighbors) {
            this.TNeighbors = TNeighbors;
        }

        public Set<Node> getParents() {
            return this.parents;
        }
    }

    private static class ArrowConfigBackward {
        private Set<Node> nayx;
        private Set<Node> parents;

        public ArrowConfigBackward(Set<Node> nayx, Set<Node> parents) {
            this.setNayx(nayx);
            this.setParents(parents);
        }

        public void setNayx(Set<Node> nayx) {
            this.nayx = nayx;
        }

        public Set<Node> getParents() {
            return this.parents;
        }

        public void setParents(Set<Node> parents) {
            this.parents = parents;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ArrowConfigBackward that = (ArrowConfigBackward)o;
            return this.nayx.equals(that.nayx) && this.parents.equals(that.parents);
        }

        public int hashCode() {
            return Objects.hash(this.nayx, this.parents);
        }
    }
}

