/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.utils;

import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.Score;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.jetbrains.annotations.NotNull;

public class GrowShrinkTree {
    private final Score score;
    private final Map<Node, Integer> index;
    private final Node node;
    private final int nodeIndex;
    private List<Node> required;
    private List<Node> forbidden;
    private GSTNode root;

    public GrowShrinkTree(Score score, Map<Node, Integer> index, Node node) {
        this.score = score;
        this.index = index;
        this.node = node;
        this.nodeIndex = index.get(node);
        this.required = new ArrayList<Node>();
        this.forbidden = new ArrayList<Node>();
        this.root = new GSTNode(this);
    }

    public double trace(Set<Node> prefix, Set<Node> all) {
        HashSet<Node> available = new HashSet<Node>(all);
        available.remove(this.node);
        available.removeAll(this.forbidden);
        HashSet<Node> parents = new HashSet<Node>();
        return this.root.trace(prefix, available, parents);
    }

    public double trace(Set<Node> prefix, Set<Node> all, Set<Node> parents) {
        HashSet<Node> available = new HashSet<Node>(all);
        available.remove(this.node);
        available.removeAll(this.forbidden);
        return this.root.trace(prefix, available, parents);
    }

    public Node getNode() {
        return this.node;
    }

    public List<Node> getFirstLayer() {
        ArrayList<Node> firstLayer = new ArrayList<Node>();
        for (GSTNode branch : this.root.branches) {
            firstLayer.add(branch.getAdd());
        }
        return firstLayer;
    }

    public Integer getIndex(Node node) {
        return this.index.get(node);
    }

    public Double localScore() {
        double score = this.score.localScore(this.nodeIndex);
        return Double.isNaN(score) ? 0.0 : score;
    }

    public Double localScore(int[] X) {
        double score = this.score.localScore(this.nodeIndex, X);
        return Double.isNaN(score) ? Double.NEGATIVE_INFINITY : score;
    }

    public boolean isRequired(Node node) {
        return this.required.contains(node);
    }

    public boolean isForbidden(Node node) {
        return this.forbidden.contains(node);
    }

    public List<Node> getVariables() {
        return this.score.getVariables();
    }

    public List<Node> getRequired() {
        return this.required;
    }

    public List<Node> getForbidden() {
        return this.forbidden;
    }

    public void setKnowledge(List<Node> required, List<Node> forbidden) {
        this.required = required;
        this.forbidden = forbidden;
        this.reset();
    }

    public void reset() {
        this.root = new GSTNode(this);
    }

    private static class GSTNode
    implements Comparable<GSTNode> {
        private final GrowShrinkTree tree;
        private final Node add;
        private final double growScore;
        private final AtomicBoolean grow;
        private final AtomicBoolean shrink;
        private double shrinkScore;
        private List<GSTNode> branches;
        private Set<Node> remove;

        private GSTNode(GrowShrinkTree tree) {
            this.tree = tree;
            this.add = null;
            this.grow = new AtomicBoolean(false);
            this.shrink = new AtomicBoolean(false);
            this.growScore = this.tree.localScore();
        }

        private GSTNode(GrowShrinkTree tree, Node add, Set<Node> parents) {
            this.tree = tree;
            this.add = add;
            this.grow = new AtomicBoolean(false);
            this.shrink = new AtomicBoolean(false);
            int i = 0;
            int[] X = new int[parents.size() + 1];
            for (Node parent : parents) {
                X[i++] = this.tree.getIndex(parent);
            }
            X[i] = this.tree.getIndex(add);
            this.growScore = this.tree.localScore(X);
        }

        private synchronized void grow(Set<Node> available, Set<Node> parents) {
            if (this.grow.get()) {
                return;
            }
            this.branches = new ArrayList<GSTNode>();
            ArrayList<GSTNode> required = new ArrayList<GSTNode>();
            for (Node add : available) {
                GSTNode branch = new GSTNode(this.tree, add, parents);
                if (this.tree.isRequired(add)) {
                    required.add(branch);
                    continue;
                }
                if (!(branch.getGrowScore() >= this.growScore)) continue;
                this.branches.add(branch);
            }
            this.branches.sort(Collections.reverseOrder());
            this.branches.addAll(0, required);
            this.grow.set(true);
        }

        private synchronized void shrink(Set<Node> parents) {
            Node best;
            if (this.shrink.get()) {
                return;
            }
            this.remove = new HashSet<Node>();
            this.shrinkScore = this.growScore;
            if (parents.isEmpty()) {
                return;
            }
            do {
                best = null;
                int[] X = new int[parents.size() - 1];
                for (Node remove : new HashSet<Node>(parents)) {
                    if (this.tree.isRequired(remove)) continue;
                    int i = 0;
                    parents.remove(remove);
                    for (Node parent : parents) {
                        X[i++] = this.tree.getIndex(parent);
                    }
                    parents.add(remove);
                    double s = this.tree.localScore(X);
                    if (!(s > this.shrinkScore)) continue;
                    this.shrinkScore = s;
                    best = remove;
                }
                if (best == null) continue;
                parents.remove(best);
                this.remove.add(best);
            } while (best != null);
            this.shrink.set(true);
        }

        public double trace(Set<Node> prefix, Set<Node> available, Set<Node> parents) {
            if (!this.grow.get()) {
                this.grow(available, parents);
            }
            for (GSTNode branch : this.branches) {
                Node add = branch.getAdd();
                available.remove(add);
                if (!prefix.contains(add)) continue;
                parents.add(add);
                return branch.trace(prefix, available, parents);
            }
            if (!this.shrink.get()) {
                this.shrink(parents);
            }
            parents.removeAll(this.remove);
            return this.shrinkScore;
        }

        public Node getAdd() {
            return this.add;
        }

        public double getGrowScore() {
            return this.growScore;
        }

        @Override
        public int compareTo(@NotNull GSTNode branch) {
            return Double.compare(this.growScore, branch.getGrowScore());
        }
    }
}

