/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.OrderedPair;
import edu.cmu.tetrad.search.GraphScore;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.TeyssierScorer;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.PermutationGenerator;
import edu.cmu.tetrad.util.RandomUtil;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.jetbrains.annotations.NotNull;

public class OtherPermAlgs {
    private final List<Node> variables;
    private long start;
    private Score score;
    private IndependenceTest test;
    private int numStarts = 1;
    private Method method = Method.GSP;
    private Knowledge knowledge = new Knowledge();
    private int depth = 4;
    private TeyssierScorer scorer;
    private int numRounds = 50;
    private boolean useScore = true;
    private boolean usePearl = false;
    private boolean verbose = false;
    private boolean cachingScores = true;
    private boolean useDataOrder = false;
    private int numVars;

    public OtherPermAlgs(@NotNull Score score) {
        this.score = score;
        this.variables = new ArrayList<Node>(score.getVariables());
        this.useScore = true;
    }

    public OtherPermAlgs(@NotNull IndependenceTest test) {
        this.test = test;
        this.variables = new ArrayList<Node>(test.getVariables());
        this.useScore = false;
    }

    public OtherPermAlgs(@NotNull IndependenceTest test, Score score) {
        this.test = test;
        this.score = score;
        this.variables = new ArrayList<Node>(test.getVariables());
    }

    public List<Node> bestOrder(@NotNull List<Node> _order) {
        ArrayList<Node> order = new ArrayList<Node>(_order);
        long start = MillisecondTimes.timeMillis();
        if (this.useScore && !(this.score instanceof GraphScore)) {
            this.scorer = new TeyssierScorer(this.test, this.score);
            this.scorer.setUseScore(true);
        } else {
            this.scorer = new TeyssierScorer(this.test, this.score);
            this.scorer.setUseRaskuttiUhler(this.usePearl);
            this.scorer.score(this.variables);
            if (this.usePearl) {
                this.scorer.setUseScore(false);
            } else {
                this.scorer.setUseScore(this.useScore);
            }
        }
        this.scorer.setKnowledge(this.knowledge);
        this.scorer.clearBookmarks();
        this.scorer.setCachingScores(this.cachingScores);
        List<Node> bestPerm = new ArrayList<Node>(order);
        double best = Double.NEGATIVE_INFINITY;
        for (int r = 0; r < (this.useDataOrder ? 1 : this.numStarts); ++r) {
            List<Node> perm;
            if (!this.useDataOrder) {
                RandomUtil.shuffle(order);
            }
            this.start = MillisecondTimes.timeMillis();
            this.makeValidKnowledgeOrder(order);
            this.scorer.score(order);
            if (this.verbose) {
                System.out.println("Using " + (Object)((Object)this.method));
            }
            if (this.method == Method.RCG) {
                perm = this.rcg(this.scorer);
            } else if (this.method == Method.GSP) {
                perm = this.gasp(this.scorer);
            } else if (this.method == Method.ESP) {
                perm = this.esp(this.scorer);
            } else if (this.method == Method.SP) {
                this.useDataOrder = true;
                perm = this.sp(this.scorer);
            } else {
                throw new IllegalArgumentException("Unrecognized method: " + (Object)((Object)this.method));
            }
            this.scorer.score(perm);
            if (!(this.scorer.score() > best)) continue;
            best = this.scorer.score();
            bestPerm = perm;
        }
        long stop = MillisecondTimes.timeMillis();
        if (this.verbose) {
            System.out.println("Final order = " + this.scorer.getPi());
            System.out.println("Elapsed time = " + (double)(stop - start) / 1000.0 + " s");
        }
        return bestPerm;
    }

    public int getNumEdges() {
        return this.scorer.getNumEdges();
    }

    private void makeValidKnowledgeOrder(List<Node> order) {
        if (!this.knowledge.isEmpty()) {
            order.sort((o1, o2) -> {
                if (o1.getName().equals(o2.getName())) {
                    return 0;
                }
                if (this.knowledge.isRequired(o1.getName(), o2.getName())) {
                    return 1;
                }
                if (this.knowledge.isRequired(o2.getName(), o1.getName())) {
                    return -1;
                }
                if (this.knowledge.isForbidden(o2.getName(), o1.getName())) {
                    return -1;
                }
                if (this.knowledge.isForbidden(o1.getName(), o2.getName())) {
                    return 1;
                }
                return 1;
            });
        }
    }

    public List<Node> esp(@NotNull TeyssierScorer scorer) {
        double sOld;
        if (this.depth <= 0) {
            throw new IllegalArgumentException("Form ESP, max depth should be > 0");
        }
        double sNew = scorer.score();
        do {
            sOld = sNew;
            this.espDfs(scorer, sOld, this.depth < 0 ? 100 : this.depth, 1);
        } while ((sNew = scorer.score()) > sOld);
        if (this.verbose) {
            System.out.println("# Edges = " + scorer.getNumEdges() + " Score = " + scorer.score() + " (ESP) Elapsed " + (double)(MillisecondTimes.timeMillis() - this.start) / 1000.0 + " s");
        }
        return scorer.getPi();
    }

