/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.regression;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.ProbUtils;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.List;

public class RegressionCovariance
implements Regression {
    private NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    private CorrelationMatrix correlations;
    private DoubleMatrix1D sd;
    private DoubleMatrix1D means;
    private double alpha = 0.05;
    private Graph graph = null;

    public RegressionCovariance(CovarianceMatrix covariances) {
        this(covariances, RegressionCovariance.zeroMeans(covariances.getDimension()));
    }

    public RegressionCovariance(CovarianceMatrix covariances, DoubleMatrix1D means) {
        this(new CorrelationMatrix(covariances), RegressionCovariance.standardDeviations(covariances), means);
    }

    public RegressionCovariance(CorrelationMatrix correlations, DoubleMatrix1D standardDeviations, DoubleMatrix1D means) {
        if (correlations == null) {
            throw new NullPointerException();
        }
        if (standardDeviations == null || standardDeviations.size() != correlations.getDimension()) {
            throw new IllegalArgumentException();
        }
        if (means != null && means.size() != correlations.getDimension()) {
            throw new IllegalArgumentException();
        }
        this.correlations = correlations;
        this.sd = standardDeviations;
        this.means = means;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public Graph getGraph() {
        return this.graph;
    }

    @Override
    public RegressionResult regress(Node target, List<Node> regressors) {
        DoubleMatrix2D allCorrelations = this.correlations.getMatrix();
        List<Node> variables = this.correlations.getVariables();
        int yIndex = variables.indexOf(target);
        int[] xIndices = new int[regressors.size()];
        for (int i = 0; i < regressors.size(); ++i) {
            xIndices[i] = variables.indexOf(regressors.get(i));
        }
        DoubleMatrix2D rX = allCorrelations.viewSelection(xIndices, xIndices);
        DoubleMatrix2D rY = allCorrelations.viewSelection(xIndices, new int[]{yIndex});
        DoubleMatrix2D bStar = new Algebra().mult(new Algebra().inverse(rX), rY.copy());
        DenseDoubleMatrix1D b = new DenseDoubleMatrix1D(bStar.rows() + 1);
        for (int k = 1; k < b.size(); ++k) {
            double sdY = this.sd.get(yIndex);
            double sdK = this.sd.get(xIndices[k - 1]);
            b.set(k, bStar.get(k - 1, 0) * (sdY / sdK));
        }
        b.set(0, Double.NaN);
        if (this.means != null) {
            double b0 = this.means.get(yIndex);
            for (int i = 0; i < xIndices.length; ++i) {
                b0 -= b.get(i + 1) * this.means.get(xIndices[i]);
            }
            b.set(0, b0);
        }
        int[] allIndices = new int[1 + regressors.size()];
        allIndices[0] = yIndex;
        for (int i = 1; i < allIndices.length; ++i) {
            allIndices[i] = variables.indexOf(regressors.get(i - 1));
        }
        DoubleMatrix2D r = allCorrelations.viewSelection(allIndices, allIndices);
        DoubleMatrix2D rInv = new Algebra().inverse(r);
        int n = this.correlations.getSampleSize();
        int k = regressors.size() + 1;
        double vY = rInv.get(0, 0);
        double r2 = 1.0 - 1.0 / vY;
        double tss = (double)n * this.sd.get(yIndex) * this.sd.get(yIndex);
        double rss = tss * (1.0 - r2);
        double seY = Math.sqrt(rss / (double)(n - k));
        DenseDoubleMatrix1D sqErr = new DenseDoubleMatrix1D(allIndices.length);
        DenseDoubleMatrix1D t = new DenseDoubleMatrix1D(allIndices.length);
        DenseDoubleMatrix1D p = new DenseDoubleMatrix1D(allIndices.length);
        sqErr.set(0, Double.NaN);
        t.set(0, Double.NaN);
        p.set(0, Double.NaN);
        DoubleMatrix2D rxInv = new Algebra().inverse(rX);
        for (int i = 0; i < regressors.size(); ++i) {
            double _r2 = 1.0 - 1.0 / rxInv.get(i, i);
            double _tss = (double)n * this.sd.get(xIndices[i]) * this.sd.get(xIndices[i]);
            double _se = seY / Math.sqrt(_tss * (1.0 - _r2));
            double _t = b.get(i + 1) / _se;
            double _p = 2.0 * (1.0 - ProbUtils.tCdf(Math.abs(_t), n - k));
            sqErr.set(i + 1, _se);
            t.set(i + 1, _t);
            p.set(i + 1, _p);
        }
        this.graph = this.createGraph(target, allIndices, regressors, p);
        String[] vNames = this.createVarNamesArray(regressors);
        double[] bArray = b.toArray();
        double[] tArray = t.toArray();
        double[] pArray = p.toArray();
        double[] seArray = sqErr.toArray();
        return new RegressionResult(false, vNames, n, bArray, tArray, pArray, seArray, r2, rss, this.alpha, null, null);
    }

    @Override
    public RegressionResult regress(Node target, Node ... regressors) {
        List<Node> _regressors = Arrays.asList(regressors);
        return this.regress(target, _regressors);
    }

    private String[] createVarNamesArray(List<Node> regressors) {
        String[] vNames = this.getVarNamesArray(regressors);
        for (int i = 0; i < regressors.size(); ++i) {
            vNames[i] = regressors.get(i).getName();
        }
        return vNames;
    }

    private String[] getVarNamesArray(List<Node> regressors) {
        return new String[regressors.size()];
    }

    private Graph createGraph(Node target, int[] allIndices, List<Node> regressors, DoubleMatrix1D p) {
        EdgeListGraph graph = new EdgeListGraph();
        graph.addNode(target);
        for (int i = 0; i < allIndices.length; ++i) {
            String variableName;
            String string = variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            if (!(p.get(i) < this.alpha)) continue;
            GraphNode predictorNode = new GraphNode(variableName);
            graph.addNode(predictorNode);
            Edge newEdge = new Edge(predictorNode, target, Endpoint.TAIL, Endpoint.ARROW);
            graph.addEdge(newEdge);
        }
        return graph;
    }

    private String createSummary(int n, int k, double rss, double r2, int[] allIndices, List<Node> regressors, DoubleMatrix1D b, DoubleMatrix1D se, DoubleMatrix1D t, DoubleMatrix1D p) {
        String summary = "\n REGRESSION RESULT";
        summary = summary + "\n n = " + n + ", k = " + k + ", alpha = " + this.alpha + "\n";
        String rssString = this.nf.format(rss);
        summary = summary + " SSE = " + rssString + "\n";
        String r2String = this.nf.format(r2);
        summary = summary + " R^2 = " + r2String + "\n\n";
        summary = summary + " VAR\tCOEF\tSE\tT\tP\n";
        for (int i = 0; i < allIndices.length; ++i) {
            String variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            summary = summary + " " + variableName + "\t" + this.nf.format(b.get(i)) + "\t" + this.nf.format(se.get(i)) + "\t" + this.nf.format(t.get(i)) + "\t" + this.nf.format(p.get(i)) + "\t" + (p.get(i) < this.alpha ? "significant " : "") + "\n";
        }
        return summary;
    }

    private static DoubleMatrix1D zeroMeans(int numVars) {
        return new DenseDoubleMatrix1D(numVars);
    }

    private static DoubleMatrix1D standardDeviations(CovarianceMatrix covariances) {
        DenseDoubleMatrix1D standardDeviations = new DenseDoubleMatrix1D(covariances.getDimension());
        for (int i = 0; i < covariances.getDimension(); ++i) {
            standardDeviations.set(i, Math.sqrt(covariances.getValue(i, i)));
        }
        return standardDeviations;
    }
}

