/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionCovariance;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.Hbsms;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.sem.DagScorer;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.Scorer;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

public final class HbmsBeam
implements Hbsms {
    private final CovarianceMatrix cov;
    private Knowledge knowledge;
    private final Graph externalGraph;
    private Graph graph;
    private double alpha = 0.05;
    private double highPValueAlpha = 0.05;
    private final NumberFormat nf = new DecimalFormat("0.0#########");
    private Graph trueModel;
    private SemIm originalSemIm;
    private SemIm newSemIm;
    private final Scorer scorer;
    private int beamWidth = 1;

    public HbmsBeam(Graph graph, DataSet data, Knowledge knowledge) {
        if (graph == null) {
            graph = new EdgeListGraph(data.getVariables());
        }
        this.knowledge = knowledge;
        this.graph = graph;
        this.externalGraph = new EdgeListGraph(graph);
        this.cov = new CovarianceMatrix(data);
        this.scorer = new DagScorer(this.cov);
    }

    public HbmsBeam(Graph graph, CovarianceMatrix cov, Knowledge knowledge) {
        if (graph == null) {
            graph = new EdgeListGraph(cov.getVariables());
        }
        this.knowledge = knowledge;
        this.graph = graph;
        this.externalGraph = new EdgeListGraph(graph);
        this.cov = cov;
        this.scorer = new DagScorer(cov);
    }

    @Override
    public Graph search() {
        EdgeListGraph _graph = new EdgeListGraph(this.externalGraph);
        this.addRequiredEdges(_graph);
        Graph bestGraph = SearchGraphUtils.dagFromCPDAG(_graph);
        if (this.getGraph().getNumEdges() == 0) {
            System.out.println("Found one!");
        }
        if (_graph.getNumEdges() == 0) {
            System.out.println("Found one!");
        }
        if (bestGraph.getNumEdges() == 0) {
            System.out.println("Found one!");
        }
        Score score0 = this.scoreGraph(bestGraph);
        double bestScore = score0.getScore();
        this.originalSemIm = score0.getEstimatedSem();
        System.out.println("Graph from search = " + bestGraph);
        if (this.trueModel != null) {
            this.trueModel = GraphUtils.replaceNodes(this.trueModel, bestGraph.getNodes());
            this.trueModel = SearchGraphUtils.cpdagForDag(this.trueModel);
        }
        System.out.println("Initial Score = " + this.nf.format(bestScore));
        MeekRules meekRules = new MeekRules();
        meekRules.setKnowledge(this.getKnowledge());
        bestGraph = this.increaseScoreLoop(bestGraph, this.getAlpha());
        bestGraph = this.removeZeroEdges(bestGraph);
        Score score = this.scoreGraph(bestGraph);
        this.newSemIm = score.getEstimatedSem();
        return bestGraph;
    }

    private Graph increaseScoreLoop(Graph bestGraph, double alpha) {
        System.out.println("Increase score loop2");
        double initialScore = this.scoreGraph(bestGraph).getScore();
        HashMap<Graph, Double> S = new HashMap<Graph, Double>();
        S.put(bestGraph, initialScore);
        boolean changed = true;
        block0: while (changed) {
            changed = false;
            for (Graph s : new HashMap(S).keySet()) {
                ArrayList<Move> moves = new ArrayList<Move>();
                moves.addAll(this.getAddMoves(s));
                moves.addAll(this.getRedirectMoves(s));
                boolean found = false;
                for (Move move : moves) {
                    Graph graph = this.makeMove(s, move);
                    if (this.getKnowledge().isViolatedBy(graph) || this.isCheckingCycles() && graph.paths().existsDirectedCycle() || S.containsKey(graph)) continue;
                    Score _score = this.scoreGraph(graph);
                    double score = _score.getScore();
                    if (S.keySet().size() < this.beamWidth) {
                        S.put(graph, score);
                        changed = true;
                        continue;
                    }
                    if (!this.increasesScore(S, score)) continue;
                    System.out.println("Increase score (" + (Object)((Object)move.getType()) + "): score = " + score);
                    this.removeMinimalScore(S);
                    S.put(graph, score);
                    changed = true;
                    if (!(this.scoreGraph(this.removeZeroEdges(graph)).getPValue() > alpha)) continue;
                    found = true;
                }
                if (!found) continue;
                break block0;
            }
        }
        System.out.println("DOF = " + this.scoreGraph(this.maximumScore(S)).getDof());
        this.graph = this.maximumScore(S);
        return this.maximumScore(S);
    }

    private boolean increasesScore(Map<Graph, Double> s, double score) {
        double minScore = Double.MAX_VALUE;
        for (Graph graph : s.keySet()) {
            if (!(s.get(graph) < minScore)) continue;
            minScore = s.get(graph);
        }
        return score > minScore;
    }

    private Graph maximumScore(Map<Graph, Double> s) {
        double maxScore = Double.NEGATIVE_INFINITY;
        Graph maxGraph = null;
        for (Graph graph : s.keySet()) {
            if (graph == null) {
                throw new NullPointerException();
            }
            double score = s.get(graph);
            if (!(score > maxScore)) continue;
            maxScore = score;
            maxGraph = graph;
        }
        return maxGraph;
    }

    private void removeMinimalScore(Map<Graph, Double> s) {
        double minScore = 2.147483647E9;
        Graph minGraph = null;
        for (Graph graph : s.keySet()) {
            if (!(s.get(graph) < minScore)) continue;
            minScore = s.get(graph);
            minGraph = graph;
        }
        s.remove(minGraph);
    }

    public Graph removeZeroEdges(Graph bestGraph) {
        boolean changed = true;
        EdgeListGraph graph = new EdgeListGraph(bestGraph);
        while (changed) {
            changed = false;
            Score score = this.scoreGraph(graph);
            SemIm estSem = score.getEstimatedSem();
            for (Parameter param : estSem.getSemPm().getParameters()) {
                List<Node> parents;
                RegressionCovariance regression;
                RegressionResult result;
                double p;
                Node child;
                Node parent;
                Node nodeB;
                if (param.getType() != ParamType.COEF) continue;
                Node nodeA = param.getNodeA();
                if (this.graph.isParentOf(nodeA, nodeB = param.getNodeB())) {
                    parent = nodeA;
                    child = nodeB;
                } else {
                    parent = nodeB;
                    child = nodeA;
                }
                if (!((p = (result = (regression = new RegressionCovariance(this.cov)).regress(child, parents = graph.getParents(child))).getP()[parents.indexOf(parent) + 1]) > this.getHighPValueAlpha())) continue;
                Edge edge = graph.getEdge(param.getNodeA(), param.getNodeB());
                if (this.getKnowledge().isRequired(edge.getNode1().getName(), edge.getNode2().getName())) {
                    System.out.println("Not removing " + edge + " because it is required.");
                    TetradLogger.getInstance().log("details", "Not removing " + edge + " because it is required.");
                    continue;
                }
                System.out.println("Removing edge " + edge + " because it has p = " + p);
                TetradLogger.getInstance().log("details", "Removing edge " + edge + " because it has p = " + p);
                graph.removeEdge(edge);
                changed = true;
            }
        }
        return graph;
    }

    private Graph makeMove(Graph graph, Move move) {
        graph = new EdgeListGraph(graph);
        Edge firstEdge = move.getFirstEdge();
        Edge secondEdge = move.getSecondEdge();
        if (firstEdge != null && move.getType() == Move.Type.ADD) {
            graph.removeEdge(firstEdge.getNode1(), firstEdge.getNode2());
            graph.addEdge(firstEdge);
        } else if (firstEdge != null && move.getType() == Move.Type.REMOVE) {
            graph.removeEdge(firstEdge);
        } else if (firstEdge != null && move.getType() == Move.Type.DOUBLE_REMOVE) {
            graph.removeEdge(firstEdge);
            graph.removeEdge(secondEdge);
        } else if (firstEdge != null && move.getType() == Move.Type.REDIRECT) {
            graph.removeEdge(graph.getEdge(firstEdge.getNode1(), firstEdge.getNode2()));
            graph.addEdge(firstEdge);
        } else if (firstEdge != null && secondEdge != null && move.getType() == Move.Type.ADD_COLLIDER) {
            Edge existingEdge1 = graph.getEdge(firstEdge.getNode1(), firstEdge.getNode2());
            Edge existingEdge2 = graph.getEdge(secondEdge.getNode1(), secondEdge.getNode2());
            if (existingEdge1 != null) {
                graph.removeEdge(existingEdge1);
            }
            if (existingEdge2 != null) {
                graph.removeEdge(existingEdge2);
            }
            graph.addEdge(firstEdge);
            graph.addEdge(secondEdge);
        } else if (firstEdge != null && secondEdge != null && move.getType() == Move.Type.REMOVE_COLLIDER) {
            graph.removeEdge(firstEdge);
            graph.removeEdge(secondEdge);
        } else if (firstEdge != null && secondEdge != null && move.getType() == Move.Type.SWAP) {
            graph.removeEdge(firstEdge);
            Edge secondEdgeStar = graph.getEdge(secondEdge.getNode1(), secondEdge.getNode2());
            if (secondEdgeStar != null) {
                graph.removeEdge(secondEdgeStar);
            }
            graph.addEdge(secondEdge);
        }
        return graph;
    }

    private List<Move> getAddMoves(Graph graph) {
        ArrayList<Move> moves = new ArrayList<Move>();
        List<Node> nodes = graph.getNodes();
        Collections.sort(nodes);
        for (int i = 0; i < nodes.size(); ++i) {
            for (int j = 0; j < nodes.size(); ++j) {
                if (i == j || graph.isAdjacentTo(nodes.get(i), nodes.get(j)) || this.getKnowledge().isForbidden(nodes.get(i).getName(), nodes.get(j).getName()) || this.getKnowledge().isRequired(nodes.get(j).getName(), nodes.get(i).getName()) || graph.paths().isAncestorOf(nodes.get(j), nodes.get(i))) continue;
                Edge edge = Edges.directedEdge(nodes.get(i), nodes.get(j));
                moves.add(new Move(edge, Move.Type.ADD));
            }
        }
        return moves;
    }

    private List<Move> getRedirectMoves(Graph graph) {
        ArrayList<Move> moves = new ArrayList<Move>();
        ArrayList<Edge> edges = new ArrayList<Edge>(graph.getEdges());
        Collections.sort(edges);
        for (Edge edge : edges) {
            Node i = edge.getNode1();
            Node j = edge.getNode2();
            if (this.knowledge.isForbidden(j.getName(), i.getName()) || this.getKnowledge().isRequired(i.getName(), j.getName()) || graph.paths().isAncestorOf(j, i)) continue;
            moves.add(new Move(Edges.directedEdge(j, i), Move.Type.REDIRECT));
        }
        return moves;
    }

    public Graph getGraph() {
        return this.graph;
    }

    @Override
    public SemIm getOriginalSemIm() {
        return this.originalSemIm;
    }

    @Override
    public SemIm getNewSemIm() {
        return this.newSemIm;
    }

    public double getHighPValueAlpha() {
        return this.highPValueAlpha;
    }

    @Override
    public void setHighPValueAlpha(double highPValueAlpha) {
        this.highPValueAlpha = highPValueAlpha;
    }

    public boolean isCheckingCycles() {
        return true;
    }

    public Score scoreGraph(Graph graph) {
        if (graph == null) {
            return Score.negativeInfinity();
        }
        this.scorer.score(graph);
        return new Score(this.scorer);
    }

    @Override
    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
        if (knowledge.isViolatedBy(this.graph)) {
            throw new IllegalArgumentException("Graph violates knowledge.");
        }
    }

    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public void setBeamWidth(int beamWidth) {
        if (beamWidth < 1) {
            throw new IllegalArgumentException();
        }
        this.beamWidth = beamWidth;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    private void addRequiredEdges(Graph graph) {
        Node nextNode;
        Iterator<Node> itn;
        Node nodeB;
        Node nodeA;
        String b;
        String a;
        KnowledgeEdge next;
        Iterator<KnowledgeEdge> it = this.getKnowledge().requiredEdgesIterator();
        while (it.hasNext()) {
            next = it.next();
            a = next.getFrom();
            b = next.getTo();
            nodeA = null;
            nodeB = null;
            itn = graph.getNodes().iterator();
            while (itn.hasNext() && (nodeA == null || nodeB == null)) {
                nextNode = itn.next();
                if (nextNode.getName().equals(a)) {
                    nodeA = nextNode;
                }
                if (!nextNode.getName().equals(b)) continue;
                nodeB = nextNode;
            }
            if (graph.paths().isAncestorOf(nodeB, nodeA)) continue;
            graph.removeEdge(nodeA, nodeB);
            graph.addDirectedEdge(nodeA, nodeB);
            TetradLogger.getInstance().log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
        }
        it = this.getKnowledge().forbiddenEdgesIterator();
        while (it.hasNext()) {
            next = it.next();
            a = next.getFrom();
            b = next.getTo();
            nodeA = null;
            nodeB = null;
            itn = graph.getNodes().iterator();
            while (itn.hasNext() && (nodeA == null || nodeB == null)) {
                nextNode = itn.next();
                if (nextNode.getName().equals(a)) {
                    nodeA = nextNode;
                }
                if (!nextNode.getName().equals(b)) continue;
                nodeB = nextNode;
            }
            if (nodeA == null || nodeB == null || !graph.isAdjacentTo(nodeA, nodeB) || graph.isChildOf(nodeA, nodeB) || graph.paths().isAncestorOf(nodeA, nodeB)) continue;
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance().log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
        }
    }

    public static class Score {
        private Scorer scorer = null;
        private final double fml;
        private final double chisq;
        private final double bic;
        private int dof;

        public Score(Scorer scorer) {
            this.scorer = scorer;
            this.fml = scorer.getFml();
            this.dof = scorer.getDof();
            int sampleSize = scorer.getSampleSize();
            this.chisq = (double)(sampleSize - 1) * this.getFml();
            this.bic = this.chisq - (double)this.dof * FastMath.log(sampleSize);
        }

        private Score() {
            int sampleSize = 1000;
            this.fml = Double.POSITIVE_INFINITY;
            this.chisq = (double)(sampleSize - 1) * this.fml;
            this.bic = this.chisq - (double)this.dof * FastMath.log(sampleSize);
        }

        public SemIm getEstimatedSem() {
            return this.scorer.getEstSem();
        }

        public double getPValue() {
            return this.scorer.getPValue();
        }

        public double getScore() {
            return -this.bic;
        }

        public double getFml() {
            return this.scorer.getFml();
        }

        public static Score negativeInfinity() {
            return new Score();
        }

        public int getDof() {
            return this.dof;
        }

        public double getChiSquare() {
            return this.chisq;
        }

        public double getBic() {
            return this.bic;
        }
    }

    public static class Move {
        private final Edge edge;
        private Edge secondEdge;
        private final Type type;

        public Move(Edge edge, Type type) {
            this.edge = edge;
            this.type = type;
        }

        public Move(Edge edge, Edge secondEdge, Type type) {
            this.edge = edge;
            this.secondEdge = secondEdge;
            this.type = type;
        }

        public Edge getFirstEdge() {
            return this.edge;
        }

        public Edge getSecondEdge() {
            return this.secondEdge;
        }

        public Type getType() {
            return this.type;
        }

        public String toString() {
            String s = this.secondEdge != null ? this.secondEdge + ", " : "";
            return "<" + this.edge + ", " + s + (Object)((Object)this.type) + ">";
        }

        public static enum Type {
            ADD,
            REMOVE,
            REDIRECT,
            ADD_COLLIDER,
            REMOVE_COLLIDER,
            SWAP,
            DOUBLE_REMOVE;

        }
    }
}

