/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.regression;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.Vector;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public class RegressionCovariance
implements Regression {
    private final CorrelationMatrix correlations;
    private final Vector sd;
    private final Vector means;
    private double alpha = 0.05;
    private Graph graph;

    public RegressionCovariance(ICovarianceMatrix covariances) {
        this(covariances, RegressionCovariance.zeroMeans(covariances.getDimension()));
    }

    private RegressionCovariance(ICovarianceMatrix covariances, Vector means) {
        this(new CorrelationMatrix(covariances), RegressionCovariance.standardDeviations(covariances), means);
    }

    private RegressionCovariance(CorrelationMatrix correlations, Vector standardDeviations, Vector means) {
        if (correlations == null) {
            throw new NullPointerException();
        }
        if (standardDeviations == null || standardDeviations.size() != correlations.getDimension()) {
            throw new IllegalArgumentException();
        }
        if (means != null && means.size() != correlations.getDimension()) {
            throw new IllegalArgumentException();
        }
        this.correlations = correlations;
        this.sd = standardDeviations;
        this.means = means;
    }

    private static Vector zeroMeans(int numVars) {
        return new Vector(numVars);
    }

    private static Vector standardDeviations(ICovarianceMatrix covariances) {
        Vector standardDeviations = new Vector(covariances.getDimension());
        for (int i = 0; i < covariances.getDimension(); ++i) {
            standardDeviations.set(i, FastMath.sqrt(covariances.getValue(i, i)));
        }
        return standardDeviations;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public Graph getGraph() {
        return this.graph;
    }

    @Override
    public RegressionResult regress(Node target, List<Node> regressors) {
        try {
            Matrix allCorrelations = this.correlations.getMatrix();
            List<Node> variables = this.correlations.getVariables();
            int yIndex = variables.indexOf(target);
            int[] xIndices = new int[regressors.size()];
            for (int i = 0; i < regressors.size(); ++i) {
                xIndices[i] = variables.indexOf(regressors.get(i));
                if (xIndices[i] != -1) continue;
                throw new NullPointerException("Can't find variable " + regressors.get(i) + " in this list: " + variables);
            }
            Matrix rX = allCorrelations.getSelection(xIndices, xIndices);
            Matrix rY = allCorrelations.getSelection(xIndices, new int[]{yIndex});
            Matrix bStar = rX.inverse().times(rY);
            Vector b = new Vector(bStar.getNumRows() + 1);
            for (int k = 1; k < b.size(); ++k) {
                double sdY = this.sd.get(yIndex);
                double sdK = this.sd.get(xIndices[k - 1]);
                b.set(k, bStar.get(k - 1, 0) * (sdY / sdK));
            }
            b.set(0, Double.NaN);
            if (this.means != null) {
                double b0 = this.means.get(yIndex);
                for (int i = 0; i < xIndices.length; ++i) {
                    b0 -= b.get(i + 1) * this.means.get(xIndices[i]);
                }
                b.set(0, b0);
            }
            int[] allIndices = new int[1 + regressors.size()];
            allIndices[0] = yIndex;
            for (int i = 1; i < allIndices.length; ++i) {
                allIndices[i] = variables.indexOf(regressors.get(i - 1));
            }
            Matrix r = allCorrelations.getSelection(allIndices, allIndices);
            Matrix rInv = r.inverse();
            int n = this.correlations.getSampleSize();
            int k = regressors.size() + 1;
            double vY = rInv.get(0, 0);
            double r2 = 1.0 - 1.0 / vY;
            double tss = (double)n * this.sd.get(yIndex) * this.sd.get(yIndex);
            double rss = tss * (1.0 - r2);
            double seY = FastMath.sqrt(rss / (double)(n - k));
            Vector sqErr = new Vector(allIndices.length);
            Vector t = new Vector(allIndices.length);
            Vector p = new Vector(allIndices.length);
            sqErr.set(0, Double.NaN);
            t.set(0, Double.NaN);
            p.set(0, Double.NaN);
            Matrix rxInv = rX.inverse();
            for (int i = 0; i < regressors.size(); ++i) {
                double _r2 = 1.0 - 1.0 / rxInv.get(i, i);
                double _tss = (double)n * this.sd.get(xIndices[i]) * this.sd.get(xIndices[i]);
                double _se = seY / FastMath.sqrt(_tss * (1.0 - _r2));
                double _t = b.get(i + 1) / _se;
                double _p = 1.0 - ProbUtils.tCdf(FastMath.abs(_t), n - k);
                sqErr.set(i + 1, _se);
                t.set(i + 1, _t);
                p.set(i + 1, _p);
            }
            this.graph = this.createGraph(target, allIndices, regressors, p);
            String[] vNames = this.createVarNamesArray(regressors);
            double[] bArray = b.toArray();
            double[] tArray = t.toArray();
            double[] pArray = p.toArray();
            double[] seArray = sqErr.toArray();
            return new RegressionResult(false, vNames, n, bArray, tArray, pArray, seArray, r2, rss, this.alpha, null);
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException("Singularity encountered when regressing " + LogUtilsSearch.getScoreFact(target, regressors));
        }
    }

    @Override
    public RegressionResult regress(Node target, Node ... regressors) {
        List<Node> _regressors = Arrays.asList(regressors);
        return this.regress(target, _regressors);
    }

    private String[] createVarNamesArray(List<Node> regressors) {
        String[] vNames = this.getVarNamesArray(regressors);
        for (int i = 0; i < regressors.size(); ++i) {
            vNames[i] = regressors.get(i).getName();
        }
        return vNames;
    }

    private String[] getVarNamesArray(List<Node> regressors) {
        return new String[regressors.size()];
    }

    private Graph createGraph(Node target, int[] allIndices, List<Node> regressors, Vector p) {
        EdgeListGraph graph = new EdgeListGraph();
        graph.addNode(target);
        for (int i = 0; i < allIndices.length; ++i) {
            String variableName;
            String string = variableName = i > 0 ? regressors.get(i - 1).getName() : "const";
            if (!(p.get(i) < this.alpha)) continue;
            GraphNode predictorNode = new GraphNode(variableName);
            graph.addNode(predictorNode);
            Edge newEdge = new Edge(predictorNode, target, Endpoint.TAIL, Endpoint.ARROW);
            graph.addEdge(newEdge);
        }
        return graph;
    }
}