    public List<Node> gasp(@NotNull TeyssierScorer scorer) {
        double sOld;
        if (this.depth < 0) {
            throw new IllegalArgumentException("Form GRASP, max depth should be >= 0");
        }
        scorer.clearBookmarks();
        double sNew = scorer.score();
        do {
            sOld = sNew;
            this.graspDfs(scorer, sOld, this.depth < 0 ? Integer.MAX_VALUE : this.depth, 0, true);
        } while ((sNew = scorer.score()) > sOld);
        if (this.verbose) {
            System.out.println("# Edges = " + scorer.getNumEdges() + " Score = " + scorer.score() + " (GASP)) Elapsed " + (double)(MillisecondTimes.timeMillis() - this.start) / 1000.0 + " s");
        }
        return scorer.getPi();
    }

    public List<Node> rcg(@NotNull TeyssierScorer scorer) {
        if (this.numRounds <= 0) {
            throw new IllegalArgumentException("For RCG, #rounds should be > 0");
        }
        scorer.clearBookmarks();
        DecimalFormat nf = new DecimalFormat("0.0");
        if (this.verbose) {
            System.out.println("\nInitial # edges = " + scorer.getNumEdges());
        }
        scorer.bookmark(1);
        int maxRounds = this.numRounds < 0 ? Integer.MAX_VALUE : this.numRounds;
        int unimproved = 0;
        for (int r = 1; r <= maxRounds; ++r) {
            double sNew;
            double s0;
            Node y;
            Node x;
            if (this.verbose) {
                System.out.println("### Round " + r);
            }
            ArrayList<OrderedPair<Node>> pairs = new ArrayList<OrderedPair<Node>>();
            for (Node y2 : scorer.getPi()) {
                for (Node x2 : scorer.getParents(y2)) {
                    pairs.add(new OrderedPair<Node>(x2, y2));
                }
            }
            RandomUtil.shuffle(pairs);
            int numImprovements = 0;
            int numEquals = 0;
            int visited = 0;
            int numPairs = pairs.size();
            for (OrderedPair orderedPair : pairs) {
                ++visited;
                x = (Node)orderedPair.getFirst();
                if (!scorer.adjacent(x, y = (Node)orderedPair.getSecond())) continue;
                s0 = scorer.score();
                scorer.bookmark(0);
                scorer.moveTo(y, scorer.index(x));
                if (this.violatesKnowledge(scorer.getOrderShallow())) {
                    scorer.goToBookmark(0);
                    continue;
                }
                sNew = scorer.score();
                if (sNew < s0) {
                    scorer.goToBookmark(0);
                    continue;
                }
                scorer.bookmark(1);
                if (sNew > s0) {
                    ++numImprovements;
                }
                if (!this.verbose) continue;
                if (sNew == s0) {
                    ++numEquals;
                }
                if (!(sNew > s0)) continue;
                System.out.println("Round " + r + " # improvements = " + numImprovements + " # unimproved = " + numEquals + " # edges = " + scorer.getNumEdges() + " progress this round = " + nf.format((double)(100 * visited) / (double)numPairs) + "%");
            }
            for (OrderedPair orderedPair : pairs) {
                ++visited;
                x = (Node)orderedPair.getFirst();
                if (!scorer.adjacent(x, y = (Node)orderedPair.getSecond())) continue;
                s0 = scorer.score();
                scorer.bookmark(0);
                scorer.moveTo(x, scorer.index(y));
                if (this.violatesKnowledge(scorer.getOrderShallow())) {
                    scorer.goToBookmark(0);
                    continue;
                }
                sNew = scorer.score();
                if (sNew < s0) {
                    scorer.goToBookmark(0);
                    continue;
                }
                scorer.bookmark(1);
                if (sNew > s0) {
                    ++numImprovements;
                }
                if (!this.verbose) continue;
                if (sNew == s0) {
                    ++numEquals;
                }
                if (!(sNew > s0)) continue;
                System.out.println("Round " + r + " # improvements = " + numImprovements + " # unimproved = " + numEquals + " # edges = " + scorer.getNumEdges() + " progress this round = " + nf.format((double)(100 * visited) / (double)numPairs) + "%");
            }
            if (numImprovements == 0) {
                ++unimproved;
            }
            if (unimproved >= this.depth) break;
        }
        if (this.verbose) {
            System.out.println("# Edges = " + scorer.getNumEdges() + " Score = " + scorer.score() + " #round = " + this.numRounds + " (RCG) Elapsed " + (double)(MillisecondTimes.timeMillis() - this.start) / 1000.0 + " s");
        }
        scorer.goToBookmark(1);
        return scorer.getPi();
    }

