/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeEqualityMode;
import edu.cmu.tetrad.graph.NodePair;
import edu.cmu.tetrad.graph.OrderedPair;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SearchGraphUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jetbrains.annotations.NotNull;

public class TeyssierScorerOpt {
    private final List<Node> variables;
    private final Map<Node, Integer> variablesHash;
    private final Score score;
    private final Map<Node, Map<Set<Node>, Float>> cache = new HashMap<Node, Map<Set<Node>, Float>>();
    private final Map<Node, Integer> orderHash;
    private ArrayList<Node> pi;
    private ArrayList<Pair> scores = new ArrayList();
    private Knowledge knowledge = new Knowledge();
    private ArrayList<Set<Node>> prefixes;

    public TeyssierScorerOpt(@NotNull Score score) {
        NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT);
        this.score = score;
        this.variables = score.getVariables();
        this.pi = new ArrayList<Node>(this.variables);
        this.orderHash = new HashMap<Node, Integer>();
        this.nodesHash(this.orderHash, this.pi);
        this.variablesHash = new HashMap<Node, Integer>();
        this.nodesHash(this.variablesHash, this.variables);
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public float score(List<Node> order) {
        int i1;
        this.pi = new ArrayList<Node>(order);
        this.scores = new ArrayList();
        for (i1 = 0; i1 < order.size(); ++i1) {
            this.scores.add(null);
        }
        this.prefixes = new ArrayList();
        for (i1 = 0; i1 < order.size(); ++i1) {
            this.prefixes.add(new HashSet());
        }
        this.initializeScores();
        return this.score();
    }

    public float score() {
        return this.sum();
    }

    private float sum() {
        float score = 0.0f;
        for (int i = 0; i < this.pi.size(); ++i) {
            float score1 = this.scores.get(i).getScore();
            score += score1;
        }
        return score;
    }

    public void moveTo(Node v, int toIndex) {
        int vIndex = this.index(v);
        if (vIndex == toIndex) {
            return;
        }
        this.pi.remove(v);
        this.pi.add(toIndex, v);
        if (toIndex < vIndex) {
            this.updateScores(toIndex, vIndex);
        } else {
            this.updateScores(vIndex, toIndex);
        }
    }

    public void demote(Node v) {
        int vIndex = this.index(v);
        int toIndex = vIndex + 1;
        if (toIndex > this.size() - 1) {
            return;
        }
        this.pi.remove(v);
        this.pi.add(toIndex, v);
        this.updateScores(vIndex, toIndex);
    }

    public boolean swap(Node m, Node n) {
        int i = this.orderHash.get(m);
        int j = this.orderHash.get(n);
        this.pi.set(i, n);
        this.pi.set(j, m);
        if (this.violatesKnowledge(this.pi)) {
            this.pi.set(i, m);
            this.pi.set(j, n);
            return false;
        }
        if (i < j) {
            this.updateScores(i, j);
        } else {
            this.updateScores(j, i);
        }
        return true;
    }

    public List<Node> getPi() {
        return new ArrayList<Node>(this.pi);
    }

    public int index(Node v) {
        Integer integer;
        if (!this.orderHash.containsKey(v)) {
            System.out.println();
        }
        if ((integer = this.orderHash.get(v)) == null) {
            throw new IllegalArgumentException("First 'evaluate' a permutation containing variable " + v + ".");
        }
        return integer;
    }

    public Set<Node> getParents(int p) {
        return new HashSet<Node>(this.scores.get(p).getParents());
    }

    public Set<Node> getParents(Node v) {
        return this.getParents(this.index(v));
    }

    public Set<Node> getAdjacentNodes(Node v) {
        HashSet<Node> adj = new HashSet<Node>();
        for (Node w : this.pi) {
            if (!this.getParents(v).contains(w) && !this.getParents(w).contains(v)) continue;
            adj.add(w);
        }
        return adj;
    }

    public Graph getGraph(boolean cpDag) {
        List<Node> order = this.getPi();
        EdgeListGraph G1 = new EdgeListGraph(this.variables);
        for (int p = 0; p < order.size(); ++p) {
            for (Node z : this.getParents(p)) {
                G1.addDirectedEdge(z, order.get(p));
            }
        }
        GraphUtils.replaceNodes(G1, this.variables);
        if (cpDag) {
            return SearchGraphUtils.cpdagForDag(G1);
        }
        return G1;
    }

    public List<NodePair> getAdjacencies() {
        List<Node> order = this.getPi();
        HashSet<NodePair> pairs = new HashSet<NodePair>();
        for (int i = 0; i < order.size(); ++i) {
            for (int j = 0; j < i; ++j) {
                Node y;
                Node x = order.get(i);
                if (!this.adjacent(x, y = order.get(j))) continue;
                pairs.add(new NodePair(x, y));
            }
        }
        return new ArrayList<NodePair>(pairs);
    }

    public Set<Node> getAncestors(Node node) {
        HashSet<Node> ancestors = new HashSet<Node>();
        this.collectAncestorsVisit(node, ancestors);
        return ancestors;
    }

    private void collectAncestorsVisit(Node node, Set<Node> ancestors) {
        if (ancestors.contains(node)) {
            return;
        }
        ancestors.add(node);
        Set<Node> parents = this.getParents(node);
        if (!parents.isEmpty()) {
            for (Node parent : parents) {
                this.collectAncestorsVisit(parent, ancestors);
            }
        }
    }

    public List<OrderedPair<Node>> getEdges() {
        List<Node> order = this.getPi();
        ArrayList<OrderedPair<Node>> edges = new ArrayList<OrderedPair<Node>>();
        for (Node y : order) {
            for (Node x : this.getParents(y)) {
                edges.add(new OrderedPair<Node>(x, y));
            }
        }
        return edges;
    }

