/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.PermutationSearch;
import edu.cmu.tetrad.search.SuborderSearch;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.BesPermutation;
import edu.cmu.tetrad.search.utils.GrowShrinkTree;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;

public class Boss
implements SuborderSearch {
    private final Score score;
    private final List<Node> variables;
    private final Map<Node, Set<Node>> parents;
    private Map<Node, GrowShrinkTree> gsts;
    private Set<Node> all;
    private ForkJoinPool pool;
    private Knowledge knowledge = new Knowledge();
    private BesPermutation bes = null;
    private int numStarts = 1;
    private boolean useDataOrder = true;
    private boolean resetAfterBM = false;
    private boolean resetAfterRS = true;
    private int numThreads = 1;
    private List<Double> bics;
    private List<Double> times;
    private boolean verbose = false;

    public Boss(Score score) {
        this.score = score;
        this.variables = score.getVariables();
        this.parents = new HashMap<Node, Set<Node>>();
        for (Node x : this.variables) {
            this.parents.put(x, new HashSet());
        }
    }

    @Override
    public void searchSuborder(List<Node> prefix, List<Node> suborder, Map<Node, GrowShrinkTree> gsts) {
        assert (this.numStarts > 0);
        this.gsts = gsts;
        this.all = new HashSet<Node>(prefix);
        this.all.addAll(suborder);
        this.bics = new ArrayList<Double>();
        this.times = new ArrayList<Double>();
        ArrayList<Node> bestSuborder = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        if (this.numThreads > 1) {
            this.pool = new ForkJoinPool(this.numThreads);
        } else if (this.numThreads != 1) {
            this.pool = ForkJoinPool.commonPool();
        }
        for (int i = 0; i < this.numStarts; ++i) {
            boolean improved;
            double time = System.currentTimeMillis();
            if (i == 0 && !this.useDataOrder || i > 0) {
                RandomUtil.shuffle(suborder);
            }
            if (i > 0 && this.resetAfterRS) {
                for (Node root : suborder) {
                    this.gsts.get(root).reset();
                }
            }
            this.makeValidKnowledgeOrder(suborder);
            do {
                improved = false;
                for (Node x : new ArrayList<Node>(suborder)) {
                    if (this.verbose && suborder.size() > 1) {
                        System.out.println(x);
                    }
                    if (this.numThreads == 1) {
                        improved |= this.betterMutation(prefix, suborder, x);
                        continue;
                    }
                    improved |= this.betterMutationAsync(prefix, suborder, x);
                }
                if (!this.verbose || suborder.size() <= 1) continue;
                System.out.printf("\nScore: %.3f\n\n", this.update(prefix, suborder));
            } while (improved);
            if (this.bes != null) {
                this.bes(prefix, suborder);
            }
            double score = this.update(prefix, suborder);
            time = (double)System.currentTimeMillis() - time;
            if (suborder.size() > 1) {
                this.bics.add(score);
                this.times.add(time);
                if (this.verbose) {
                    System.out.printf("\nRestart: %d\t Score: %.3f\t Time: %.3f\n\n", i, score, time / 1000.0);
                }
            }
            if (!(score > bestScore)) continue;
            bestSuborder = new ArrayList<Node>(suborder);
            bestScore = score;
        }
        if (this.numThreads > 1) {
            this.pool.shutdown();
        }
        suborder.clear();
        if (bestSuborder != null) {
            suborder.addAll(bestSuborder);
        }
        this.update(prefix, suborder);
    }

    public void setUseBes(boolean use) {
        this.bes = null;
        if (use) {
            this.bes = new BesPermutation(this.score);
            this.bes.setVerbose(false);
            this.bes.setKnowledge(this.knowledge);
        }
    }

    @Override
    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
        if (this.bes != null) {
            this.bes.setKnowledge(knowledge);
        }
    }

    public void setNumStarts(int numStarts) {
        this.numStarts = numStarts;
    }

    public void setResetAfterBM(boolean reset) {
        this.resetAfterBM = reset;
    }

    public void setResetAfterRS(boolean reset) {
        this.resetAfterRS = reset;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setNumThreads(int numThreads) {
        this.numThreads = numThreads;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Map<Node, Set<Node>> getParents() {
        return this.parents;
    }

    @Override
    public Score getScore() {
        return this.score;
    }

    public List<Double> getBics() {
        return this.bics;
    }

    public List<Double> getTimes() {
        return this.times;
    }

    public void setUseDataOrder(boolean useDataOrder) {
        this.useDataOrder = useDataOrder;
    }

    private boolean betterMutationAsync(List<Node> prefix, List<Node> suborder, Node x) {
        ArrayList<Trace> tasks = new ArrayList<Trace>();
        double[] scores = new double[suborder.size()];
        double[] with = new double[suborder.size() - 1];
        double[] without = new double[suborder.size() - 1];
        HashSet<Node> Z = new HashSet<Node>(prefix);
        int i = 0;
        int curr = 0;
        tasks.add(new Trace(this.gsts.get(x), this.all, Z, scores, i));
        for (Node z : suborder) {
            if (this.knowledge.isRequired(x.getName(), z.getName())) break;
            if (x == z) {
                curr = i;
                continue;
            }
            Z.add(x);
            tasks.add(new Trace(this.gsts.get(z), this.all, Z, with, i));
            Z.remove(x);
            tasks.add(new Trace(this.gsts.get(z), this.all, Z, without, i));
            Z.add(z);
            tasks.add(new Trace(this.gsts.get(x), this.all, Z, scores, ++i));
        }
        RandomUtil.shuffle(tasks);
        this.pool.invokeAll(tasks);
        if (this.resetAfterBM) {
            this.gsts.get(x).reset();
        }
        double runningScore = 0.0;
        i = with.length - 1;
        while (i >= 0) {
            int n = i--;
            scores[n] = scores[n] + (runningScore += with[i]);
        }
        runningScore = 0.0;
        for (i = 0; i < without.length; ++i) {
            int n = i + 1;
            scores[n] = scores[n] + (runningScore += without[i]);
        }
        int best = curr;
        for (i = scores.length - 1; i >= 0 && !this.knowledge.isRequired(suborder.get(i).getName(), x.getName()); --i) {
            if (!(scores[i] + 1.0E-6 > scores[best])) continue;
            best = i;
        }
        if (scores[curr] + 1.0E-6 > scores[best]) {
            return false;
        }
        suborder.remove(x);
        suborder.add(best, x);
        return true;
    }

    private boolean betterMutation(List<Node> prefix, List<Node> suborder, Node x) {
        Node z;
        ListIterator<Node> itr = suborder.listIterator();
        double[] scores = new double[suborder.size() + 1];
        HashSet<Node> Z = new HashSet<Node>(prefix);
        int i = 0;
        double score = 0.0;
        int curr = 0;
        while (itr.hasNext()) {
            Node z2 = itr.next();
            if (this.knowledge.isRequired(x.getName(), z2.getName())) {
                itr.previous();
                break;
            }
            scores[i++] = this.gsts.get(x).trace(Z, this.all) + score;
            if (z2 != x) {
                score += this.gsts.get(z2).trace(Z, this.all);
                Z.add(z2);
                continue;
            }
            curr = i - 1;
        }
        scores[i] = this.gsts.get(x).trace(Z, this.all) + score;
        int best = i;
        Z.add(x);
        score = 0.0;
        while (itr.hasPrevious() && !this.knowledge.isRequired((z = itr.previous()).getName(), x.getName())) {
            if (z != x) {
                Z.remove(z);
                score += this.gsts.get(z).trace(Z, this.all);
            }
            int n = --i;
            scores[n] = scores[n] + score;
            if (!(scores[i] + 1.0E-6 > scores[best])) continue;
            best = i;
        }
        if (scores[curr] + 1.0E-6 > scores[best]) {
            return false;
        }
        if (best > curr) {
            --best;
        }
        suborder.remove(x);
        suborder.add(best, x);
        return true;
    }

    private void bes(List<Node> prefix, List<Node> suborder) {
        ArrayList<Node> all = new ArrayList<Node>(prefix);
        all.addAll(suborder);
        Graph graph = PermutationSearch.getGraph(all, this.parents, this.knowledge, true);
        this.bes.bes(graph, all, suborder);
        graph.paths().makeValidOrder(suborder);
    }

    private double update(List<Node> prefix, List<Node> suborder) {
        double score = 0.0;
        HashSet<Node> Z = new HashSet<Node>(prefix);
        for (Node x : suborder) {
            Set<Node> parents = this.parents.get(x);
            parents.clear();
            score += this.gsts.get(x).trace(Z, this.all, parents);
            Z.add(x);
        }
        return score;
    }

    private void makeValidKnowledgeOrder(List<Node> order) {
        int i;
        if (this.knowledge.isEmpty()) {
            return;
        }
        int index = 0;
        HashSet<String> tier = new HashSet<String>(this.knowledge.getVariablesNotInTiers());
        for (i = 0; i < order.size(); ++i) {
            if (!tier.contains(order.get(i).getName())) continue;
            Node x = order.remove(i);
            order.add(index++, x);
        }
        for (i = 0; i < this.knowledge.getNumTiers(); ++i) {
            tier = new HashSet<String>(this.knowledge.getTier(i));
            for (int j = 0; j < order.size(); ++j) {
                if (!tier.contains(order.get(j).getName())) continue;
                Node x = order.remove(j);
                order.add(index++, x);
            }
        }
        block3: for (i = 1; i < order.size(); ++i) {
            String a = order.get(i).getName();
            for (int j = 0; j < i; ++j) {
                String b = order.get(j).getName();
                if (!this.knowledge.isRequired(a, b)) continue;
                Node x = order.remove(i);
                order.add(j, x);
                continue block3;
            }
        }
    }

    private static class Trace
    implements Callable<Void> {
        private final GrowShrinkTree gst;
        private final Set<Node> all;
        private final Set<Node> prefix;
        private final double[] scores;
        private final int index;

        Trace(GrowShrinkTree gst, Set<Node> all, Set<Node> prefix, double[] scores, int index) {
            this.gst = gst;
            this.all = all;
            this.prefix = new HashSet<Node>(prefix);
            this.scores = scores;
            this.index = index;
        }

        @Override
        public Void call() {
            double score;
            this.scores[this.index] = score = this.gst.trace(this.prefix, this.all);
            return null;
        }
    }
}

