/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.regression;

import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.regression.LogisticRegressionResult;
import edu.cmu.tetrad.util.NumberFormatUtil;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.text.NumberFormat;

public class LogisticRegression {
    private NumberFormat nf;
    private double alpha = 0.05;
    private double[][] regressors;
    private int sampleSize;
    private String[] variableNames;
    private Graph outGraph;
    private LogisticRegressionResult result;
    private double[] coefficients;
    private double[] pValues;
    private double[] zScores;
    private PrintStream out = System.out;
    private PrintStream err = System.err;

    public LogisticRegression() {
        NumberFormatUtil.getInstance().setNumberFormat(new DecimalFormat("0.00000000"));
        this.nf = NumberFormatUtil.getInstance().getNumberFormat();
    }

    public double[][] getRegressors() {
        return this.regressors;
    }

    public void setRegressors(double[][] regressors) {
        if (regressors == null) {
            throw new NullPointerException("Regressor data must not be null.");
        }
        this.sampleSize = regressors[0].length;
        for (double[] regressor : regressors) {
            if (regressor == null) {
                throw new NullPointerException("All regressor columns must be non-null.");
            }
            if (regressor.length == this.sampleSize) continue;
            throw new IllegalArgumentException("Regressor data must all be the same length.");
        }
        this.regressors = regressors;
        this.variableNames = null;
    }

    public String[] getVariableNames() {
        return this.variableNames;
    }

    public void setVariableNames(String[] variableNames) {
        if (variableNames == null) {
            throw new NullPointerException("The variable names array must not be null.");
        }
        if (this.regressors == null) {
            throw new IllegalArgumentException("Please set the regressor data before setting the variable names; otherwise, I don't know whether you have the correct number of variable names.");
        }
        if (variableNames.length != this.regressors.length) {
            throw new IllegalArgumentException("The number of variable names must match the number of regressors: " + variableNames.length + " != " + this.regressors.length);
        }
        this.variableNames = variableNames;
    }

