/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.CombinationGenerator;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;

public final class PValueImprover2 {
    private DataSet dataSet;
    private Knowledge knowledge;
    private SemIm semIm;
    private Graph trueDag;

    public PValueImprover2(SemIm semIm, DataSet data, Knowledge knowledge) {
        this.semIm = semIm;
        this.dataSet = data;
        this.knowledge = knowledge;
    }

    public PValueImprover2(DataSet data, Knowledge knowledge) {
        this.dataSet = data;
        this.knowledge = knowledge;
    }

    public SemIm search() {
        double alpha = 0.05;
        EdgeListGraph graph = new EdgeListGraph(this.semIm.getSemPm().getGraph());
        double pValue = this.scoreGraph(graph).getPValue();
        System.out.println(graph);
        System.out.println(pValue);
        pValue = this.scoreGraph(graph).getPValue();
        this.addRequiredEdges(graph);
        this.removeEdgesToLowerFml(graph);
        if (pValue < alpha) {
            this.adjustOrientations(alpha, graph);
        }
        boolean changed = true;
        while (changed) {
            changed = false;
            this.adjustOrientations(alpha, graph);
            changed = this.removeEdges(alpha, graph, changed);
        }
        Score score = this.scoreGraph(graph);
        return score.getEstimatedSem();
    }