    public List<Node> sp(@NotNull TeyssierScorer scorer) {
        int[] perm;
        double maxScore = Double.NEGATIVE_INFINITY;
        List<Node> maxP = null;
        List<Node> variables = scorer.getPi();
        PermutationGenerator gen = new PermutationGenerator(variables.size());
        HashSet<Graph> frugalCpdags = new HashSet<Graph>();
        int[] v = new int[this.numVars];
        for (int i = 0; i < this.numVars; ++i) {
            v[i] = i;
        }
        List<Node> pi0 = GraphUtils.asList(v, variables);
        scorer.score(pi0);
        System.out.println("\t\t# edges for " + pi0 + scorer.getNumEdges());
        while ((perm = gen.next()) != null) {
            List<Node> p = GraphUtils.asList(perm, variables);
            scorer.score(p);
            if (scorer.score() > maxScore) {
                maxScore = scorer.score();
                maxP = p;
                frugalCpdags.clear();
            }
            if (scorer.score() != maxScore) continue;
            frugalCpdags.add(scorer.getGraph(true));
        }
        System.out.println("\t\t# frugal cpdags BY SP = " + frugalCpdags.size());
        System.out.println("\t\t# edges for frugal = " + ((Graph)frugalCpdags.iterator().next()).getNumEdges());
        if (frugalCpdags.size() == 1) {
            System.out.println("\t!!!! U-FRUGAL BY SP");
        }
        if (this.verbose) {
            System.out.println("# Edges = " + scorer.getNumEdges() + " Score = " + scorer.score() + " (SP) Elapsed " + (double)(MillisecondTimes.timeMillis() - this.start) / 1000.0 + " sp");
        }
        System.out.println("Frugal CPDAGs: ");
        for (Graph g : frugalCpdags) {
            System.out.println(g);
        }
        return maxP;
    }

    @NotNull
    public Graph getGraph(boolean cpDag) {
        if (this.scorer == null) {
            throw new IllegalArgumentException("Please run algorithm first.");
        }
        Graph graph = this.scorer.getGraph(cpDag);
        graph.addAttribute("# edges", graph.getNumEdges());
        return graph;
    }

    public void setCacheScores(boolean cachingScores) {
        this.cachingScores = cachingScores;
    }

    public void setNumStarts(int numStarts) {
        this.numStarts = numStarts;
    }

    public Method getMethod() {
        return this.method;
    }

    public void setMethod(Method method) {
        this.method = method;
    }

    public List<Node> getVariables() {
        return this.variables;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public void setUseDataOrder(boolean useDataOrder) {
        this.useDataOrder = useDataOrder;
    }

    public void setDepth(int depth) {
        if (depth < -1) {
            throw new IllegalArgumentException("Depth should be >= -1.");
        }
        this.depth = depth;
    }

    public void setUseScore(boolean useScore) {
        this.useScore = useScore;
    }

    private boolean violatesKnowledge(List<Node> order) {
        if (!this.knowledge.isEmpty()) {
            for (int i = 0; i < order.size(); ++i) {
                for (int j = i + 1; j < order.size(); ++j) {
                    if (!this.knowledge.isForbidden(order.get(i).getName(), order.get(j).getName())) continue;
                    return true;
                }
            }
        }
        return false;
    }

    private void espDfs(@NotNull TeyssierScorer scorer, double sOld, int depth, int currentDepth) {
        for (int i = 0; i < scorer.size() - 1; ++i) {
            List<Node> pi = scorer.getPi();
            scorer.swap(scorer.get(i), scorer.get(i + 1));
            if (this.violatesKnowledge(scorer.getPi())) {
                scorer.score(pi);
                continue;
            }
            double sNew = scorer.score();
            if (sNew == sOld && currentDepth < depth) {
                this.espDfs(scorer, sNew, depth, currentDepth + 1);
                sNew = scorer.score();
            }
            if (!(sNew <= sOld)) break;
            scorer.score(pi);
        }
    }

    private void graspDfs(@NotNull TeyssierScorer scorer, double sOld, int depth, int currentDepth, boolean checkCovering) {
        for (OrderedPair<Node> adj : scorer.getEdges()) {
            Node x = adj.getFirst();
            Node y = adj.getSecond();
            if (checkCovering && !scorer.coveredEdge(x, y)) continue;
            scorer.bookmark(currentDepth);
            scorer.swaptuck(x, y);
            if (this.violatesKnowledge(scorer.getPi())) {
                scorer.goToBookmark(currentDepth);
                continue;
            }
            double sNew = scorer.score();
            if (sNew == sOld && currentDepth < depth) {
                this.graspDfs(scorer, sNew, depth, currentDepth + 1, checkCovering);
                sNew = scorer.score();
            }
            if (!(sNew <= sOld)) break;
            scorer.goToBookmark(currentDepth);
        }
    }

    public void setNumRounds(int numRounds) {
        this.numRounds = numRounds;
    }

    public void setUsePearl(boolean usePearl) {
        this.usePearl = usePearl;
    }

    public void setNumVariables(int numVars) {
        this.numVars = numVars;
    }

    public static enum Method {
        RCG,
        GSP,
        ESP,
        SP;

    }
}