    public String regress(int[] target, String targetName) {
        int j;
        int j2;
        int i;
        String report = "";
        if (target.length != this.sampleSize) {
            throw new IllegalArgumentException("Target sample size must match regressor sample size.");
        }
        this.outGraph = new EdgeListGraph();
        GraphNode targetNode = new GraphNode(targetName);
        this.outGraph.addNode(targetNode);
        int numRegressors = this.regressors.length;
        int numCases = this.regressors[0].length;
        double[][] x = new double[numRegressors + 1][];
        double[] c1 = new double[numCases];
        x[0] = c1;
        System.arraycopy(this.regressors, 0, x, 1, numRegressors);
        for (int i2 = 0; i2 < numCases; ++i2) {
            x[0][i2] = 1.0;
            c1[i2] = 1.0;
        }
        double[] xMeans = new double[numRegressors + 1];
        double[] xStdDevs = new double[numRegressors + 1];
        double[] y0 = new double[numCases];
        double[] y1 = new double[numCases];
        for (int i3 = 0; i3 < numCases; ++i3) {
            y0[i3] = 0.0;
            y1[i3] = 0.0;
        }
        int ny0 = 0;
        int ny1 = 0;
        int nc = 0;
        for (i = 0; i < numCases; ++i) {
            if ((double)target[i] == 0.0) {
                y0[i] = 1.0;
                ++ny0;
            } else {
                y1[i] = 1.0;
                ++ny1;
            }
            nc = (int)((double)nc + (y0[i] + y1[i]));
            for (j2 = 1; j2 <= numRegressors; ++j2) {
                int n = j2;
                xMeans[n] = xMeans[n] + (y0[i] + y1[i]) * x[j2][i];
                int n2 = j2;
                xStdDevs[n2] = xStdDevs[n2] + (y0[i] + y1[i]) * x[j2][i] * x[j2][i];
            }
        }
        report = report + ny0 + " cases have " + targetName + " = 0; " + ny1 + " cases have " + targetName + " = 1.\n";
        report = report + "\tVariable\tAvg\tSD\n";
        for (int j3 = 1; j3 <= numRegressors; ++j3) {
            int n = j3;
            xMeans[n] = xMeans[n] / (double)nc;
            int n3 = j3;
            xStdDevs[n3] = xStdDevs[n3] / (double)nc;
            xStdDevs[j3] = Math.sqrt(Math.abs(xStdDevs[j3] - xMeans[j3] * xMeans[j3]));
            report = report + "\t" + this.variableNames[j3 - 1] + "\t" + this.nf.format(xMeans[j3]) + "\t" + this.nf.format(xStdDevs[j3]) + "\n";
        }
        xMeans[0] = 0.0;
        xStdDevs[0] = 1.0;
        for (i = 0; i < nc; ++i) {
            for (j2 = 1; j2 <= numRegressors; ++j2) {
                x[j2][i] = (x[j2][i] - xMeans[j2]) / xStdDevs[j2];
            }
        }
        double[] par = new double[numRegressors + 1];
        double[] parStdErr = new double[numRegressors + 1];
        this.coefficients = new double[numRegressors + 1];
        par[0] = Math.log((double)ny1 / (double)ny0);
        for (int j4 = 1; j4 <= numRegressors; ++j4) {
            par[j4] = 0.0;
        }
        double[][] arr = new double[numRegressors + 1][numRegressors + 2];
        double llP = 2.0E10;
        double ll = 1.0E10;
        double llN = 0.0;
        while (Math.abs(llP - ll) > 1.0E-7) {
            int i4;
            int j5;
            llP = ll;
            ll = 0.0;
            for (j5 = 0; j5 <= numRegressors; ++j5) {
                for (int k = j5; k <= numRegressors + 1; ++k) {
                    arr[j5][k] = 0.0;
                }
            }
            for (i4 = 0; i4 < nc; ++i4) {
                double q;
                double ln1mV;
                double lnV;
                double v = par[0];
                for (j = 1; j <= numRegressors; ++j) {
                    v += par[j] * x[j][i4];
                }
                if (v > 15.0) {
                    lnV = -Math.exp(-v);
                    ln1mV = -v;
                    q = Math.exp(-v);
                    v = Math.exp(lnV);
                } else if (v < -15.0) {
                    lnV = v;
                    ln1mV = -Math.exp(v);
                    q = Math.exp(v);
                    v = Math.exp(lnV);
                } else {
                    v = 1.0 / (1.0 + Math.exp(-v));
                    lnV = Math.log(v);
                    ln1mV = Math.log(1.0 - v);
                    q = v * (1.0 - v);
                }
                ll = ll - 2.0 * y1[i4] * lnV - 2.0 * y0[i4] * ln1mV;
                for (j = 0; j <= numRegressors; ++j) {
                    double xij = x[j][i4];
                    double[] dArray = arr[j];
                    int n = numRegressors + 1;
                    dArray[n] = dArray[n] + xij * (y1[i4] * (1.0 - v) + y0[i4] * -v);
                    for (int k = j; k <= numRegressors; ++k) {
                        double[] dArray2 = arr[j];
                        int n4 = k;
                        dArray2[n4] = dArray2[n4] + xij * x[k][i4] * q * (y0[i4] + y1[i4]);
                    }
                }
            }
            if (llP == 1.0E10) {
                llN = ll;
            }
            for (j5 = 1; j5 <= numRegressors; ++j5) {
                for (int k = 0; k < j5; ++k) {
                    arr[j5][k] = arr[k][j5];
                }
            }
            for (i4 = 0; i4 <= numRegressors; ++i4) {
                double s = arr[i4][i4];
                arr[i4][i4] = 1.0;
                for (int k = 0; k <= numRegressors + 1; ++k) {
                    arr[i4][k] = arr[i4][k] / s;
                }
                for (int j6 = 0; j6 <= numRegressors; ++j6) {
                    if (i4 == j6) continue;
                    s = arr[j6][i4];
                    arr[j6][i4] = 0.0;
                    for (int k = 0; k <= numRegressors + 1; ++k) {
                        arr[j6][k] = arr[j6][k] - s * arr[i4][k];
                    }
                }
            }
            for (j5 = 0; j5 <= numRegressors; ++j5) {
                int n = j5;
                par[n] = par[n] + arr[j5][numRegressors + 1];
            }
        }
        EdgeListGraph outgraph = new EdgeListGraph();
        GraphNode targNode = new GraphNode(targetName);
        outgraph.addNode(targNode);
        double chiSq = llN - ll;
        report = report + "Overall Model Fit...\n";
        report = report + "  Chi Square = " + this.nf.format(chiSq) + "; df = " + numRegressors + "; " + "p = " + this.nf.format(this.chiSquare(chiSq, numRegressors)) + "\n";
        report = report + "\nCoefficients and Standard Errors...\n";
        report = report + " Variable\tCoeff.\tStdErr\tprob.\tsig.\n";
        String[] sigMarker = new String[numRegressors];
        this.pValues = new double[numRegressors + 1];
        this.zScores = new double[numRegressors + 1];
        for (j = 1; j <= numRegressors; ++j) {
            double prob;
            par[j] = par[j] / xStdDevs[j];
            parStdErr[j] = Math.sqrt(arr[j][j]) / xStdDevs[j];
            par[0] = par[0] - par[j] * xMeans[j];
            double zScore = par[j] / parStdErr[j];
            this.getpValues()[j] = prob = this.norm(Math.abs(zScore));
            this.getZScores()[j] = zScore;
            if (prob < this.alpha) {
                sigMarker[j - 1] = "*";
                GraphNode predNode = new GraphNode(this.variableNames[j - 1]);
                outgraph.addNode(predNode);
                Edge newEdge = new Edge(predNode, targNode, Endpoint.TAIL, Endpoint.ARROW);
                outgraph.addEdge(newEdge);
            } else {
                sigMarker[j - 1] = "";
            }
            report = report + this.variableNames[j - 1] + "\t" + this.nf.format(par[j]) + "\t" + this.nf.format(parStdErr[j]) + "\t" + this.nf.format(prob) + "\t" + sigMarker[j - 1] + "\n";
        }
        parStdErr[0] = Math.sqrt(arr[0][0]);
        double zScore = par[0] / parStdErr[0];
        this.getpValues()[0] = this.norm(Math.abs(zScore));
        this.getZScores()[0] = zScore;
        double intercept = par[0];
        report = report + "\nIntercept = " + this.nf.format(intercept) + "\n";
        this.setOutGraph(outgraph);
        this.setCoefficients(par);
        this.result = new LogisticRegressionResult(targetName, this.variableNames, xMeans, xStdDevs, numRegressors, ny0, ny1, this.coefficients, parStdErr, this.getpValues(), intercept, report);
        return report;
    }

