/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.list.DoubleArrayList;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class LingamPattern2 {
    private final Graph cpdag;
    private final List<DataSet> dataSets;
    private double[] pValues;
    private double alpha = 0.05;
    private final List<Regression> regressions;
    private final List<Node> variables;
    private final ArrayList<Matrix> data;

    public LingamPattern2(Graph cpdag, List<DataSet> dataSets) throws IllegalArgumentException {
        if (dataSets == null) {
            throw new IllegalArgumentException("Data set must be specified.");
        }
        if (cpdag == null) {
            throw new IllegalArgumentException("CPDAG must be specified.");
        }
        this.cpdag = cpdag;
        this.dataSets = dataSets;
        this.variables = dataSets.get(0).getVariables();
        this.data = new ArrayList();
        for (DataSet dataSet : this.getDataSets()) {
            Matrix _data = dataSet.getDoubleData();
            this.data.add(_data);
        }
        this.regressions = new ArrayList<Regression>();
        for (Matrix _data : this.data) {
            this.regressions.add(new RegressionDataset(_data, this.variables));
        }
    }

    public Graph search() {
        Graph _cpdag = GraphUtils.bidirectedToUndirected(this.getCpdag());
        TetradLogger.getInstance().log("info", "Making list of all dags in CPDAG...");
        List<Graph> dags = SearchGraphUtils.getAllGraphsByDirectingUndirectedEdges(_cpdag);
        TetradLogger.getInstance().log("normalityTests", "Anderson Darling P value for Variables\n");
        if (dags.isEmpty()) {
            return null;
        }
        ArrayList<Score> scores = new ArrayList<Score>();
        for (Graph dag : dags) {
            scores.add(this.getScore(dag, this.data, this.variables));
        }
        double maxScore = 0.0;
        int maxj = -1;
        for (int j = 0; j < dags.size(); ++j) {
            double _score = ((Score)scores.get((int)j)).score;
            if (!(_score > maxScore)) continue;
            maxScore = _score;
            maxj = j;
        }
        Graph dag = dags.get(maxj);
        this.pValues = ((Score)scores.get((int)maxj)).pvals;
        Graph ngDagCPDAG = SearchGraphUtils.cpdagFromDag(dag);
        List<Node> nodes = ngDagCPDAG.getNodes();
        for (Edge edge : ngDagCPDAG.getEdges()) {
            boolean node2Nongaussian;
            Node node1 = edge.getNode1();
            Node node2 = edge.getNode2();
            double p1 = this.getPValues()[nodes.indexOf(node1)];
            double p2 = this.getPValues()[nodes.indexOf(node2)];
            boolean node1Nongaussian = p1 < this.getAlpha();
            boolean bl = node2Nongaussian = p2 < this.getAlpha();
            if (!node1Nongaussian && !node2Nongaussian || !Edges.isUndirectedEdge(edge)) continue;
            ngDagCPDAG.removeEdge(edge);
            ngDagCPDAG.addEdge(dag.getEdge(node1, node2));
            if (node1Nongaussian) {
                TetradLogger.getInstance().log("edgeOrientations", node1 + " nongaussian ");
            }
            if (node2Nongaussian) {
                TetradLogger.getInstance().log("edgeOrientations", node2 + " nongaussian ");
            }
            TetradLogger.getInstance().log("nongaussianOrientations", "Nongaussian orientation: " + dag.getEdge(node1, node2));
        }
        System.out.println();
        MeekRules meekRules = new MeekRules();
        meekRules.setAggressivelyPreventCycles(true);
        meekRules.orientImplied(ngDagCPDAG);
        TetradLogger.getInstance().log("graph", "Returning: " + ngDagCPDAG);
        return ngDagCPDAG;
    }

    private Score getScore(Graph dag, List<Matrix> data, List<Node> variables) {
        int j;
        int totalSampleSize = 0;
        for (Matrix _data : data) {
            totalSampleSize += _data.rows();
        }
        int numCols = data.get(0).columns();
        List<Node> nodes = dag.getNodes();
        double score = 0.0;
        double[] pValues = new double[nodes.size()];
        Matrix residuals = new Matrix(totalSampleSize, numCols);
        for (j = 0; j < nodes.size(); ++j) {
            ArrayList<Double> _residuals = new ArrayList<Double>();
            Node _target = nodes.get(j);
            List<Node> _regressors = dag.getParents(_target);
            Node target = this.getVariable(variables, _target.getName());
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (Node _regressor : _regressors) {
                Node variable = this.getVariable(variables, _regressor.getName());
                regressors.add(variable);
            }
            for (int m = 0; m < data.size(); ++m) {
                RegressionResult result = this.regressions.get(m).regress(target, regressors);
                Vector residualsSingleDataset = result.getResiduals();
                DoubleArrayList _residualsSingleDataset = new DoubleArrayList(residualsSingleDataset.toArray());
                double mean = Descriptive.mean(_residualsSingleDataset);
                double std = Descriptive.standardDeviation(Descriptive.variance(_residualsSingleDataset.size(), Descriptive.sum(_residualsSingleDataset), Descriptive.sumOfSquares(_residualsSingleDataset)));
                for (int i2 = 0; i2 < _residualsSingleDataset.size(); ++i2) {
                    _residualsSingleDataset.set(i2, (_residualsSingleDataset.get(i2) - mean) / std);
                }
                for (int k = 0; k < _residualsSingleDataset.size(); ++k) {
                    _residuals.add(_residualsSingleDataset.get(k));
                }
                DoubleArrayList f = new DoubleArrayList(_residualsSingleDataset.elements());
                for (int k = 0; k < f.size(); ++k) {
                    f.set(k, FastMath.abs(f.get(k)));
                }
                double _mean = Descriptive.mean(f);
                double diff = _mean - FastMath.sqrt(0.6366197723675814);
                score += diff * diff;
            }
            for (int k = 0; k < _residuals.size(); ++k) {
                residuals.set(k, j, (Double)_residuals.get(k));
            }
        }
        for (j = 0; j < residuals.columns(); ++j) {
            double p;
            double[] x = residuals.getColumn(j).toArray();
            pValues[j] = p = new AndersonDarlingTest(x).getP();
        }
        return new Score(score, pValues);
    }

    public double[] getPValues() {
        return this.pValues;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Alpha is in range [0, 1]");
        }
        this.alpha = alpha;
    }

    private Graph getCpdag() {
        return this.cpdag;
    }

    private List<DataSet> getDataSets() {
        return this.dataSets;
    }

    private Node getVariable(List<Node> variables, String name) {
        for (Node node : variables) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }

    private static class Score {
        double score;
        double[] pvals;

        public Score(double score, double[] pvals) {
            this.score = score;
            this.pvals = pvals;
        }
    }
}

