/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.utils;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodePair;
import edu.cmu.tetrad.graph.OrderedPair;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.score.GraphScore;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.GrowShrinkTree;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class TeyssierScorer {
    private final List<Node> variables;
    private final IndependenceTest test;
    private final Score score;
    private final Map<Object, ArrayList<Node>> bookmarkedOrders = new HashMap<Object, ArrayList<Node>>();
    private final Map<Object, ArrayList<Pair>> bookmarkedScores = new HashMap<Object, ArrayList<Pair>>();
    private final Map<Object, Map<Node, Integer>> bookmarkedOrderHashes = new HashMap<Object, Map<Node, Integer>>();
    private final Map<Object, Double> bookmarkedRunningScores = new HashMap<Object, Double>();
    private final Map<Node, GrowShrinkTree> trees = new HashMap<Node, GrowShrinkTree>();
    private ArrayList<Node> pi;
    private Map<Node, Integer> orderHash = new HashMap<Node, Integer>();
    private ArrayList<Set<Node>> prefixes;
    private ArrayList<Pair> scores;
    private Knowledge knowledge = new Knowledge();
    private boolean useScore;
    private boolean useRaskuttiUhler = false;
    private double runningScore = 0.0;

    public TeyssierScorer(IndependenceTest test, Score score) {
        if (test == null && score == null) {
            throw new IllegalArgumentException("Required: test or score");
        }
        this.variables = score.getVariables();
        this.pi = new ArrayList<Node>(this.variables);
        HashMap<Node, Integer> variablesHash = new HashMap<Node, Integer>();
        this.nodesHash(variablesHash, this.variables);
        this.nodesHash(this.orderHash, this.pi);
        this.test = test;
        this.score = score;
        this.setUseScore(true);
        if (this.useScore) {
            for (Node node : this.variables) {
                this.trees.put(node, new GrowShrinkTree(score, variablesHash, node));
            }
        }
    }

    public void setUseScore(boolean useScore) {
        boolean bl = this.useScore = useScore && !(this.score instanceof GraphScore);
        if (this.useScore) {
            this.useRaskuttiUhler = false;
        }
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
        for (Node node : this.variables) {
            ArrayList<Node> required = new ArrayList<Node>();
            ArrayList<Node> forbidden = new ArrayList<Node>();
            for (Node parent : this.variables) {
                if (knowledge.isRequired(parent.getName(), node.getName())) {
                    required.add(parent);
                }
                if (!knowledge.isForbidden(parent.getName(), node.getName())) continue;
                forbidden.add(parent);
            }
            if (required.isEmpty() && forbidden.isEmpty()) continue;
            this.trees.get(node).setKnowledge(required, forbidden);
        }
    }

    public void setUseRaskuttiUhler(boolean useRaskuttiUhler) {
        this.useRaskuttiUhler = useRaskuttiUhler;
        if (useRaskuttiUhler) {
            this.useScore = false;
        }
    }

    public double score(List<Node> order) {
        int i1;
        this.pi = new ArrayList<Node>(order);
        this.scores = new ArrayList();
        for (i1 = 0; i1 < order.size(); ++i1) {
            this.scores.add(null);
        }
        this.prefixes = new ArrayList();
        for (i1 = 0; i1 < order.size(); ++i1) {
            this.prefixes.add(null);
        }
        this.initializeScores();
        return this.score();
    }

    public double score() {
        return this.sum();
    }

    public void swaptuck(Node x, Node y) {
        if (this.index(y) < this.index(x)) {
            this.moveTo(x, this.index(y));
        }
    }

    public boolean tuck(Node k, int j) {
        if (this.adjacent(k, this.get(j))) {
            return false;
        }
        if (j >= this.index(k)) {
            return false;
        }
        Set<Node> ancestors = this.getAncestors(k);
        for (int i = j + 1; i <= this.index(k); ++i) {
            if (!ancestors.contains(this.get(i))) continue;
            this.moveTo(this.get(i), j++);
        }
        return true;
    }

    public void moveTo(Node v, int toIndex) {
        int vIndex = this.index(v);
        if (vIndex == toIndex) {
            return;
        }
        if (this.lastMoveSame(vIndex, toIndex)) {
            return;
        }
        this.pi.remove(v);
        this.pi.add(toIndex, v);
        if (toIndex < vIndex) {
            this.updateScores(toIndex, vIndex);
        } else {
            this.updateScores(vIndex, toIndex);
        }
    }

    public boolean swap(Node m, Node n) {
        int i = this.orderHash.get(m);
        int j = this.orderHash.get(n);
        this.pi.set(i, n);
        this.pi.set(j, m);
        if (this.violatesKnowledge(this.pi)) {
            this.pi.set(i, m);
            this.pi.set(j, n);
            return false;
        }
        if (i < j) {
            this.updateScores(i, j);
        } else {
            this.updateScores(j, i);
        }
        return true;
    }

    public boolean coveredEdge(Node x, Node y) {
        if (!this.adjacent(x, y)) {
            return false;
        }
        Set<Node> px = this.getParents(x);
        Set<Node> py = this.getParents(y);
        px.remove(y);
        py.remove(x);
        return px.equals(py);
    }

    public List<Node> getPi() {
        return new ArrayList<Node>(this.pi);
    }

    public List<Node> getOrderShallow() {
        return this.pi;
    }

    public int index(Node v) {
        Integer integer = this.orderHash.get(v);
        if (integer == null) {
            throw new IllegalArgumentException("First 'evaluate' a permutation containing variable " + v + ".");
        }
        return integer;
    }

    public Set<Node> getParents(int p) {
        if (this.scores.get(p) == null) {
            this.recalculate(p);
        }
        return new HashSet<Node>(this.scores.get(p).getParents());
    }

    public Set<Node> getChildren(int p) {
        Set<Node> adj = this.getAdjacentNodes(this.get(p));
        adj.removeAll(this.getParents(p));
        return adj;
    }

    public Set<Node> getParents(Node v) {
        return this.getParents(this.index(v));
    }

    public Set<Node> getChildren(Node v) {
        return this.getChildren(this.index(v));
    }

    public Set<Node> getAdjacentNodes(Node v) {
        HashSet<Node> adj = new HashSet<Node>();
        for (Node w : this.pi) {
            if (!this.getParents(v).contains(w) && !this.getParents(w).contains(v)) continue;
            adj.add(w);
        }
        return adj;
    }

    public Graph getGraph(boolean cpDag) {
        EdgeListGraph graph = new EdgeListGraph(this.variables);
        for (Node a : this.variables) {
            for (Node b : this.getParents(a)) {
                graph.addDirectedEdge(b, a);
            }
        }
        if (cpDag) {
            MeekRules rules = new MeekRules();
            rules.setKnowledge(this.knowledge);
            rules.orientImplied(graph);
        }
        return graph;
    }

    public List<NodePair> getAdjacencies() {
        List<Node> order = this.getPi();
        HashSet<NodePair> pairs = new HashSet<NodePair>();
        for (int i = 0; i < order.size(); ++i) {
            for (int j = 0; j < i; ++j) {
                Node y;
                Node x = order.get(i);
                if (!this.adjacent(x, y = order.get(j))) continue;
                pairs.add(new NodePair(x, y));
            }
        }
        return new ArrayList<NodePair>(pairs);
    }

    public Set<Node> getAncestors(Node node) {
        HashSet<Node> ancestors = new HashSet<Node>();
        this.collectAncestorsVisit(node, ancestors);
        return ancestors;
    }

    public List<OrderedPair<Node>> getEdges() {
        List<Node> order = this.getPi();
        ArrayList<OrderedPair<Node>> edges = new ArrayList<OrderedPair<Node>>();
        for (Node y : order) {
            for (Node x : this.getParents(y)) {
                edges.add(new OrderedPair<Node>(x, y));
            }
        }
        return edges;
    }

    public int getNumEdges() {
        int numEdges = 0;
        for (int p = 0; p < this.pi.size(); ++p) {
            numEdges += this.getParents(p).size();
        }
        return numEdges;
    }

    public Node get(int j) {
        return this.pi.get(j);
    }

    public void bookmark(int key) {
        try {
            this.bookmarkedOrders.put(key, new ArrayList<Node>(this.pi));
            this.bookmarkedScores.put(key, new ArrayList<Pair>(this.scores));
            this.bookmarkedOrderHashes.put(key, new HashMap<Node, Integer>(this.orderHash));
            this.bookmarkedRunningScores.put(key, this.runningScore);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void bookmark() {
        this.bookmark(Integer.MIN_VALUE);
    }

    public void goToBookmark(int key) {
        if (!this.bookmarkedOrders.containsKey(key)) {
            throw new IllegalArgumentException("That key was not bookmarked: " + key);
        }
        this.pi = new ArrayList(this.bookmarkedOrders.get(key));
        this.scores = new ArrayList(this.bookmarkedScores.get(key));
        this.orderHash = new HashMap<Node, Integer>(this.bookmarkedOrderHashes.get(key));
        this.runningScore = this.bookmarkedRunningScores.get(key);
    }

    public void goToBookmark() {
        this.goToBookmark(Integer.MIN_VALUE);
    }

    public void clearBookmarks() {
        this.bookmarkedOrders.clear();
        this.bookmarkedScores.clear();
        this.bookmarkedOrderHashes.clear();
        this.bookmarkedRunningScores.clear();
    }

    public int size() {
        return this.pi.size();
    }

    public List<Node> getShuffledVariables() {
        List<Node> variables = this.getPi();
        RandomUtil.shuffle(variables);
        return variables;
    }

    public boolean adjacent(Node a, Node b) {
        if (a == b) {
            return false;
        }
        return this.getParents(a).contains(b) || this.getParents(b).contains(a);
    }

    public boolean collider(Node a, Node b, Node c) {
        return this.getParents(b).contains(a) && this.getParents(b).contains(c);
    }

    public boolean triangle(Node a, Node b, Node c) {
        return this.adjacent(a, b) && this.adjacent(b, c) && this.adjacent(a, c);
    }

    public boolean clique(List<Node> W) {
        for (int i = 0; i < W.size(); ++i) {
            for (int j = i + 1; j < W.size(); ++j) {
                if (this.adjacent(W.get(i), W.get(j))) continue;
                return false;
            }
        }
        return true;
    }

    public Set<Node> getPrefix(int i) {
        HashSet<Node> prefix = new HashSet<Node>();
        for (int j = 0; j < i; ++j) {
            prefix.add(this.pi.get(j));
        }
        return prefix;
    }

    public Set<Set<Node>> getSkeleton() {
        List<Node> order = this.getPi();
        HashSet<Set<Node>> skeleton = new HashSet<Set<Node>>();
        for (Node y : order) {
            for (Node x : this.getParents(y)) {
                HashSet<Node> adj = new HashSet<Node>();
                adj.add(x);
                adj.add(y);
                skeleton.add(adj);
            }
        }
        return skeleton;
    }

    public boolean parent(Node k, Node j) {
        return this.getParents(j).contains(k);
    }

    private void collectAncestorsVisit(Node node, Set<Node> ancestors) {
        if (ancestors.contains(node)) {
            return;
        }
        ancestors.add(node);
        Set<Node> parents = this.getParents(node);
        if (!parents.isEmpty()) {
            for (Node parent : parents) {
                this.collectAncestorsVisit(parent, ancestors);
            }
        }
    }

    private boolean violatesKnowledge(List<Node> order) {
        if (this.knowledge.isEmpty()) {
            return false;
        }
        for (int i = 0; i < order.size(); ++i) {
            for (int j = 0; j < i; ++j) {
                if (!this.knowledge.isRequired(order.get(i).getName(), order.get(j).getName())) continue;
                return true;
            }
        }
        return false;
    }

    private void initializeScores() {
        for (int i1 = 0; i1 < this.pi.size(); ++i1) {
            this.prefixes.set(i1, null);
        }
        this.updateScores(0, this.pi.size() - 1);
    }

    private void updateScores(int i1, int i2) {
        for (int i = i1; i <= i2; ++i) {
            this.orderHash.put(this.pi.get(i), i);
            this.scores.set(i, null);
        }
    }

    private void recalculate(int p) {
        if (this.prefixes.get(p) == null || !this.prefixes.get(p).containsAll(this.getPrefix(p))) {
            Pair p2 = this.getParentsInternal(p);
            this.runningScore = this.scores.get(p) == null ? (this.runningScore += p2.score) : (this.runningScore += p2.score - this.scores.get(p).score);
            this.scores.set(p, p2);
        }
    }

    private double sum() {
        double score = 0.0;
        for (int i = 0; i < this.pi.size(); ++i) {
            if (this.scores.get(i) == null) {
                this.recalculate(i);
            }
            score += this.scores.get(i).getScore();
        }
        return score;
    }

    private void nodesHash(Map<Node, Integer> nodesHash, List<Node> variables) {
        for (int i = 0; i < variables.size(); ++i) {
            nodesHash.put(variables.get(i), i);
        }
    }

    private boolean lastMoveSame(int i1, int i2) {
        if (i1 <= i2) {
            Set<Node> prefix0 = this.getPrefix(i1);
            for (int i = i1; i <= i2; ++i) {
                prefix0.add(this.get(i));
                if (prefix0.equals(this.prefixes.get(i))) continue;
                return false;
            }
        } else {
            Set<Node> prefix0 = this.getPrefix(i1);
            for (int i = i2; i <= i1; ++i) {
                prefix0.add(this.get(i));
                if (prefix0.equals(this.prefixes.get(i))) continue;
                return false;
            }
        }
        return true;
    }

    @NotNull
    private Pair getGrowShrinkScore(int p) {
        Node n = this.pi.get(p);
        HashSet<Node> prefix = new HashSet<Node>(this.getPrefix(p));
        HashSet<Node> all = new HashSet<Node>(this.variables);
        LinkedHashSet<Node> parents = new LinkedHashSet<Node>();
        double sMax = this.trees.get(n).trace(prefix, all, parents);
        return new Pair(parents, Double.isNaN(sMax) ? Double.NEGATIVE_INFINITY : sMax);
    }

    private Pair getGrowShrinkIndependent(int p) {
        Node n = this.pi.get(p);
        HashSet<Node> parents = new HashSet<Node>();
        Set<Node> prefix = this.getPrefix(p);
        boolean changed1 = true;
        while (changed1) {
            changed1 = false;
            for (Node z0 : prefix) {
                if (parents.contains(z0) || !this.knowledge.isEmpty() && this.knowledge.isForbidden(z0.getName(), n.getName())) continue;
                if (!this.knowledge.isEmpty() && this.knowledge.isRequired(z0.getName(), n.getName())) {
                    parents.add(z0);
                    continue;
                }
                if (!this.test.checkIndependence(n, z0, new HashSet<Node>(parents)).isDependent()) continue;
                parents.add(z0);
                changed1 = true;
            }
            for (Node z1 : new HashSet(parents)) {
                if (!this.knowledge.isEmpty() && this.knowledge.isRequired(z1.getName(), n.getName())) continue;
                parents.remove(z1);
                if (this.test.checkIndependence(n, z1, new HashSet<Node>(parents)).isDependent()) {
                    parents.add(z1);
                    continue;
                }
                changed1 = true;
            }
        }
        return new Pair(parents, -parents.size());
    }

    private Pair getParentsInternal(int p) {
        if (this.useRaskuttiUhler) {
            return this.getRaskuttiUhlerParents(p);
        }
        if (this.useScore) {
            return this.getGrowShrinkScore(p);
        }
        return this.getGrowShrinkIndependent(p);
    }

    private Pair getRaskuttiUhlerParents(int p) {
        Node x = this.pi.get(p);
        HashSet<Node> parents = new HashSet<Node>();
        Set<Node> prefix = this.getPrefix(p);
        for (Node y : prefix) {
            HashSet<Node> minus = new HashSet<Node>(prefix);
            minus.remove(y);
            HashSet<Node> z = new HashSet<Node>(minus);
            if (!this.test.checkIndependence(x, y, z).isDependent()) continue;
            parents.add(y);
        }
        return new Pair(parents, -parents.size());
    }

    private static class Pair {
        private final Set<Node> parents;
        private final double score;

        private Pair(Set<Node> parents, double score) {
            this.parents = parents;
            this.score = score;
        }

        public Set<Node> getParents() {
            return this.parents;
        }

        public double getScore() {
            return this.score;
        }

        public int hashCode() {
            return this.parents.hashCode() + (int)FastMath.floor(10000.0 * this.score);
        }

        public boolean equals(Object o) {
            if (o == null) {
                return false;
            }
            if (!(o instanceof Pair)) {
                return false;
            }
            Pair thatPair = (Pair)o;
            return this.parents.equals(thatPair.parents) && this.score == thatPair.score;
        }
    }
}

