/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Score;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jetbrains.annotations.NotNull;

public class GrowShrinkTree {
    private static Score score;
    private static HashMap<Node, Integer> index;
    private final Map<Node, GSTNode> roots;

    public GrowShrinkTree(Score score) {
        GrowShrinkTree.score = score;
        index = new HashMap();
        int i = 0;
        for (Node node : score.getVariables()) {
            index.put(node, i++);
        }
        this.roots = new HashMap<Node, GSTNode>();
        for (Node node : GrowShrinkTree.score.getVariables()) {
            this.roots.put(node, new GSTNode(node));
        }
    }

    public double GrowShrink(Node node, Set<Node> prefix, LinkedHashSet<Node> parents) {
        return this.roots.get(node).GrowShrink(node, prefix, parents);
    }

    private static class GSTNode
    implements Comparable<GSTNode> {
        private final Node add;
        private boolean grow;
        private boolean shrink;
        private final double growScore;
        private double shrinkScore;
        private List<GSTNode> branches;
        private Set<Node> remove;

        private GSTNode(Node node) {
            this.add = null;
            this.grow = false;
            this.shrink = false;
            int y = (Integer)index.get(node);
            this.growScore = score.localScore(y);
        }

        private GSTNode(Node node, Node add, Set<Node> parents) {
            this.add = add;
            this.grow = false;
            this.shrink = false;
            int y = (Integer)index.get(node);
            int[] X = new int[parents.size() + 1];
            int i = 0;
            for (Node parent : parents) {
                X[i++] = (Integer)index.get(parent);
            }
            X[i] = (Integer)index.get(add);
            this.growScore = score.localScore(y, X);
        }

        public double GrowShrink(Node node, Set<Node> prefix, LinkedHashSet<Node> parents) {
            if (!this.grow) {
                this.grow = true;
                this.branches = new ArrayList<GSTNode>();
                for (Node add : score.getVariables()) {
                    GSTNode branch;
                    if (parents.contains(add) || add == node || this.compareTo(branch = new GSTNode(node, add, parents)) >= 0) continue;
                    this.branches.add(branch);
                }
                this.branches.sort(Collections.reverseOrder());
            }
            for (GSTNode branch : this.branches) {
                Node add = branch.getAdd();
                if (!prefix.contains(add)) continue;
                prefix.remove(add);
                parents.add(add);
                return branch.GrowShrink(node, prefix, parents);
            }
            if (!this.shrink) {
                Node best;
                this.shrink = true;
                this.remove = new HashSet<Node>();
                this.shrinkScore = this.growScore;
                if (parents.isEmpty()) {
                    return this.shrinkScore;
                }
                int y = (Integer)index.get(node);
                do {
                    int[] X = new int[parents.size() - 1];
                    int i = 0;
                    Iterator itr = parents.iterator();
                    itr.next();
                    while (itr.hasNext()) {
                        X[i++] = (Integer)index.get(itr.next());
                    }
                    itr = parents.iterator();
                    Node remove = (Node)itr.next();
                    best = null;
                    do {
                        double s;
                        if ((s = score.localScore(y, X)) > this.shrinkScore) {
                            this.shrinkScore = s;
                            best = remove;
                        }
                        if (i >= parents.size() - 1) continue;
                        remove = (Node)itr.next();
                        X[i++] = (Integer)index.get(remove);
                    } while (i < parents.size() - 1);
                    if (best == null) continue;
                    parents.remove(best);
                    this.remove.add(best);
                } while (best != null);
            }
            parents.removeAll(this.remove);
            return this.shrinkScore;
        }

        public Node getAdd() {
            return this.add;
        }

        public double getGrowScore() {
            return this.growScore;
        }

        @Override
        public int compareTo(@NotNull GSTNode branch) {
            return Double.compare(this.growScore, branch.getGrowScore());
        }
    }
}

