/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.regression;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.List;

public class RegressionDatasetGeneralized
implements Regression {
    private NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    private DoubleMatrix2D data;
    private List<Node> variables;
    private double alpha = 0.05;
    private Graph graph = null;

    public RegressionDatasetGeneralized(DataSet data) {
        this.data = data.getDoubleData();
        this.variables = data.getVariables();
    }

    public RegressionDatasetGeneralized(DoubleMatrix2D data, List<Node> variables) {
        this.data = data;
        this.variables = variables;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public Graph getGraph() {
        return this.graph;
    }

    @Override
    public RegressionResult regress(Node target, List<Node> regressors) {
        int n = this.data.rows();
        int k = regressors.size() + 1;
        int _target = this.variables.indexOf(target);
        int[] _regressors = new int[regressors.size()];
        for (int i = 0; i < regressors.size(); ++i) {
            _regressors[i] = this.variables.indexOf(regressors.get(i));
        }
        int[] rows = new int[this.data.rows()];
        for (int i = 0; i < rows.length; ++i) {
            rows[i] = i;
        }
        DoubleMatrix2D xSub = this.data.viewSelection(rows, _regressors);
        DenseDoubleMatrix2D X = new DenseDoubleMatrix2D(xSub.rows(), xSub.columns() + 1);
        for (int i = 0; i < X.rows(); ++i) {
            for (int j = 0; j < X.columns(); ++j) {
                if (j == 0) {
                    X.set(i, j, 1.0);
                    continue;
                }
                X.set(i, j, xSub.get(i, j - 1));
            }
        }
        DoubleMatrix1D y = this.data.viewColumn(_target);
        DoubleMatrix2D Xt = new Algebra().transpose(X);
        DoubleMatrix2D XtX = new Algebra().mult(Xt, X);
        DoubleMatrix2D G = MatrixUtils.ginverse(XtX);
        DoubleMatrix2D Xt2 = Xt.like();
        Xt2.assign(Xt);
        DoubleMatrix2D GXt = new Algebra().mult(G, Xt2);
        DoubleMatrix1D b = new Algebra().mult(GXt, y);
        DoubleMatrix1D yPred = new Algebra().mult((DoubleMatrix2D)X, b);
        DoubleMatrix1D xRes = yPred.copy().assign(y, Functions.minus);
        double rss = this.rss(X, y, b);
        double se = Math.sqrt(rss / (double)(n - k));
        double tss = this.tss(y);
        double r2 = 1.0 - rss / tss;
        String[] vNames = new String[regressors.size()];
        for (int i = 0; i < regressors.size(); ++i) {
            vNames[i] = regressors.get(i).getName();
        }
        return new RegressionResult(false, vNames, n, b.toArray(), new double[0], new double[0], new double[0], r2, rss, this.alpha, yPred, xRes);
    }

    @Override
    public RegressionResult regress(Node target, Node ... regressors) {
        List<Node> _regressors = Arrays.asList(regressors);
        return this.regress(target, _regressors);
    }

    private Graph createOutputGraph(String target, DoubleMatrix2D x, List<Node> regressors, DoubleMatrix1D p) {
        GraphNode targetNode = new GraphNode(target);
        EdgeListGraph graph = new EdgeListGraph();
        graph.addNode(targetNode);
        for (int i = 0; i < x.columns(); ++i) {
            String variableName;
            String string = variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            if (!(p.get(i) < this.alpha)) continue;
            GraphNode predictorNode = new GraphNode(variableName);
            graph.addNode(predictorNode);
            Edge newEdge = new Edge(predictorNode, targetNode, Endpoint.TAIL, Endpoint.ARROW);
            graph.addEdge(newEdge);
        }
        return graph;
    }

    private String createResultString(int n, int k, double rss, double r2, DoubleMatrix2D x, List<Node> regressors, DoubleMatrix2D b, DoubleMatrix1D se, DoubleMatrix1D t, DoubleMatrix1D p) {
        String rssString = this.nf.format(rss);
        String r2String = this.nf.format(r2);
        String summary = "\n REGRESSION RESULT";
        summary = summary + "\n n = " + n + ", k = " + k + ", alpha = " + this.alpha + "\n";
        summary = summary + " SSE = " + rssString + "\n";
        summary = summary + " R^2 = " + r2String + "\n\n";
        summary = summary + " VAR\tCOEF\tSE\tT\tP\n";
        for (int i = 0; i < x.columns(); ++i) {
            String variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            summary = summary + " " + variableName + "\t" + this.nf.format(b.get(i, 0)) + "\t" + this.nf.format(se.get(i)) + "\t" + this.nf.format(t.get(i)) + "\t" + this.nf.format(p.get(i)) + "\t" + (p.get(i) < this.alpha ? "significant " : "") + "\n";
        }
        return summary;
    }

    private double rss(DoubleMatrix2D x, DoubleMatrix1D y, DoubleMatrix1D b) {
        double rss = 0.0;
        for (int i = 0; i < x.rows(); ++i) {
            double yH = 0.0;
            for (int j = 0; j < b.size(); ++j) {
                yH += b.get(j) * x.get(i, j);
            }
            double d = y.get(i) - yH;
            rss += d * d;
        }
        return rss;
    }

    private double tss(DoubleMatrix1D y) {
        double mean = 0.0;
        for (int i = 0; i < y.size(); ++i) {
            mean += y.get(i);
        }
        mean /= (double)y.size();
        double ssm = 0.0;
        for (int i = 0; i < y.size(); ++i) {
            double d = mean - y.get(i);
            ssm += d * d;
        }
        return ssm;
    }
}

