/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.regression;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.PlusMult;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.ProbUtils;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.List;

public class RegressionDataset
implements Regression {
    private NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    private DoubleMatrix2D data;
    private List<Node> variables;
    private double alpha = 0.05;
    private Graph graph = null;

    public RegressionDataset(DataSet data) {
        this.data = data.getDoubleData();
        this.variables = data.getVariables();
    }

    public RegressionDataset(DoubleMatrix2D data, List<Node> variables) {
        this.data = data;
        this.variables = variables;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public Graph getGraph() {
        return this.graph;
    }

    @Override
    public RegressionResult regress(Node target, List<Node> regressors) {
        int j;
        int i;
        DenseDoubleMatrix2D x;
        int n = this.data.rows();
        int k = regressors.size() + 1;
        int _target = this.variables.indexOf(target);
        int[] _regressors = new int[regressors.size()];
        for (int i2 = 0; i2 < regressors.size(); ++i2) {
            _regressors[i2] = this.variables.indexOf(regressors.get(i2));
        }
        int[] rows = new int[this.data.rows()];
        for (int i3 = 0; i3 < rows.length; ++i3) {
            rows[i3] = i3;
        }
        DoubleMatrix2D y = this.data.viewSelection(rows, new int[]{_target}).copy();
        DoubleMatrix2D xSub = this.data.viewSelection(rows, _regressors);
        if (regressors.size() > 0) {
            x = new DenseDoubleMatrix2D(xSub.rows(), xSub.columns() + 1);
            for (i = 0; i < x.rows(); ++i) {
                for (j = 0; j < x.columns(); ++j) {
                    if (j == 0) {
                        x.set(i, j, 1.0);
                        continue;
                    }
                    x.set(i, j, xSub.get(i, j - 1));
                }
            }
        } else {
            x = new DenseDoubleMatrix2D(xSub.rows(), xSub.columns());
            for (i = 0; i < x.rows(); ++i) {
                for (j = 0; j < x.columns(); ++j) {
                    x.set(i, j, xSub.get(i, j));
                }
            }
        }
        DoubleMatrix2D xT = x.viewDice();
        DoubleMatrix2D xTx = new Algebra().mult(xT, x);
        DoubleMatrix2D xTxInv = new Algebra().inverse(xTx);
        DoubleMatrix2D xTy = new Algebra().mult(xT, y);
        DoubleMatrix2D b = new Algebra().mult(xTxInv, xTy);
        DoubleMatrix2D yHat = new Algebra().mult((DoubleMatrix2D)x, b);
        DoubleMatrix2D res = y.copy().assign(yHat, PlusMult.plusMult(-1.0));
        DoubleMatrix1D _yHat = yHat.viewColumn(0);
        DoubleMatrix1D _res = res.viewColumn(0);
        double rss = this.rss(x, y, b);
        double se = Math.sqrt(rss / (double)(n - k));
        double tss = this.tss(y);
        double r2 = 1.0 - rss / tss;
        DenseDoubleMatrix1D sqErr = new DenseDoubleMatrix1D(x.columns());
        DenseDoubleMatrix1D t = new DenseDoubleMatrix1D(x.columns());
        DenseDoubleMatrix1D p = new DenseDoubleMatrix1D(x.columns());
        for (int i4 = 0; i4 < x.columns(); ++i4) {
            double _s = se * se * xTxInv.get(i4, i4);
            double _se = Math.sqrt(_s);
            double _t = b.get(i4, 0) / _se;
            double _p = 2.0 * (1.0 - ProbUtils.tCdf(Math.abs(_t), n - k));
            sqErr.set(i4, _se);
            t.set(i4, _t);
            p.set(i4, _p);
        }
        this.graph = this.createOutputGraph(target.getName(), x, regressors, p);
        String[] vNames = new String[regressors.size()];
        for (int i5 = 0; i5 < regressors.size(); ++i5) {
            vNames[i5] = regressors.get(i5).getName();
        }
        double[] bArray = b.viewColumn(0).toArray();
        double[] tArray = t.toArray();
        double[] pArray = p.toArray();
        double[] seArray = sqErr.toArray();
        return new RegressionResult(regressors.size() == 0, vNames, n, bArray, tArray, pArray, seArray, r2, rss, this.alpha, _yHat, _res);
    }

    @Override
    public RegressionResult regress(Node target, Node ... regressors) {
        List<Node> _regressors = Arrays.asList(regressors);
        return this.regress(target, _regressors);
    }

    private Graph createOutputGraph(String target, DoubleMatrix2D x, List<Node> regressors, DoubleMatrix1D p) {
        GraphNode targetNode = new GraphNode(target);
        EdgeListGraph graph = new EdgeListGraph();
        graph.addNode(targetNode);
        for (int i = 0; i < x.columns(); ++i) {
            String variableName;
            String string = variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            if (!(p.get(i) < this.alpha)) continue;
            GraphNode predictorNode = new GraphNode(variableName);
            graph.addNode(predictorNode);
            Edge newEdge = new Edge(predictorNode, targetNode, Endpoint.TAIL, Endpoint.ARROW);
            graph.addEdge(newEdge);
        }
        return graph;
    }

    private String createResultString(int n, int k, double rss, double r2, DoubleMatrix2D x, List<Node> regressors, DoubleMatrix2D b, DoubleMatrix1D se, DoubleMatrix1D t, DoubleMatrix1D p) {
        String rssString = this.nf.format(rss);
        String r2String = this.nf.format(r2);
        String summary = "\n REGRESSION RESULT";
        summary = summary + "\n n = " + n + ", k = " + k + ", alpha = " + this.alpha + "\n";
        summary = summary + " SSE = " + rssString + "\n";
        summary = summary + " R^2 = " + r2String + "\n\n";
        summary = summary + " VAR\tCOEF\tSE\tT\tP\n";
        for (int i = 0; i < x.columns(); ++i) {
            String variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            summary = summary + " " + variableName + "\t" + this.nf.format(b.get(i, 0)) + "\t" + this.nf.format(se.get(i)) + "\t" + this.nf.format(t.get(i)) + "\t" + this.nf.format(p.get(i)) + "\t" + (p.get(i) < this.alpha ? "significant " : "") + "\n";
        }
        return summary;
    }

    private double rss(DoubleMatrix2D x, DoubleMatrix2D y, DoubleMatrix2D b) {
        double rss = 0.0;
        for (int i = 0; i < x.rows(); ++i) {
            double yH = 0.0;
            for (int j = 0; j < b.rows(); ++j) {
                yH += b.get(j, 0) * x.get(i, j);
            }
            double d = y.get(i, 0) - yH;
            rss += d * d;
        }
        return rss;
    }

    private double tss(DoubleMatrix2D y) {
        double mean = 0.0;
        for (int i = 0; i < y.rows(); ++i) {
            mean += y.get(i, 0);
        }
        mean /= (double)y.rows();
        double ssm = 0.0;
        for (int i = 0; i < y.rows(); ++i) {
            double d = mean - y.get(i, 0);
            ssm += d * d;
        }
        return ssm;
    }
}

