/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.list.DoubleArrayList;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.Fask;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class BossLingam {
    private final Graph cpdag;
    private final DataSet dataSet;
    private double[] pValues;
    private double alpha = 0.05;

    public BossLingam(Graph cpdag, DataSet dataSet) throws IllegalArgumentException {
        if (cpdag == null) {
            throw new IllegalArgumentException("CPDAG must be specified.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("Data set must be specified.");
        }
        this.cpdag = cpdag;
        this.dataSet = dataSet;
    }

    public Graph search() {
        EdgeListGraph toOrient = new EdgeListGraph(this.cpdag);
        DataSet standardized = DataTransforms.standardizeData(this.dataSet);
        double[][] _data = standardized.getDoubleData().transpose().toArray();
        GraphUtils.replaceNodes(toOrient, standardized.getVariables());
        List<Node> nodes = standardized.getVariables();
        for (Edge edge : this.cpdag.getEdges()) {
            int j;
            if (!Edges.isUndirectedEdge(edge)) continue;
            Node X = edge.getNode1();
            Node Y = edge.getNode2();
            int i = nodes.indexOf(X);
            double lr = Fask.faskLeftRightV2(_data[i], _data[j = nodes.indexOf(Y)], true, 0.0);
            if (lr > 0.0) {
                toOrient.removeEdge(edge);
                toOrient.addDirectedEdge(X, Y);
                continue;
            }
            toOrient.removeEdge(edge);
            toOrient.addDirectedEdge(Y, X);
        }
        TetradLogger.getInstance().log("graph", "Returning: " + toOrient);
        return toOrient;
    }

    public double[] getPValues() {
        return this.pValues;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Alpha is in range [0, 1]");
        }
        this.alpha = alpha;
    }

    private Score getScore(Graph dag, Matrix data, List<Node> variables) {
        RegressionDataset regression = new RegressionDataset(data, variables);
        List<Node> nodes = dag.getNodes();
        double score = 0.0;
        double[] pValues = new double[nodes.size()];
        Matrix residuals = new Matrix(data.getNumRows(), data.getNumColumns());
        for (int i = 0; i < nodes.size(); ++i) {
            Node _target = nodes.get(i);
            List<Node> _regressors = dag.getParents(_target);
            Node target = this.getVariable(variables, _target.getName());
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (Node _regressor : _regressors) {
                Node variable = this.getVariable(variables, _regressor.getName());
                regressors.add(variable);
            }
            RegressionResult result = regression.regress(target, regressors);
            Vector residualsColumn = result.getResiduals();
            residuals.assignColumn(i, residualsColumn);
            DoubleArrayList residualsArray = new DoubleArrayList(residualsColumn.toArray());
            double mean = Descriptive.mean(residualsArray);
            double std = Descriptive.standardDeviation(Descriptive.variance(residualsArray.size(), Descriptive.sum(residualsArray), Descriptive.sumOfSquares(residualsArray)));
            for (int i2 = 0; i2 < residualsArray.size(); ++i2) {
                residualsArray.set(i2, (residualsArray.get(i2) - mean) / std);
                residualsArray.set(i2, FastMath.abs(residualsArray.get(i2)));
            }
            double _mean = Descriptive.mean(residualsArray);
            double diff = _mean - FastMath.sqrt(0.6366197723675814);
            score += diff * diff;
        }
        for (int j = 0; j < residuals.getNumColumns(); ++j) {
            double p;
            double[] x = residuals.getColumn(j).toArray();
            pValues[j] = p = new AndersonDarlingTest(x).getP();
        }
        return new Score(score, pValues);
    }

    private Graph getCpdag() {
        return this.cpdag;
    }

    private DataSet getDataSet() {
        return this.dataSet;
    }

    private Node getVariable(List<Node> variables, String name) {
        for (Node node : variables) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }

    private static class Score {
        double score;
        double[] pvals;

        public Score(double score, double[] pvals) {
            this.score = score;
            this.pvals = pvals;
        }
    }
}