    private double chiSquare(double x, int n) {
        if (x > 1000.0 || n > 1000) {
            double q = this.norm((Math.pow(x / (double)n, 0.3333333333333333) + 2.0 / (9.0 * (double)n) - 1.0) / Math.sqrt(2.0 / (9.0 * (double)n))) / 2.0;
            if (x > (double)n) {
                return q;
            }
            return 1.0 - q;
        }
        double p = Math.exp(-0.5 * x);
        if (n % 2 == 1) {
            p *= Math.sqrt(2.0 * x / Math.PI);
        }
        for (int k = n; k >= 2; k -= 2) {
            p = p * x / (double)k;
        }
        double t = p;
        int a = n;
        while (t > p * 1.0E-15) {
            t = t * x / (double)(a += 2);
            p += t;
        }
        return 1.0 - p;
    }

    private double norm(double z) {
        double q = z * z;
        double piOver2 = 1.5707963267948966;
        if (Math.abs(z) > 7.0) {
            return (1.0 - 1.0 / q + 3.0 / (q * q)) * Math.exp(-q / 2.0) / (Math.abs(z) * Math.sqrt(piOver2));
        }
        return this.chiSquare(q, 1);
    }

    private void setCoefficients(double[] c) {
        System.arraycopy(c, 0, this.coefficients, 0, c.length);
    }

    public double[] getCoefficients() {
        return this.coefficients;
    }

    private void setOutGraph(Graph g) {
        this.outGraph = g;
    }

    public LogisticRegressionResult getResult() {
        return this.result;
    }

    public Graph getOutGraph() {
        return this.outGraph;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public PrintStream getOut() {
        return this.out;
    }

    public void setOut(PrintStream out) {
        this.out = out;
    }

    public PrintStream getErr() {
        return this.err;
    }

    public double[] getpValues() {
        return this.pValues;
    }

    public double[] getZScores() {
        return this.zScores;
    }
}