    public SemIm search2() {
        EdgeListGraph graph = new EdgeListGraph(this.semIm.getSemPm().getGraph());
        for (Edge edge : graph.getEdges()) {
            ArrayList<Node> nodes = new ArrayList<Node>();
            nodes.add(edge.getNode1());
            nodes.add(edge.getNode2());
            Node n1 = this.trueDag.getNode(edge.getNode1().getName());
            Node n2 = this.trueDag.getNode(edge.getNode2().getName());
            Score score = this.scoreGraph(graph, nodes, this.getTrueDag());
            Edge trueEdge = this.trueDag.getEdge(n1, n2);
            List<List<Node>> treks = GraphUtils.treks(this.trueDag, n1, n2);
            System.out.println(edge + " pValue = " + score.getPValue() + " in graph : " + trueEdge);
            for (List<Node> trek : treks) {
                System.out.println("\t" + trek);
            }
        }
        System.out.println("Other edges...");
        EdgeListGraph fullyConnected = new EdgeListGraph(graph.getNodes());
        fullyConnected.fullyConnect(Endpoint.TAIL);
        List<Edge> allEdges = fullyConnected.getEdges();
        for (Edge edge : allEdges) {
            if (graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
            ArrayList<Node> nodes = new ArrayList<Node>();
            nodes.add(edge.getNode1());
            nodes.add(edge.getNode2());
            Node n1 = this.trueDag.getNode(edge.getNode1().getName());
            Node n2 = this.trueDag.getNode(edge.getNode2().getName());
            Score score = this.scoreGraph(graph, nodes, this.getTrueDag());
            Edge trueEdge = this.trueDag.getEdge(n1, n2);
            if (!(score.getPValue() < 0.005)) continue;
            List<List<Node>> treks = GraphUtils.treks(this.trueDag, n1, n2);
            System.out.println(edge + " pValue = " + score.getPValue() + " in graph : " + trueEdge);
            for (List<Node> trek : treks) {
                System.out.println("\t" + trek);
            }
        }
        return this.semIm;
    }

    public Graph search3() {
        double alpha = 0.005;
        List<Node> allNodes = this.dataSet.getVariables();
        EdgeListGraph fullyConnected = new EdgeListGraph(allNodes);
        fullyConnected.fullyConnect(Endpoint.TAIL);
        List<Edge> allEdges = fullyConnected.getEdges();
        EdgeListGraph trekGraph = new EdgeListGraph(allNodes);
        for (Edge edge : allEdges) {
            if (trekGraph.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
            ArrayList<Node> nodes = new ArrayList<Node>();
            nodes.add(edge.getNode1());
            nodes.add(edge.getNode2());
            EdgeListGraph testGraph = new EdgeListGraph(nodes);
            Score score = this.scoreGraph(testGraph);
            if (!(score.getPValue() < alpha)) continue;
            trekGraph.addEdge(edge);
        }
        System.out.println(trekGraph);
        EdgeListGraph reducedGraph = new EdgeListGraph();
        for (Node node : trekGraph.getNodes()) {
            System.out.println("Adding node to reduced graph: " + node);
            reducedGraph.addNode(node);
            for (Edge edge : trekGraph.getEdges(node)) {
                if (!reducedGraph.containsNode(edge.getDistalNode(node)) || reducedGraph.getEdge(edge.getNode1(), edge.getNode2()) != null) continue;
                reducedGraph.addUndirectedEdge(edge.getNode1(), edge.getNode2());
                System.out.println("Adding edge " + reducedGraph.getEdge(edge.getNode1(), edge.getNode2()));
            }
            for (Edge edge : reducedGraph.getEdges()) {
                Node node2;
                System.out.println("### + " + edge);
                Node node1 = edge.getNode1();
                if (!reducedGraph.isAdjacentTo(node1, node2 = edge.getNode2())) continue;
                LinkedHashSet<Node> adj1 = new LinkedHashSet<Node>(reducedGraph.getAdjacentNodes(node1));
                LinkedHashSet<Node> adj2 = new LinkedHashSet<Node>(reducedGraph.getAdjacentNodes(node2));
                adj1.retainAll(adj2);
                LinkedList<Node> commonAdj = new LinkedList<Node>(adj1);
                System.out.println("Common adjacencies for " + node1 + " and " + node2 + ": " + commonAdj);
                if (commonAdj.isEmpty()) continue;
                LinkedList<Node> nodes = new LinkedList<Node>();
                nodes.add((Node)commonAdj.get(0));
                nodes.add(node1);
                nodes.add(node2);
                EdgeListGraph subGraph = new EdgeListGraph(nodes);
                for (int i = 0; i < nodes.size(); ++i) {
                    for (int j = i + 1; j < nodes.size(); ++j) {
                        if (!reducedGraph.isAdjacentTo((Node)nodes.get(i), (Node)nodes.get(j))) continue;
                        subGraph.addDirectedEdge((Node)nodes.get(i), (Node)nodes.get(j));
                    }
                }
                subGraph.removeEdge(node1, node2);
                Score score = this.scoreGraph(subGraph);
                System.out.println("p = " + score.getPValue());
                if (!(score.getPValue() > alpha)) continue;
                Edge _edge = reducedGraph.getEdge(node1, node2);
                reducedGraph.removeEdge(edge);
                System.out.println("Removing edge " + _edge);
            }
        }
        for (Edge edge : reducedGraph.getEdges()) {
            reducedGraph.removeEdge(edge);
            reducedGraph.addUndirectedEdge(edge.getNode1(), edge.getNode2());
        }
        PValueImprover2.orientColliders(trekGraph, reducedGraph);
        new MeekRules().orientImplied(reducedGraph);
        return reducedGraph;
    }

    public static void orientColliders(Graph trekMap, Graph graph) {
        List<Node> nodes = graph.getNodes();
        for (Node a : nodes) {
            int[] combination;
            List<Node> adjacentNodes = graph.getAdjacentNodes(a);
            if (adjacentNodes.size() < 2) continue;
            ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
            while ((combination = cg.next()) != null) {
                Node c;
                Node b = adjacentNodes.get(combination[0]);
                if (graph.isAdjacentTo(b, c = adjacentNodes.get(combination[1])) || trekMap.isAdjacentTo(b, c)) continue;
                graph.setEndpoint(b, a, Endpoint.ARROW);
                graph.setEndpoint(c, a, Endpoint.ARROW);
            }
        }
    }

    private void addRequiredEdges(Graph graph) {
        List<Node> nodes = graph.getNodes();
        for (int i = 0; i < nodes.size(); ++i) {
            for (int j = 0; j < nodes.size(); ++j) {
                if (i == j || !this.knowledge.edgeRequired(nodes.get(i).getName(), nodes.get(j).getName()) || graph.isAdjacentTo(nodes.get(i), nodes.get(j))) continue;
                graph.addDirectedEdge(nodes.get(i), nodes.get(j));
            }
        }
    }

    private void addEdges(double alpha, Graph graph) {
        List<Node> nodes = graph.getNodes();
        Score score = this.scoreGraph(graph);
        double bestFml = score.getFml();
        double bestPValue = score.getPValue();
        Edge bestEdge = null;
        int n = -1;
        double ratio = 1.0;
        while (bestPValue < alpha && ++n < 3) {
            for (int i = 0; i < graph.getNodes().size(); ++i) {
                for (int j = 0; j < graph.getNodes().size(); ++j) {
                    if (i == j || graph.isAdjacentTo(nodes.get(i), nodes.get(j)) || this.knowledge.edgeForbidden(nodes.get(i).getName(), nodes.get(j).getName())) continue;
                    Edge edge = Edges.directedEdge(nodes.get(i), nodes.get(j));
                    graph.addEdge(edge);
                    Score _score = this.scoreGraph(graph);
                    double newFml = _score.getFml();
                    double newPValue = _score.getPValue();
                    if (newFml < bestFml) {
                        System.out.println("Ratio = " + newFml / bestFml);
                        ratio = newFml / bestFml;
                        bestFml = newFml;
                        bestPValue = newPValue;
                        bestEdge = edge;
                    }
                    graph.removeEdge(edge);
                }
            }
            if (bestEdge == null) {
                return;
            }
            graph.addEdge(bestEdge);
            System.out.println("Adding edge " + bestEdge + " p = " + bestPValue);
        }
    }

    private void adjustOrientations(double alpha, Graph graph) {
        for (Node node : graph.getNodes()) {
            int[] comb;
            System.out.println("Node " + node);
            List<Node> adj = graph.getAdjacentNodes(node);
            if (adj.size() < 2) continue;
            LinkedList<Edge> adjEdges = new LinkedList<Edge>();
            for (Node _node : adj) {
                adjEdges.add(graph.getEdge(node, _node));
            }
            int[] dims = new int[adj.size()];
            for (int i = 0; i < adj.size(); ++i) {
                dims[i] = 2;
            }
            CombinationGenerator gen = new CombinationGenerator(dims);
            int[] bestComb = null;
            double bestScore = 0.0;
            block3: while ((comb = gen.next()) != null) {
                for (Node _node : adj) {
                    graph.removeEdge(node, _node);
                }
                StringBuilder buf = new StringBuilder();
                for (int i = 0; i < comb.length; ++i) {
                    Node _node = adj.get(i);
                    if (comb[i] == 0) {
                        if (this.knowledge.edgeForbidden(node.getName(), _node.getName())) continue block3;
                        graph.addDirectedEdge(node, _node);
                        buf.append(graph.getEdge(node, _node)).append(" ");
                        continue;
                    }
                    if (this.knowledge.edgeForbidden(_node.getName(), node.getName())) continue block3;
                    graph.addDirectedEdge(_node, node);
                    buf.append(graph.getEdge(node, _node)).append(" ");
                }
                Score score = this.scoreGraph(graph);
                double pValue = score.getPValue();
                if (!(pValue > bestScore) || !(pValue > alpha)) continue;
                bestComb = new int[comb.length];
                System.arraycopy(comb, 0, bestComb, 0, comb.length);
                bestScore = pValue;
            }
            if (bestComb == null) {
                for (Node _node : adj) {
                    graph.removeEdge(node, _node);
                }
                for (Edge edge : adjEdges) {
                    graph.addEdge(edge);
                }
                continue;
            }
            for (Node _node : adj) {
                graph.removeEdge(node, _node);
            }
            for (int i = 0; i < bestComb.length; ++i) {
                Node _node = adj.get(i);
                if (bestComb[i] == false) {
                    graph.addDirectedEdge(node, _node);
                    System.out.print(graph.getEdge(node, _node) + " ");
                    continue;
                }
                graph.addDirectedEdge(_node, node);
                System.out.print(graph.getEdge(node, _node) + " ");
            }
            System.out.print(" chosen... " + this.scoreGraph(graph).getPValue() + "\t");
            System.out.println(graph);
        }
    }

    private void adjustOrientations2(double alpha, Graph graph) {
        int[] comb;
        double pValue = 0.0;
        List<Edge> adj = graph.getEdges();
        int[] dims = new int[adj.size()];
        for (int i = 0; i < adj.size(); ++i) {
            dims[i] = 2;
        }
        CombinationGenerator gen = new CombinationGenerator(dims);
        int[] bestComb = null;
        double bestScore = 0.0;
        while ((comb = gen.next()) != null) {
            for (Edge edge : adj) {
                graph.removeEdge(edge.getNode1(), edge.getNode2());
            }
            StringBuilder buf = new StringBuilder();
            for (int i = 0; i < comb.length; ++i) {
                if (comb[i] == 0) {
                    graph.addDirectedEdge(adj.get(i).getNode1(), adj.get(i).getNode2());
                    continue;
                }
                graph.addDirectedEdge(adj.get(i).getNode2(), adj.get(i).getNode1());
            }
            if (!(pValue >= bestScore)) continue;
            bestComb = new int[comb.length];
            System.arraycopy(comb, 0, bestComb, 0, comb.length);
            bestScore = pValue;
        }
        for (Edge edge : adj) {
            graph.removeEdge(edge.getNode1(), edge.getNode2());
        }
        for (int i = 0; i < bestComb.length; ++i) {
            if (bestComb[i] == false) {
                graph.addDirectedEdge(adj.get(i).getNode1(), adj.get(i).getNode2());
                continue;
            }
            graph.addDirectedEdge(adj.get(i).getNode2(), adj.get(i).getNode1());
        }
        System.out.print(" chosen... " + this.scoreGraph(graph).getPValue() + "\t");
        System.out.println(graph);
    }

    private void removeEdgesToLowerFml(Graph graph) {
        System.out.println("Removing edges to lower FML.");
        double fml = this.scoreGraph(graph).getFml();
        for (Edge edge : graph.getEdges()) {
            if (this.knowledge.edgeRequired(edge.getNode1().getName(), edge.getNode2().getName())) continue;
            graph.removeEdge(edge);
            Score score = this.scoreGraph(graph);
            double _fml = score.getFml();
            System.out.println(edge + ": " + score.getPValue());
            if (_fml < fml) {
                System.out.println("Removed: " + edge + " p = " + score.getPValue());
                fml = _fml;
                continue;
            }
            graph.addEdge(edge);
        }
        System.out.println(this.scoreGraph(graph).getPValue());
    }

    private boolean removeEdges(double alpha, Graph graph, boolean changed) {
        for (Edge edge : graph.getEdges()) {
            if (this.knowledge.edgeRequired(edge.getNode1().getName(), edge.getNode2().getName())) continue;
            graph.removeEdge(edge);
            Score score = this.scoreGraph(graph);
            System.out.println(edge + ": " + score.getPValue());
            if (score.getPValue() > alpha) {
                System.out.println("Removed: " + edge);
                changed = true;
            } else {
                graph.addEdge(edge);
            }
            System.out.println(this.scoreGraph(graph).getPValue());
        }
        return changed;
    }

    private Score scoreGraph(Graph graph) {
        SemPm semPm = new SemPm(graph);
        SemEstimator semEstimator = new SemEstimator(this.dataSet, semPm, (SemOptimizer)new SemOptimizerEm());
        semEstimator.estimate();
        SemIm estimatedSem = semEstimator.getEstimatedSem();
        return new Score(estimatedSem, estimatedSem.getPValue(), estimatedSem.getFml());
    }

    private Score scoreGraph(Graph graph, List<Node> nodes, Graph trueDag) {
        Graph subGraph = graph.subgraph(nodes);
        subGraph.removeEdges(subGraph.getEdges());
        SemPm semPm = new SemPm(subGraph);
        SemEstimator semEstimator = new SemEstimator(this.dataSet, semPm, (SemOptimizer)new SemOptimizerEm());
        semEstimator.estimate();
        SemIm estimatedSem = semEstimator.getEstimatedSem();
        return new Score(estimatedSem, estimatedSem.getPValue(), estimatedSem.getFml());
    }

    public double secondDerivative(SemIm semIm, Parameter p1, Parameter p2) {
        double delta = 0.005;
        SemFittingFunction fcn = new SemFittingFunction(semIm);
        List<Parameter> freeParameters = semIm.getFreeParameters();
        int i = freeParameters.indexOf(p1);
        int j = freeParameters.indexOf(p2);
        double[] params = semIm.getFreeParamValues();
        return this.secondPartialDerivative(fcn, i, j, params, delta);
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
    }

    public Graph getTrueDag() {
        return this.trueDag;
    }

    public void setTrueDag(Graph trueDag) {
        this.trueDag = trueDag;
    }

    public double secondPartialDerivative(FittingFunction f, int i, int j, double[] p, double delt) {
        double[] arg = new double[p.length];
        System.arraycopy(p, 0, arg, 0, p.length);
        int n = i;
        arg[n] = arg[n] + delt;
        int n2 = j;
        arg[n2] = arg[n2] + delt;
        double ff1 = f.evaluate(arg);
        int n3 = j;
        arg[n3] = arg[n3] - 2.0 * delt;
        double ff2 = f.evaluate(arg);
        int n4 = i;
        arg[n4] = arg[n4] - 2.0 * delt;
        int n5 = j;
        arg[n5] = arg[n5] + 2.0 * delt;
        double ff3 = f.evaluate(arg);
        int n6 = j;
        arg[n6] = arg[n6] - 2.0 * delt;
        double ff4 = f.evaluate(arg);
        double fsSum = ff1 - ff2 - ff3 + ff4;
        return fsSum / (4.0 * delt * delt);
    }

    static class SemFittingFunction
    implements FittingFunction {
        private final SemIm sem;

        public SemFittingFunction(SemIm sem) {
            this.sem = sem;
        }

        @Override
        public double evaluate(double[] parameters) {
            this.sem.setFreeParamValues(parameters);
            return this.sem.getFml();
        }

        @Override
        public int getNumParameters() {
            return this.sem.getNumFreeParams();
        }
    }

    static interface FittingFunction {
        public double evaluate(double[] var1);

        public int getNumParameters();
    }

    private static class Score {
        private SemIm estimatedSem;
        private double pValue;
        private double fml;

        public Score(SemIm estimatedSem, double score, double fml) {
            this.estimatedSem = estimatedSem;
            this.pValue = score;
            this.fml = fml;
        }

        public SemIm getEstimatedSem() {
            return this.estimatedSem;
        }

        public double getPValue() {
            return this.pValue;
        }

        public double getFml() {
            return this.fml;
        }
    }
}

