/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphTransforms;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.utils.DagInCpcagIterator;
import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.search.work_in_progress.Hbsms;
import edu.cmu.tetrad.sem.DagScorer;
import edu.cmu.tetrad.sem.Scorer;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;

public final class HbsmsGes
implements Hbsms {
    private final Graph graph;
    private final NumberFormat nf = new DecimalFormat("0.0#########");
    private final Set<GraphWithPValue> significantModels = new HashSet<GraphWithPValue>();
    private final Scorer scorer;
    private Knowledge knowledge = new Knowledge();
    private double alpha = 0.05;
    private SemIm originalSemIm;
    private SemIm newSemIm;

    public HbsmsGes(Graph graph, DataSet data) {
        if (graph == null) {
            throw new NullPointerException("Graph not specified.");
        }
        boolean allowArbitraryOrientations = true;
        boolean allowNewColliders = true;
        DagInCpcagIterator iterator = new DagInCpcagIterator(graph, this.getKnowledge(), true, true);
        graph = iterator.next();
        if (GraphUtils.containsBidirectedEdge(graph = GraphTransforms.cpdagForDag(graph))) {
            throw new IllegalArgumentException("Contains bidirected edge.");
        }
        this.graph = graph;
        this.scorer = new DagScorer(data);
    }

    private static boolean validDelete(Node x, Node y, Set<Node> h, Graph graph) {
        List<Node> naYXH = HbsmsGes.findNaYX(x, y, graph);
        naYXH.removeAll(h);
        return GraphUtils.isClique(naYXH, graph);
    }

    private static List<Node> getTNeighbors(Node x, Node y, Graph graph) {
        LinkedList<Node> tNeighbors = new LinkedList<Node>(graph.getAdjacentNodes(y));
        tNeighbors.removeAll(graph.getAdjacentNodes(x));
        for (int i = tNeighbors.size() - 1; i >= 0; --i) {
            Node z = (Node)tNeighbors.get(i);
            Edge edge = graph.getEdge(y, z);
            if (Edges.isUndirectedEdge(edge)) continue;
            tNeighbors.remove(z);
        }
        return tNeighbors;
    }

    private static List<Node> getHNeighbors(Node x, Node y, Graph graph) {
        LinkedList<Node> hNeighbors = new LinkedList<Node>(graph.getAdjacentNodes(y));
        hNeighbors.retainAll(graph.getAdjacentNodes(x));
        for (int i = hNeighbors.size() - 1; i >= 0; --i) {
            Node z = (Node)hNeighbors.get(i);
            Edge edge = graph.getEdge(y, z);
            if (Edges.isUndirectedEdge(edge)) continue;
            hNeighbors.remove(z);
        }
        return hNeighbors;
    }

    private static List<Node> findNaYX(Node x, Node y, Graph graph) {
        LinkedList<Node> naYX = new LinkedList<Node>(graph.getAdjacentNodes(y));
        naYX.retainAll(graph.getAdjacentNodes(x));
        for (int i = 0; i < naYX.size(); ++i) {
            Node z = (Node)naYX.get(i);
            Edge edge = graph.getEdge(y, z);
            if (Edges.isUndirectedEdge(edge)) continue;
            naYX.remove(z);
        }
        return naYX;
    }

    private static List<Set<Node>> powerSet(List<Node> nodes) {
        ArrayList<Set<Node>> subsets = new ArrayList<Set<Node>>();
        int total = (int)FastMath.pow(2.0, nodes.size());
        for (int i = 0; i < total; ++i) {
            HashSet<Node> newSet = new HashSet<Node>();
            String selection = Integer.toBinaryString(i);
            for (int j = selection.length() - 1; j >= 0; --j) {
                if (selection.charAt(j) != '1') continue;
                newSet.add(nodes.get(selection.length() - j - 1));
            }
            subsets.add(newSet);
        }
        return subsets;
    }

    private void saveModelIfSignificant(Graph graph) {
        double pValue = this.scoreGraph(graph).getPValue();
        if (pValue > this.alpha) {
            this.getSignificantModels().add(new GraphWithPValue(graph, pValue));
        }
    }

    public Score scoreGraph(Graph graph) {
        Graph dag = GraphTransforms.dagFromCPDAG(graph, this.getKnowledge());
        this.scorer.score(dag);
        return new Score(this.scorer);
    }

    public Graph getGraph() {
        return this.graph;
    }

    @Override
    public SemIm getOriginalSemIm() {
        return this.originalSemIm;
    }

    @Override
    public SemIm getNewSemIm() {
        return this.newSemIm;
    }

    @Override
    public void setHighPValueAlpha(double highPValueAlpha) {
    }

    public Score scoreDag(Graph dag) {
        this.scorer.score(dag);
        return new Score(this.scorer);
    }

    @Override
    public Graph search() {
        Score score1 = this.scoreGraph(this.getGraph());
        double score = score1.getScore();
        System.out.println(this.getGraph());
        System.out.println(score);
        this.originalSemIm = score1.getEstimatedSem();
        this.saveModelIfSignificant(this.getGraph());
        score = this.fes(this.getGraph(), score);
        this.bes(this.getGraph(), score);
        Score _score = this.scoreGraph(this.getGraph());
        this.newSemIm = _score.getEstimatedSem();
        return new EdgeListGraph(this.getGraph());
    }

    private double fes(Graph graph, double score) {
        Node x;
        TetradLogger.getInstance().log("info", "** FORWARD EQUIVALENCE SEARCH");
        double bestScore = score;
        TetradLogger.getInstance().log("info", "Initial Score = " + this.nf.format(bestScore));
        Set<Node> t = new HashSet<Node>();
        do {
            Node y = null;
            x = null;
            List<Node> nodes = graph.getNodes();
            RandomUtil.shuffle(nodes);
            for (int i = 0; i < nodes.size(); ++i) {
                Node _x = nodes.get(i);
                for (Node _y : nodes) {
                    if (_x == _y || graph.isAdjacentTo(_x, _y) || this.getKnowledge().isForbidden(_x.getName(), _y.getName())) continue;
                    List<Node> tNeighbors = HbsmsGes.getTNeighbors(_x, _y, graph);
                    List<Set<Node>> tSubsets = HbsmsGes.powerSet(tNeighbors);
                    for (Set<Node> tSubset : tSubsets) {
                        if (this.invalidSetByKnowledge(_x, _y, tSubset, true)) continue;
                        EdgeListGraph graph2 = new EdgeListGraph(graph);
                        this.tryInsert(_x, _y, tSubset, graph2);
                        if (graph2.paths().existsDirectedCycle()) continue;
                        double evalScore = this.scoreGraph(graph2).getScore();
                        TetradLogger.getInstance().log("edgeEvaluations", "Trying to add " + _x + "-->" + _y + " evalScore = " + evalScore);
                        if (!(evalScore > bestScore) || !(evalScore > score) || !this.validInsert(_x, _y, tSubset, graph)) continue;
                        bestScore = evalScore;
                        x = _x;
                        y = _y;
                        t = tSubset;
                    }
                }
            }
            if (x == null) continue;
            score = bestScore;
            this.insert(x, y, t, graph);
            this.rebuildCPDAG(graph);
            this.saveModelIfSignificant(graph);
            if (!(this.scoreGraph(graph).getPValue() > this.alpha)) continue;
            return score;
        } while (x != null);
        return score;
    }

    private void bes(Graph graph, double initialScore) {
        Node x;
        TetradLogger.getInstance().log("info", "** BACKWARD ELIMINATION SEARCH");
        TetradLogger.getInstance().log("info", "Initial Score = " + this.nf.format(initialScore));
        double bestScore = initialScore;
        Set<Node> t = new HashSet<Node>();
        do {
            Node y = null;
            x = null;
            ArrayList<Edge> graphEdges = new ArrayList<Edge>(graph.getEdges());
            RandomUtil.shuffle(graphEdges);
            for (Edge edge : graphEdges) {
                Node _y;
                Node _x;
                if (Edges.isUndirectedEdge(edge)) {
                    _x = edge.getNode1();
                    _y = edge.getNode2();
                } else {
                    _x = Edges.getDirectedEdgeTail(edge);
                    _y = Edges.getDirectedEdgeHead(edge);
                }
                if (!this.getKnowledge().noEdgeRequired(_x.getName(), _y.getName())) continue;
                List<Node> hNeighbors = HbsmsGes.getHNeighbors(_x, _y, graph);
                List<Set<Node>> hSubsets = HbsmsGes.powerSet(hNeighbors);
                for (Set<Node> hSubset : hSubsets) {
                    if (this.invalidSetByKnowledge(_x, _y, hSubset, false)) continue;
                    EdgeListGraph graph2 = new EdgeListGraph(graph);
                    this.tryDelete(_x, _y, hSubset, graph2);
                    double evalScore = this.scoreGraph(graph2).getScore();
                    if (!(evalScore > bestScore) || !HbsmsGes.validDelete(_x, _y, hSubset, graph)) continue;
                    bestScore = evalScore;
                    x = _x;
                    y = _y;
                    t = hSubset;
                }
            }
            if (x == null) continue;
            if (!graph.isAdjacentTo(x, y)) {
                throw new IllegalArgumentException("trying to delete a nonexistent edge! " + x + "---" + y);
            }
            this.delete(x, y, t, graph);
            this.rebuildCPDAG(graph);
            this.saveModelIfSignificant(graph);
        } while (x != null);
    }

    private void tryInsert(Node x, Node y, Set<Node> subset, Graph graph) {
        graph.addDirectedEdge(x, y);
        for (Node t : subset) {
            Edge oldEdge = graph.getEdge(t, y);
            if (!Edges.isUndirectedEdge(oldEdge)) {
                throw new IllegalArgumentException("Should be undirected: " + oldEdge);
            }
            graph.removeEdge(t, y);
            graph.addDirectedEdge(t, y);
            TetradLogger.getInstance().log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(t, y));
        }
    }

    private void tryDelete(Node x, Node y, Set<Node> subset, Graph graph) {
        graph.removeEdge(x, y);
        for (Node h : subset) {
            Edge oldEdge;
            if (Edges.isUndirectedEdge(graph.getEdge(x, h))) {
                graph.removeEdge(x, h);
                graph.addDirectedEdge(x, h);
                oldEdge = graph.getEdge(x, h);
                TetradLogger.getInstance().log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
            }
            if (!Edges.isUndirectedEdge(graph.getEdge(y, h))) continue;
            graph.removeEdge(y, h);
            graph.addDirectedEdge(y, h);
            oldEdge = graph.getEdge(y, h);
            TetradLogger.getInstance().log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
        }
    }

    private void insert(Node x, Node y, Set<Node> subset, Graph graph) {
        if (graph.isAdjacentTo(x, y)) {
            return;
        }
        graph.addDirectedEdge(x, y);
        for (Node t : subset) {
            Edge oldEdge = graph.getEdge(t, y);
            if (!Edges.isUndirectedEdge(oldEdge)) {
                throw new IllegalArgumentException("Should be undirected: " + oldEdge);
            }
            graph.removeEdge(t, y);
            graph.addDirectedEdge(t, y);
            TetradLogger.getInstance().log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(t, y));
        }
    }

    private void delete(Node x, Node y, Set<Node> subset, Graph graph) {
        Edge oldEdge = graph.getEdge(x, y);
        System.out.println(graph.getNumEdges() + ". DELETE " + oldEdge + " " + subset + " (" + this.nf.format(this.scoreGraph(graph).getPValue()) + ")");
        graph.removeEdge(x, y);
        for (Node h : subset) {
            Edge oldEdge2;
            if (Edges.isUndirectedEdge(graph.getEdge(x, h))) {
                graph.removeEdge(x, h);
                graph.addDirectedEdge(x, h);
                oldEdge2 = graph.getEdge(x, h);
                TetradLogger.getInstance().log("directedEdges", "--- Directing " + oldEdge2 + " to " + graph.getEdge(x, h));
            }
            if (!Edges.isUndirectedEdge(graph.getEdge(y, h))) continue;
            graph.removeEdge(y, h);
            graph.addDirectedEdge(y, h);
            oldEdge2 = graph.getEdge(y, h);
            TetradLogger.getInstance().log("directedEdges", "--- Directing " + oldEdge2 + " to " + graph.getEdge(y, h));
        }
    }

    private boolean validInsert(Node x, Node y, Set<Node> subset, Graph graph) {
        LinkedList<Node> naYXT = new LinkedList<Node>(subset);
        naYXT.addAll(HbsmsGes.findNaYX(x, y, graph));
        return GraphUtils.isClique(naYXT, graph) && this.isSemiDirectedBlocked(x, y, naYXT, graph, new HashSet<Node>());
    }

    private boolean invalidSetByKnowledge(Node x, Node y, Set<Node> subset, boolean insertMode) {
        if (insertMode) {
            for (Node aSubset : subset) {
                if (!this.getKnowledge().isForbidden(aSubset.getName(), y.getName())) continue;
                return true;
            }
        } else {
            for (Node nextElement : subset) {
                if (this.getKnowledge().isForbidden(x.getName(), nextElement.getName())) {
                    return true;
                }
                if (!this.getKnowledge().isForbidden(y.getName(), nextElement.getName())) continue;
                return true;
            }
        }
        return false;
    }

    private boolean isSemiDirectedBlocked(Node x, Node y, List<Node> naYXT, Graph graph, Set<Node> marked) {
        if (naYXT.contains(y)) {
            return true;
        }
        if (y == x) {
            return false;
        }
        for (Node node1 : graph.getNodes()) {
            if (node1 == y || marked.contains(node1) || !graph.isAdjacentTo(y, node1) || graph.isParentOf(node1, y)) continue;
            marked.add(node1);
            if (!this.isSemiDirectedBlocked(x, node1, naYXT, graph, marked)) {
                return false;
            }
            marked.remove(node1);
        }
        return true;
    }

    private void rebuildCPDAG(Graph graph) {
        GraphSearchUtils.basicCpdag(graph);
        this.addRequiredEdges(graph);
        this.pdagWithBk(graph, this.getKnowledge());
        TetradLogger.getInstance().log("rebuiltCPDAGs", "Rebuilt CPDAG = " + graph);
    }

    private void pdagWithBk(Graph graph, Knowledge knowledge) {
        MeekRules rules = new MeekRules();
        rules.setKnowledge(knowledge);
        rules.orientImplied(graph);
    }

    private void addRequiredEdges(Graph graph) {
        List<Node> nodes = graph.getNodes();
        for (int i = 0; i < nodes.size(); ++i) {
            for (int j = 0; j < nodes.size(); ++j) {
                if (i == j || !this.getKnowledge().isRequired(nodes.get(i).getName(), nodes.get(j).getName()) || graph.isAdjacentTo(nodes.get(i), nodes.get(j))) continue;
                graph.addDirectedEdge(nodes.get(i), nodes.get(j));
            }
        }
    }

    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public void setBeamWidth(int beamWidth) {
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    @Override
    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public Set<GraphWithPValue> getSignificantModels() {
        return this.significantModels;
    }

    public static class Score {
        private final Scorer scorer;
        private final double pValue;
        private final double fml;
        private final double chisq;
        private final double bic;
        private final int dof;

        public Score(Scorer scorer) {
            this.scorer = scorer;
            this.pValue = scorer.getPValue();
            this.fml = scorer.getFml();
            this.chisq = scorer.getChiSquare();
            this.bic = scorer.getBicScore();
            this.dof = scorer.getDof();
        }

        public SemIm getEstimatedSem() {
            return this.scorer.getEstSem();
        }

        public double getPValue() {
            return this.pValue;
        }

        public double getScore() {
            return -this.bic;
        }

        public double getFml() {
            return this.fml;
        }

        public int getDof() {
            return this.dof;
        }

        public double getChiSquare() {
            return this.chisq;
        }

        public double getBic() {
            return this.bic;
        }
    }

    public static class GraphWithPValue {
        private final Graph graph;
        private final double pValue;

        public GraphWithPValue(Graph graph, double pValue) {
            this.graph = graph;
            this.pValue = pValue;
        }

        public Graph getGraph() {
            return this.graph;
        }

        public double getPValue() {
            return this.pValue;
        }

        public int hashCode() {
            return 17 * this.graph.hashCode();
        }

        public boolean equals(Object o) {
            if (o == null) {
                return false;
            }
            if (!(o instanceof GraphWithPValue)) {
                return false;
            }
            GraphWithPValue p = (GraphWithPValue)o;
            return p.graph.equals(this.graph);
        }
    }
}

