/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.list.DoubleArrayList;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class LingamPattern {
    private final Graph cpdag;
    private final DataSet dataSet;
    private double[] pValues;
    private double alpha = 0.05;

    public LingamPattern(Graph cpdag, DataSet dataSet) throws IllegalArgumentException {
        if (cpdag == null) {
            throw new IllegalArgumentException("CPDAG must be specified.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("Data set must be specified.");
        }
        this.cpdag = cpdag;
        this.dataSet = dataSet;
    }

    public Graph search() {
        Graph _cpdag = GraphUtils.bidirectedToUndirected(this.getCpdag());
        TetradLogger.getInstance().log("info", "Making list of all dags in CPDAG...");
        List<Graph> dags = SearchGraphUtils.getAllGraphsByDirectingUndirectedEdges(_cpdag);
        TetradLogger.getInstance().log("normalityTests", "Anderson Darling P value for Variables\n");
        DecimalFormat nf = new DecimalFormat("0.0000");
        if (dags.isEmpty()) {
            return null;
        }
        Matrix data = this.getDataSet().getDoubleData();
        List<Node> variables = this.getDataSet().getVariables();
        if (dags.size() == 0) {
            throw new IllegalArgumentException("The data set is empty.");
        }
        ArrayList<Score> scores = new ArrayList<Score>();
        for (Graph dag : dags) {
            scores.add(this.getScore(dag, data, variables));
        }
        double maxScore = 0.0;
        int maxj = -1;
        for (int j = 0; j < dags.size(); ++j) {
            double _score = ((Score)scores.get((int)j)).score;
            if (!(_score > maxScore)) continue;
            maxScore = _score;
            maxj = j;
        }
        Graph dag = dags.get(maxj);
        this.pValues = ((Score)scores.get((int)maxj)).pvals;
        TetradLogger.getInstance().log("graph", "winning dag = " + dag);
        TetradLogger.getInstance().log("normalityTests", "Anderson Darling P value for Residuals\n");
        for (int j = 0; j < this.getDataSet().getNumColumns(); ++j) {
            TetradLogger.getInstance().log("normalityTests", this.getDataSet().getVariable(j) + ": " + nf.format(((Score)scores.get((int)maxj)).pvals[j]));
        }
        Graph ngDagCPDAG = SearchGraphUtils.cpdagFromDag(dag);
        List<Node> nodes = ngDagCPDAG.getNodes();
        for (Edge edge : ngDagCPDAG.getEdges()) {
            boolean node2Nongaussian;
            Node node1 = edge.getNode1();
            Node node2 = edge.getNode2();
            double p1 = this.getPValues()[nodes.indexOf(node1)];
            double p2 = this.getPValues()[nodes.indexOf(node2)];
            boolean node1Nongaussian = p1 < this.getAlpha();
            boolean bl = node2Nongaussian = p2 < this.getAlpha();
            if (!node1Nongaussian && !node2Nongaussian || !Edges.isUndirectedEdge(edge)) continue;
            ngDagCPDAG.removeEdge(edge);
            ngDagCPDAG.addEdge(dag.getEdge(node1, node2));
            if (node1Nongaussian) {
                TetradLogger.getInstance().log("edgeOrientations", node1 + " nongaussian ");
            }
            if (node2Nongaussian) {
                TetradLogger.getInstance().log("edgeOrientations", node2 + " nongaussian ");
            }
            TetradLogger.getInstance().log("nongaussianOrientations", "Nongaussian orientation: " + dag.getEdge(node1, node2));
        }
        new MeekRules().orientImplied(ngDagCPDAG);
        TetradLogger.getInstance().log("graph", "Returning: " + ngDagCPDAG);
        return ngDagCPDAG;
    }

    private Score getScore(Graph dag, Matrix data, List<Node> variables) {
        RegressionDataset regression = new RegressionDataset(data, variables);
        List<Node> nodes = dag.getNodes();
        double score = 0.0;
        double[] pValues = new double[nodes.size()];
        Matrix residuals = new Matrix(data.rows(), data.columns());
        for (int i = 0; i < nodes.size(); ++i) {
            Node _target = nodes.get(i);
            List<Node> _regressors = dag.getParents(_target);
            Node target = this.getVariable(variables, _target.getName());
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (Node _regressor : _regressors) {
                Node variable = this.getVariable(variables, _regressor.getName());
                regressors.add(variable);
            }
            RegressionResult result = regression.regress(target, regressors);
            Vector residualsColumn = result.getResiduals();
            residuals.assignColumn(i, residualsColumn);
            DoubleArrayList residualsArray = new DoubleArrayList(residualsColumn.toArray());
            double mean = Descriptive.mean(residualsArray);
            double std = Descriptive.standardDeviation(Descriptive.variance(residualsArray.size(), Descriptive.sum(residualsArray), Descriptive.sumOfSquares(residualsArray)));
            for (int i2 = 0; i2 < residualsArray.size(); ++i2) {
                residualsArray.set(i2, (residualsArray.get(i2) - mean) / std);
                residualsArray.set(i2, FastMath.abs(residualsArray.get(i2)));
            }
            double _mean = Descriptive.mean(residualsArray);
            double diff = _mean - FastMath.sqrt(0.6366197723675814);
            score += diff * diff;
        }
        for (int j = 0; j < residuals.columns(); ++j) {
            double p;
            double[] x = residuals.getColumn(j).toArray();
            pValues[j] = p = new AndersonDarlingTest(x).getP();
        }
        return new Score(score, pValues);
    }

    public double[] getPValues() {
        return this.pValues;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Alpha is in range [0, 1]");
        }
        this.alpha = alpha;
    }

    private Graph getCpdag() {
        return this.cpdag;
    }

    private DataSet getDataSet() {
        return this.dataSet;
    }

    private Node getVariable(List<Node> variables, String name) {
        for (Node node : variables) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }

    private static class Score {
        double score;
        double[] pvals;

        public Score(double score, double[] pvals) {
            this.score = score;
            this.pvals = pvals;
        }
    }
}