    public int getNumEdges() {
        int numEdges = 0;
        for (int p = 0; p < this.pi.size(); ++p) {
            numEdges += this.getParents(p).size();
        }
        return numEdges;
    }

    public Node get(int j) {
        return this.pi.get(j);
    }

    public int size() {
        return this.pi.size();
    }

    public boolean adjacent(Node a, Node b) {
        if (a == b) {
            return false;
        }
        return this.getParents(a).contains(b) || this.getParents(b).contains(a);
    }

    public boolean collider(Node a, Node b, Node c) {
        return this.getParents(b).contains(a) && this.getParents(b).contains(c);
    }

    public boolean triangle(Node a, Node b, Node c) {
        return this.adjacent(a, b) && this.adjacent(b, c) && this.adjacent(a, c);
    }

    public boolean clique(List<Node> W) {
        for (int i = 0; i < W.size(); ++i) {
            for (int j = i + 1; j < W.size(); ++j) {
                if (this.adjacent(W.get(i), W.get(j))) continue;
                return false;
            }
        }
        return true;
    }

    private boolean violatesKnowledge(List<Node> order) {
        if (!this.knowledge.isEmpty()) {
            for (int i = 0; i < order.size(); ++i) {
                for (int j = i + 1; j < order.size(); ++j) {
                    if (!this.knowledge.isForbidden(order.get(i).getName(), order.get(j).getName())) continue;
                    return true;
                }
            }
        }
        return false;
    }

    private void initializeScores() {
        for (int i1 = 0; i1 < this.pi.size(); ++i1) {
            this.prefixes.set(i1, null);
        }
        for (int i = 0; i < this.pi.size(); ++i) {
            this.recalculate(i);
            this.orderHash.put(this.pi.get(i), i);
        }
    }

    public void updateScores(int i1, int i2) {
        for (int i = i1; i <= i2; ++i) {
            this.recalculate(i);
            this.orderHash.put(this.pi.get(i), i);
        }
    }

    private float score(Node n, Set<Node> pi) {
        pi = new HashSet<Node>(pi);
        this.cache.computeIfAbsent(n, w -> new HashMap());
        Float score = this.cache.get(n).get(pi);
        if (score != null) {
            return score.floatValue();
        }
        int[] parentIndices = new int[pi.size()];
        int k = 0;
        for (Node p : pi) {
            parentIndices[k++] = this.variablesHash.get(p);
        }
        float v = (float)this.score.localScore((int)this.variablesHash.get(n), parentIndices);
        this.cache.computeIfAbsent(n, w -> new HashMap());
        this.cache.get(n).put(pi, Float.valueOf(v));
        return v;
    }

    private Set<Node> getPrefix(int i) {
        HashSet<Node> prefix = new HashSet<Node>();
        for (int j = 0; j < i; ++j) {
            prefix.add(this.pi.get(j));
        }
        return prefix;
    }

    private void recalculate(int p) {
        Set<Node> prefix = this.getPrefix(p);
        Pair pair = this.getGrowShrinkScore(p, prefix);
        this.scores.set(p, pair);
    }

    private void nodesHash(Map<Node, Integer> nodesHash, List<Node> variables) {
        for (int i = 0; i < variables.size(); ++i) {
            nodesHash.put(variables.get(i), i);
        }
    }

    private boolean lastMoveSame(int i1, int i2) {
        if (i1 <= i2) {
            for (int i = i1; i <= i2; ++i) {
                if (this.getPrefix(i).equals(this.prefixes.get(i))) continue;
                return false;
            }
        } else {
            for (int i = i2; i <= i1; ++i) {
                if (this.getPrefix(i).equals(this.prefixes.get(i))) continue;
                return false;
            }
        }
        return true;
    }

    @NotNull
    private Pair getGrowShrinkScore(int p, Set<Node> prefix) {
        Node n = this.pi.get(p);
        HashSet<Node> parents = new HashSet<Node>();
        float sMax = this.score(n, new HashSet<Node>());
        boolean changed = true;
        while (changed) {
            changed = false;
            Node z = null;
            for (Node z0 : prefix) {
                if (parents.contains(z0) || !this.knowledge.isEmpty() && this.knowledge.isForbidden(z0.getName(), n.getName())) continue;
                parents.add(z0);
                float s2 = this.score(n, parents);
                if (s2 >= sMax) {
                    sMax = s2;
                    z = z0;
                }
                parents.remove(z0);
            }
            if (z == null) continue;
            parents.add(z);
            changed = true;
        }
        boolean changed2 = true;
        while (changed2) {
            changed2 = false;
            Node w = null;
            for (Node z0 : new HashSet(parents)) {
                parents.remove(z0);
                float s2 = this.score(n, parents);
                if (s2 >= sMax) {
                    sMax = s2;
                    w = z0;
                }
                parents.add(z0);
            }
            if (w == null) continue;
            parents.remove(w);
            changed2 = true;
        }
        return new Pair(parents, Float.isNaN(sMax) ? Float.NEGATIVE_INFINITY : sMax);
    }

    private static class Pair {
        private final Set<Node> parents;
        private final float score;

        private Pair(Set<Node> parents, float score) {
            this.parents = parents;
            this.score = score;
        }

        public Set<Node> getParents() {
            return this.parents;
        }

        public float getScore() {
            return this.score;
        }

        public int hashCode() {
            return this.parents.hashCode();
        }

        public boolean equals(Object o) {
            if (o == null) {
                return false;
            }
            if (!(o instanceof Pair)) {
                return false;
            }
            Pair thatPair = (Pair)o;
            return this.parents.equals(thatPair.parents);
        }
    }
}

