/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.search.DagIterator3;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class MaxPValueSearch {
    private DataSet data;
    private List<Node> nodes;
    private double alpha;
    private int maxEdges;
    private Dag trueDag;
    private SemIm trueIm;

    public MaxPValueSearch(DataSet data, double alpha, int maxEdges) {
        if (data == null || !data.isContinuous()) {
            throw new IllegalArgumentException("Please provide a continuous dataset.");
        }
        if (alpha < 0.0) {
            throw new IllegalArgumentException("Alpha must be >= 0: " + alpha);
        }
        if (maxEdges < 0) {
            throw new IllegalArgumentException("Max edges must be >= 0: " + maxEdges);
        }
        this.data = data;
        this.nodes = data.getVariables();
        this.alpha = 1.0E-10;
        this.maxEdges = maxEdges;
    }

    public void setTrueDag(Dag dag) {
        this.trueDag = dag;
    }

    public void setTrueIm(SemIm trueIm) {
        this.trueIm = trueIm;
    }

    public Result search() {
        LinkedList<Double> maxPs = new LinkedList<Double>();
        LinkedList<Dag> maxDags = new LinkedList<Dag>();
        if (this.trueDag != null) {
            double trueFml = this.printDagEstimate("INDEPENDENT DAG ", this.trueDag);
        }
        for (int numEdges = 0; numEdges <= this.maxEdges; ++numEdges) {
            System.out.println("######## Num Edges = " + numEdges);
            DagIterator3 iterator3 = new DagIterator3(this.nodes, numEdges, numEdges);
            int i = 0;
            while (iterator3.hasNext()) {
                Dag dag;
                if (++i % 100 == 0) {
                    System.out.println(i);
                }
                if ((dag = iterator3.next()).equals(this.trueDag)) {
                    this.printDagEstimate("INDEPENDENT DAG FOUND BY GENERATOR", dag);
                }
                SemPm pm = new SemPm(dag);
                SemEstimator estimator = new SemEstimator(this.data, pm);
                estimator.estimate();
                SemIm im = estimator.getEstimatedSem();
                double p = im.getPValue();
                if (p < 0.01) continue;
                maxPs.add(p);
                maxDags.add(dag);
            }
            if (maxDags.isEmpty()) continue;
            if (this.trueDag != null) {
                this.printDagEstimate("INDEPENDENT DAG", this.trueDag);
            }
            if (!maxDags.contains(this.trueDag)) {
                System.out.println("WARNING!!!! INDEPENDENT DAG NOT IN THE LIST!!!");
            }
            return new Result(maxDags, maxPs);
        }
        return new Result(new ArrayList<Dag>(), new ArrayList<Double>());
    }

    public List<Dag> search2() {
        LinkedList<Double> maxPs = new LinkedList<Double>();
        LinkedList<Dag> maxDags = new LinkedList<Dag>();
        for (int numEdges = 0; numEdges <= this.maxEdges; ++numEdges) {
            System.out.println("######## Num Edges = " + numEdges);
            DagIterator3 iterator3 = new DagIterator3(this.nodes, numEdges, numEdges);
            int i = 0;
            while (iterator3.hasNext()) {
                if (++i % 100 == 0) {
                    System.out.println(i);
                }
                Dag dag = iterator3.next();
                if (this.trueDag != null && dag.equals(this.trueDag)) {
                    this.printDagEstimate("INDEPENDENT DAG FOUND BY GENERATOR", dag);
                }
                SemPm pm = new SemPm(dag);
                SemEstimator estimator = new SemEstimator(this.data, pm);
                estimator.estimate();
                SemIm im = estimator.getEstimatedSem();
                double p = im.getPValue();
                if (p < this.alpha) continue;
                maxPs.add(p);
                maxDags.add(dag);
                System.out.println(dag);
                System.out.println("P = " + p);
                if (this.trueIm == null) continue;
                double sum = this.sumDifferencesFromTrue(im, this.trueIm);
                System.out.println("Sum of diffs of sample covar vs. estimated covar elements = " + sum);
                double sum2 = this.sumParameterDifferencesFromTrue(im, this.trueIm);
                System.out.println("Sum of diffs of true params from corresponding estimated params = " + sum2);
                System.out.println("FML = " + im.getFml());
            }
        }
        if (!maxDags.isEmpty()) {
            double maxP = 0.0;
            Iterator i$ = maxPs.iterator();
            while (i$.hasNext()) {
                double _p = (Double)i$.next();
                if (!(_p > maxP)) continue;
                maxP = _p;
            }
            for (int j = maxDags.size() - 1; j >= 0; --j) {
                if (!((Double)maxPs.get(j) < maxP - 0.01)) continue;
                maxPs.remove(j);
                maxDags.remove(j);
            }
            this.printTopDags(maxDags, maxPs);
            if (this.trueDag != null) {
                this.printDagEstimate("INDEPENDENT DAG", this.trueDag);
            }
            if (!maxDags.contains(this.trueDag)) {
                System.out.println("WARNING!!!! INDEPENDENT DAG NOT IN THE LIST!!!");
            }
            return maxDags;
        }
        return new LinkedList<Dag>();
    }

    private double sumDifferencesFromTrue(SemIm im, SemIm trueIm) {
        DoubleMatrix2D implCovar = im.getImplCovar();
        DoubleMatrix2D sampleCovar = im.getSampleCovar();
        double sum = 0.0;
        for (int i = 0; i < implCovar.rows(); ++i) {
            for (int j = 0; j < implCovar.columns(); ++j) {
                double diff = implCovar.get(i, j) - sampleCovar.get(i, j);
                sum += diff * diff;
            }
        }
        return sum;
    }

    private double sumParameterDifferencesFromTrue(SemIm im, SemIm trueIm) {
        double sum = 0.0;
        SemGraph estGraph = im.getSemPm().getGraph();
        for (Parameter trueParam : trueIm.getFreeParameters()) {
            double diff;
            double estValue;
            double trueValue;
            Node estA = estGraph.getNode(((Object)trueParam.getNodeA()).toString());
            Node estB = estGraph.getNode(((Object)trueParam.getNodeB()).toString());
            if (trueParam.getType() == ParamType.COEF) {
                trueValue = trueIm.getEdgeCoef(trueParam.getNodeA(), trueParam.getNodeB());
                estValue = im.getEdgeCoef(estA, estB);
                if (Double.isNaN(estValue)) {
                    estValue = 0.0;
                }
                diff = trueValue - estValue;
                sum += diff * diff;
            }
            if (trueParam.getType() != ParamType.COVAR) continue;
            trueValue = trueIm.getErrVar(trueParam.getNodeA());
            estValue = im.getErrVar(estA);
            diff = trueValue - estValue;
            sum += diff * diff;
        }
        return sum;
    }

    private double printDagEstimate(String label, Dag dag) {
        if (this.trueDag == null) {
            throw new IllegalArgumentException();
        }
        SemPm pm = new SemPm(dag);
        SemEstimator estimator = new SemEstimator(this.data, pm);
        estimator.estimate();
        SemIm im = estimator.getEstimatedSem();
        double p = im.getPValue();
        System.out.println(label + dag);
        System.out.println("P = " + p);
        System.out.println();
        return im.getFml();
    }

    private void printTopDags(LinkedList<Dag> maxDags, LinkedList<Double> maxPs) {
        if (maxDags.size() != maxPs.size()) {
            throw new IllegalArgumentException();
        }
        System.out.println("TOP DAGs");
        for (int i = 0; i < maxDags.size(); ++i) {
            System.out.println("\n#" + (i + 1) + ": P = " + maxPs.get(i) + "\n" + maxDags.get(i));
        }
    }

    public static Graph convertToPattern(List<Dag> dags) {
        if (dags == null || dags.isEmpty()) {
            return null;
        }
        EdgeListGraph pattern = new EdgeListGraph(dags.get(0));
        for (int i = 1; i < dags.size(); ++i) {
            Dag dag = dags.get(i);
            for (Edge edge : pattern.getEdges()) {
                if (dag.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
                System.out.println("Not all DAGs have the same adjacencies");
                return null;
            }
            for (Edge edge : dag.getEdges()) {
                if (pattern.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
                System.out.println("Not all DAGs have the same adjacencies");
                return null;
            }
            for (Edge patternEdge : pattern.getEdges()) {
                if (Edges.isUndirectedEdge(patternEdge)) continue;
                Node node1 = patternEdge.getNode1();
                Node node2 = patternEdge.getNode2();
                Edge dagEdge = dag.getEdge(node1, node2);
                if (Edges.getDirectedEdgeHead(patternEdge) == Edges.getDirectedEdgeHead(dagEdge)) continue;
                pattern.removeEdge(patternEdge);
                pattern.addUndirectedEdge(node1, node2);
            }
        }
        return pattern;
    }

    public class Result {
        private List<Dag> dags;
        private List<Double> pValues;

        public Result(List<Dag> dags, List<Double> pValues) {
            if (dags == null || pValues == null || dags.size() != pValues.size()) {
                throw new IllegalArgumentException();
            }
            this.setDags(dags);
            this.setPValues(pValues);
            this.sort();
        }

        private void sort() {
            ArrayList<Dag> sortedDags = new ArrayList<Dag>();
            ArrayList<Double> sortedPValues = new ArrayList<Double>();
            while (!this.getDags().isEmpty()) {
                double max = 0.0;
                int index = -1;
                for (int i = 0; i < this.getDags().size(); ++i) {
                    if (!(this.getPValues().get(i) > max)) continue;
                    max = this.getPValues().get(i);
                    index = i;
                }
                sortedDags.add(this.getDags().get(index));
                sortedPValues.add(this.getPValues().get(index));
                this.getDags().remove(index);
                this.getPValues().remove(index);
            }
            this.dags = sortedDags;
            this.pValues = sortedPValues;
        }

        public List<Dag> getDags() {
            return this.dags;
        }

        public void setDags(List<Dag> dags) {
            this.dags = dags;
        }

        public List<Double> getPValues() {
            return this.pValues;
        }

        public void setPValues(List<Double> pValues) {
            this.pValues = pValues;
        }
    }
}

