/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.DagInPatternIterator;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerScattershot;
import edu.cmu.tetrad.sem.SemPm;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class FmlSearch {
    private DataSet data;
    private List<Node> nodes;
    private double alpha;
    private int maxEdges;
    private Dag trueDag;
    private SemIm trueIm;

    public FmlSearch(DataSet data, double alpha, int maxEdges) {
        if (data == null || !data.isContinuous()) {
            throw new IllegalArgumentException("Please provide a continuous dataset.");
        }
        if (alpha < 0.0) {
            throw new IllegalArgumentException("Alpha must be >= 0: " + alpha);
        }
        if (maxEdges < 0) {
            throw new IllegalArgumentException("Max edges must be >= 0: " + maxEdges);
        }
        this.data = data;
        this.nodes = data.getVariables();
        this.alpha = 0.05;
        this.maxEdges = maxEdges;
    }

    public void setTrueDag(Dag dag) {
        this.trueDag = dag;
    }

    public void setTrueIm(SemIm trueIm) {
        this.trueIm = trueIm;
    }

    public List<Graph> search() {
        double pValue;
        ArrayList<Graph> currentPatterns = new ArrayList<Graph>();
        currentPatterns.add(new EdgeListGraph(this.nodes));
        do {
            ArrayList<Map<Edge, Double>> edgeFmls = new ArrayList<Map<Edge, Double>>();
            for (Graph pattern : currentPatterns) {
                Graph dag = SearchGraphUtils.dagFromPattern(pattern);
                Map<Edge, Double> _edgeFmls = new LinkedHashMap();
                edgeFmls.add(_edgeFmls);
                for (int i = 0; i < this.nodes.size(); ++i) {
                    for (int j = 0; j < i; ++j) {
                        Node node2;
                        Node node1 = this.nodes.get(i);
                        if (dag.isAdjacentTo(node1, node2 = this.nodes.get(j))) continue;
                        if (!dag.existsDirectedPathFromTo(node2, node1)) {
                            dag.addDirectedEdge(node1, node2);
                            Edge edge1 = dag.getEdge(node1, node2);
                            _edgeFmls.put(edge1, this.fml(dag));
                            dag.removeEdge(edge1);
                        }
                        if (dag.existsDirectedPathFromTo(node1, node2)) continue;
                        dag.addDirectedEdge(node2, node1);
                        Edge edge2 = dag.getEdge(node2, node1);
                        _edgeFmls.put(edge2, this.fml(dag));
                        dag.removeEdge(edge2);
                    }
                }
            }
            double minFml = Double.POSITIVE_INFINITY;
            for (Map<Edge, Double> _edgeFmls : edgeFmls) {
                for (Edge edge : _edgeFmls.keySet()) {
                    if (!((Double)_edgeFmls.get(edge) < minFml)) continue;
                    minFml = (Double)_edgeFmls.get(edge);
                }
            }
            System.out.println("Min fml = " + minFml);
            for (Map<Edge, Double> _edgeFmls : edgeFmls) {
                for (Edge edge : _edgeFmls.keySet()) {
                    if (!((Double)_edgeFmls.get(edge) >= minFml) || !((Double)_edgeFmls.get(edge) < 1.1 * minFml)) continue;
                    System.out.println("Nearby: " + edge + " FML = " + _edgeFmls.get(edge));
                }
            }
            for (Map<Edge, Double> _edgeFmls : edgeFmls) {
                LinkedHashMap copy = new LinkedHashMap(_edgeFmls);
                for (Edge edge : copy.keySet()) {
                    if (!((Double)_edgeFmls.get(edge) > minFml)) continue;
                    _edgeFmls.remove(edge);
                }
            }
            ArrayList<Graph> newCurrentPatterns = new ArrayList<Graph>();
            for (int i = 0; i < edgeFmls.size(); ++i) {
                Graph pattern = (Graph)currentPatterns.get(i);
                Map _edgeFmls = (Map)edgeFmls.get(i);
                Graph graph = SearchGraphUtils.dagFromPattern(pattern);
                for (Edge edge : _edgeFmls.keySet()) {
                    EdgeListGraph _graph = new EdgeListGraph(graph);
                    _graph.addEdge(edge);
                    Node head = Edges.getDirectedEdgeHead(edge);
                    for (Node node : graph.getNodesOutTo(head, Endpoint.ARROW)) {
                        graph.removeEdge(head, node);
                        graph.addDirectedEdge(head, node);
                    }
                    Graph newPattern = SearchGraphUtils.patternFromDag(_graph);
                    if (newCurrentPatterns.contains(newPattern)) continue;
                    newCurrentPatterns.add(newPattern);
                }
            }
            ArrayList<Graph> newCurrentPatternsRevised = new ArrayList<Graph>();
            for (int i = 0; i < newCurrentPatterns.size(); ++i) {
                Graph pattern = (Graph)newCurrentPatterns.get(i);
                DagInPatternIterator iterator = new DagInPatternIterator(pattern);
                double _minFml = Double.POSITIVE_INFINITY;
                Graph aBestDag = null;
                while (iterator.hasNext()) {
                    Graph dag = iterator.next();
                    double _fml = this.fml(dag);
                    if (!(_fml < _minFml)) continue;
                    _minFml = _fml;
                    aBestDag = dag;
                }
                newCurrentPatternsRevised.add(SearchGraphUtils.patternFromDag(aBestDag));
            }
            currentPatterns = newCurrentPatternsRevised;
            System.out.println("New current Patterns = " + currentPatterns);
            if (!currentPatterns.isEmpty()) continue;
            return new ArrayList<Graph>();
        } while (!((pValue = this.pValue(SearchGraphUtils.dagFromPattern((Graph)currentPatterns.get(0)))) > this.alpha));
        System.out.println("P value of found models = " + pValue);
        ArrayList<Graph> trimmedPatterns = new ArrayList<Graph>();
        for (Graph pattern : currentPatterns) {
            trimmedPatterns.add(this.trimPattern(pattern));
        }
        currentPatterns = trimmedPatterns;
        System.out.println("True DAG " + this.trueDag);
        System.out.println("FML = " + this.fml(this.trueDag));
        System.out.println("P Value of true model = " + this.pValue(this.trueDag));
        for (int i = 0; i < currentPatterns.size(); ++i) {
            System.out.println("Output pattern # " + (i + 1));
            System.out.println(currentPatterns.get(i));
            System.out.println("P Value of that = " + this.pValue(SearchGraphUtils.dagFromPattern((Graph)currentPatterns.get(i))));
        }
        return currentPatterns;
    }

    private Graph trimPattern(Graph pattern) {
        Graph dag = SearchGraphUtils.dagFromPattern(pattern);
        for (Edge edge : dag.getEdges()) {
            dag.removeEdge(edge);
            double _pValue = this.pValue(dag);
            if (_pValue < this.alpha) {
                dag.addEdge(edge);
                continue;
            }
            System.out.println("Removing " + edge);
        }
        return SearchGraphUtils.patternFromDag(dag);
    }

    private double fml(Graph graph) {
        SemPm semPm = new SemPm(graph);
        SemEstimator semEstimator = new SemEstimator(this.data, semPm, (SemOptimizer)new SemOptimizerScattershot());
        semEstimator.estimate();
        SemIm estimatedSem = semEstimator.getEstimatedSem();
        return estimatedSem.getFml();
    }

    private double pValue(Graph graph) {
        SemPm semPm = new SemPm(graph);
        SemEstimator semEstimator = new SemEstimator(this.data, semPm, (SemOptimizer)new SemOptimizerScattershot());
        semEstimator.estimate();
        SemIm estimatedSem = semEstimator.getEstimatedSem();
        return estimatedSem.getPValue();
    }

    public Graph convertToPattern(List<Dag> dags) {
        if (dags == null || dags.isEmpty()) {
            return null;
        }
        EdgeListGraph pattern = new EdgeListGraph(dags.get(0));
        for (int i = 1; i < dags.size(); ++i) {
            Dag dag = dags.get(i);
            for (Edge edge : pattern.getEdges()) {
                if (dag.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
                System.out.println("Not all DAGs have the same adjacencies");
                return null;
            }
            for (Edge edge : dag.getEdges()) {
                if (pattern.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
                System.out.println("Not all DAGs have the same adjacencies");
                return null;
            }
            for (Edge patternEdge : pattern.getEdges()) {
                if (Edges.isUndirectedEdge(patternEdge)) continue;
                Node node1 = patternEdge.getNode1();
                Node node2 = patternEdge.getNode2();
                Edge dagEdge = dag.getEdge(node1, node2);
                if (Edges.getDirectedEdgeHead(patternEdge) == Edges.getDirectedEdgeHead(dagEdge)) continue;
                pattern.removeEdge(patternEdge);
                pattern.addUndirectedEdge(node1, node2);
            }
        }
        return pattern;
    }
}

