/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.util.SublistGenerator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class LvBesJoe {
    private final List<Node> variables;
    private final Score score;
    private Knowledge knowledge = new Knowledge();
    private int depth = -1;
    private EdgeListGraph origGraph = null;

    public LvBesJoe(@NotNull Score score) {
        this.score = score;
        this.variables = score.getVariables();
    }

    @NotNull
    public List<Node> getVariables() {
        return this.variables;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public void setDepth(int depth) {
        if (depth < -1) {
            throw new IllegalArgumentException("Depth should be >= -1.");
        }
        this.depth = depth;
    }

    private void buildIndexing(List<Node> nodes, Map<Node, Integer> hashIndices) {
        int i = -1;
        for (Node n : nodes) {
            hashIndices.put(n, ++i);
        }
    }

    public void bes(Graph graph, List<Node> variables) {
        this.origGraph = new EdgeListGraph(graph);
        HashMap<Node, Integer> hashIndices = new HashMap<Node, Integer>();
        ConcurrentSkipListSet<Arrow> sortedArrowsBack = new ConcurrentSkipListSet<Arrow>();
        ConcurrentHashMap<Edge, ArrowConfigBackward> arrowsMapBackward = new ConcurrentHashMap<Edge, ArrowConfigBackward>();
        int[] arrowIndex = new int[1];
        this.buildIndexing(variables, hashIndices);
        this.reevaluateBackward(new HashSet<Node>(variables), graph, hashIndices, arrowIndex, sortedArrowsBack, arrowsMapBackward);
        while (!sortedArrowsBack.isEmpty()) {
            Arrow arrow = (Arrow)sortedArrowsBack.first();
            sortedArrowsBack.remove(arrow);
            Node x = arrow.getA();
            Node y = arrow.getB();
            if (!graph.isAdjacentTo(x, y)) continue;
            HashSet<Node> complement = new HashSet<Node>(arrow.getCommonAdjacents());
            complement.removeAll(arrow.getHOrT());
            double _bump = this.deleteEval(x, y, complement, arrow.parents, hashIndices);
            this.delete(x, y, arrow.getHOrT(), _bump, arrow.getCommonAdjacents(), graph);
        }
    }

    private void delete(Node x, Node y, Set<Node> H, double bump, Set<Node> ca, Graph graph) {
        Edge oldxy = graph.getEdge(x, y);
        HashSet<Node> diff = new HashSet<Node>(ca);
        diff.removeAll(H);
        graph.removeEdge(oldxy);
        int numEdges = graph.getNumEdges();
        if (numEdges % 1000 == 0 && numEdges > 0) {
            System.out.println("Num edges (backwards) = " + numEdges);
        }
        for (Node h : H) {
            if (graph.isAdjacentTo(x, h)) {
                graph.setEndpoint(x, h, Endpoint.ARROW);
            }
            if (!graph.isAdjacentTo(y, h)) continue;
            graph.setEndpoint(y, h, Endpoint.ARROW);
        }
    }

    private double deleteEval(Node x, Node y, Set<Node> complement, Set<Node> parents, Map<Node, Integer> hashIndices) {
        HashSet<Node> set = new HashSet<Node>(complement);
        set.addAll(parents);
        set.remove(x);
        return -this.scoreGraphChange(x, y, set, hashIndices);
    }

    private double scoreGraphChange(Node x, Node y, Set<Node> parents, Map<Node, Integer> hashIndices) {
        int xIndex = hashIndices.get(x);
        int yIndex = hashIndices.get(y);
        if (x == y) {
            throw new IllegalArgumentException();
        }
        if (parents.contains(y)) {
            throw new IllegalArgumentException();
        }
        int[] parentIndices = new int[parents.size()];
        int count = 0;
        for (Node parent : parents) {
            parentIndices[count++] = hashIndices.get(parent);
        }
        return this.score.localScoreDiff(xIndex, yIndex, parentIndices);
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    private Set<Node> getCommonAdjacents(Node x, Node y) {
        List<Node> ca = this.origGraph.getAdjacentNodes(x);
        ca.retainAll(this.origGraph.getAdjacentNodes(y));
        return new HashSet<Node>(ca);
    }

    private void reevaluateBackward(Set<Node> toProcess, Graph graph, Map<Node, Integer> hashIndices, int[] arrowIndex, SortedSet<Arrow> sortedArrowsBack, Map<Edge, ArrowConfigBackward> arrowsMapBackward) {
        for (Node r : toProcess) {
            ArrayList<Node> adjacentNodes = new ArrayList<Node>(toProcess);
            class BackwardTask
            extends RecursiveTask<Boolean> {
                private final Node r;
                private final List<Node> adj;
                private final Map<Node, Integer> hashIndices;
                private final int chunk;
                private final int from;
                private final int to;
                private final SortedSet<Arrow> sortedArrowsBack;
                final Map<Edge, ArrowConfigBackward> arrowsMapBackward;
                final /* synthetic */ Graph val$graph;
                final /* synthetic */ int[] val$arrowIndex;
                final /* synthetic */ LvBesJoe this$0;

                BackwardTask(Node r, List<Node> adj, int chunk, int from, int to, Map<Node, Integer> hashIndices, SortedSet<Arrow> sortedArrowsBack, Map<Edge, ArrowConfigBackward> arrowsMapBackward) {
                    this.this$0 = this$0;
                    this.val$graph = var10_10;
                    this.val$arrowIndex = var11_11;
                    this.adj = adj;
                    this.hashIndices = hashIndices;
                    this.chunk = chunk;
                    this.from = from;
                    this.to = to;
                    this.r = r;
                    this.sortedArrowsBack = sortedArrowsBack;
                    this.arrowsMapBackward = arrowsMapBackward;
                }

                @Override
                protected Boolean compute() {
                    if (this.to - this.from <= this.chunk) {
                        for (int _w = this.from; _w < this.to; ++_w) {
                            Node w = this.adj.get(_w);
                            Edge e = this.val$graph.getEdge(w, this.r);
                            if (e == null) continue;
                            this.this$0.calculateArrowsBackward(w, this.r, this.val$graph, this.hashIndices, this.val$arrowIndex, this.sortedArrowsBack);
                            this.this$0.calculateArrowsBackward(this.r, w, this.val$graph, this.hashIndices, this.val$arrowIndex, this.sortedArrowsBack);
                        }
                    } else {
                        int mid = (this.to - this.from) / 2;
                        ArrayList<BackwardTask> tasks = new ArrayList<BackwardTask>();
                        tasks.add(new BackwardTask(this.this$0, this.r, this.adj, this.chunk, this.from, this.from + mid, this.hashIndices, this.sortedArrowsBack, this.arrowsMapBackward, this.val$graph, this.val$arrowIndex));
                        tasks.add(new BackwardTask(this.this$0, this.r, this.adj, this.chunk, this.from + mid, this.to, this.hashIndices, this.sortedArrowsBack, this.arrowsMapBackward, this.val$graph, this.val$arrowIndex));
                        BackwardTask.invokeAll(tasks);
                    }
                    return true;
                }
            }
            ForkJoinPool.commonPool().invoke(new BackwardTask(this, r, adjacentNodes, this.getChunkSize(adjacentNodes.size()), 0, adjacentNodes.size(), hashIndices, sortedArrowsBack, arrowsMapBackward, graph, arrowIndex));
        }
    }

    private int getChunkSize(int n) {
        return 5;
    }

    private void calculateArrowsBackward(Node a, Node b, Graph graph, Map<Node, Integer> hashIndices, int[] arrowIndex, SortedSet<Arrow> sortedArrowsBack) {
        int[] choice;
        Set<Node> ca = this.getCommonAdjacents(a, b);
        List<Node> parents = this.origGraph.getAdjacentNodes(b);
        for (Node n : ca) {
            parents.remove(n);
        }
        for (Node n : ca) {
            parents.remove(n);
        }
        ArrayList<Node> _ca = new ArrayList<Node>(ca);
        int _depth = FastMath.min(this.depth == -1 ? 100000 : this.depth, _ca.size());
        SublistGenerator gen = new SublistGenerator(_ca.size(), _depth);
        Set<Node> maxComplement = null;
        double maxBump = Double.NEGATIVE_INFINITY;
        Set<Node> maxParents = new HashSet<Node>();
        while ((choice = gen.next()) != null) {
            int[] choice2;
            Set<Node> complement = GraphUtils.asSet(choice, _ca);
            SublistGenerator gen2 = new SublistGenerator(parents.size(), -1);
            while ((choice2 = gen2.next()) != null) {
                Set<Node> p = GraphUtils.asSet(choice2, parents);
                double _bump = this.deleteEval(a, b, complement, p, hashIndices);
                if (!(_bump > maxBump)) continue;
                maxBump = _bump;
                maxComplement = complement;
                maxParents = p;
            }
        }
        if (maxBump > 0.0) {
            HashSet<Node> _H = new HashSet<Node>(ca);
            _H.removeAll(maxComplement);
            this.addArrowBackward(a, b, _H, ca, maxParents, maxBump, arrowIndex, sortedArrowsBack);
        }
    }

    private void addArrowBackward(Node a, Node b, Set<Node> hOrT, Set<Node> naYX, Set<Node> parents, double bump, int[] arrowIndex, SortedSet<Arrow> sortedArrowsBack) {
        int n = arrowIndex[0];
        arrowIndex[0] = n + 1;
        Arrow arrow = new Arrow(bump, a, b, hOrT, null, naYX, parents, n);
        sortedArrowsBack.add(arrow);
    }

    private static class Arrow
    implements Comparable<Arrow> {
        private final double bump;
        private final Node a;
        private final Node b;
        private final Set<Node> hOrT;
        private final Set<Node> naYX;
        private final Set<Node> parents;
        private final int index;
        private Set<Node> TNeighbors;

        Arrow(double bump, Node a, Node b, Set<Node> hOrT, Set<Node> capTorH, Set<Node> naYX, Set<Node> parents, int index) {
            this.bump = bump;
            this.a = a;
            this.b = b;
            this.setTNeighbors(capTorH);
            this.hOrT = hOrT;
            this.naYX = naYX;
            this.index = index;
            this.parents = parents;
        }

        public double getBump() {
            return this.bump;
        }

        public Node getA() {
            return this.a;
        }

        public Node getB() {
            return this.b;
        }

        Set<Node> getHOrT() {
            return this.hOrT;
        }

        Set<Node> getCommonAdjacents() {
            return this.naYX;
        }

        @Override
        public int compareTo(@NotNull Arrow arrow) {
            int compare = Double.compare(arrow.getBump(), this.getBump());
            if (compare == 0) {
                return Integer.compare(this.getIndex(), arrow.getIndex());
            }
            return compare;
        }

        public String toString() {
            return "Arrow<" + this.a + "->" + this.b + " bump = " + this.bump + " t/h = " + this.hOrT + " TNeighbors = " + this.getTNeighbors() + " parents = " + this.parents + " naYX = " + this.naYX + ">";
        }

        public int getIndex() {
            return this.index;
        }

        public Set<Node> getTNeighbors() {
            return this.TNeighbors;
        }

        public void setTNeighbors(Set<Node> TNeighbors) {
            this.TNeighbors = TNeighbors;
        }

        public Set<Node> getParents() {
            return this.parents;
        }
    }

    private static class ArrowConfigBackward {
        private Set<Node> ca;
        private Set<Node> parents;

        public ArrowConfigBackward(Set<Node> nayx, Set<Node> parents) {
            this.setCa(nayx);
            this.setParents(parents);
        }

        public void setCa(Set<Node> ca) {
            this.ca = ca;
        }

        public Set<Node> getParents() {
            return this.parents;
        }

        public void setParents(Set<Node> parents) {
            this.parents = parents;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ArrowConfigBackward that = (ArrowConfigBackward)o;
            return this.ca.equals(that.ca) && this.parents.equals(that.parents);
        }

        public int hashCode() {
            return Objects.hash(this.ca, this.parents);
        }
    }
}

