/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.regression;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.util.FastMath;

public class LogisticRegression
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private final DataSet dataSet;
    private double alpha = 0.05;
    private final double[][] dataCols;
    private int[] rows;

    public LogisticRegression(DataSet dataSet) {
        this.dataSet = dataSet;
        this.dataCols = dataSet.getDoubleData().transpose().toArray();
        this.setRows(new int[dataSet.getNumRows()]);
        for (int i = 0; i < this.getRows().length; ++i) {
            this.getRows()[i] = i;
        }
    }

    public static LogisticRegression serializableInstance() {
        return new LogisticRegression(BoxDataSet.serializableInstance());
    }

    public Result regress(DiscreteVariable x, List<Node> regressors) {
        int col;
        if (!this.binary(x)) {
            throw new IllegalArgumentException("Target must be binary.");
        }
        for (Node node : regressors) {
            if (node instanceof ContinuousVariable || this.binary(node)) continue;
            throw new IllegalArgumentException("Regressors must be continuous or binary.");
        }
        double[][] _regressors = new double[regressors.size()][this.getRows().length];
        for (int j = 0; j < regressors.size(); ++j) {
            col = this.dataSet.getColumn(regressors.get(j));
            double[] dataCol = this.dataCols[col];
            for (int i = 0; i < this.getRows().length; ++i) {
                _regressors[j][i] = dataCol[this.getRows()[i]];
            }
        }
        int[] target = new int[this.getRows().length];
        col = this.dataSet.getColumn(this.dataSet.getVariable(x.getName()));
        for (int i = 0; i < this.getRows().length; ++i) {
            target[i] = this.dataSet.getInt(this.getRows()[i], col);
        }
        ArrayList<String> regressorNames = new ArrayList<String>();
        for (Node node : regressors) {
            regressorNames.add(node.getName());
        }
        return this.regress(target, x.getName(), _regressors, regressorNames);
    }

    private boolean binary(Node x) {
        return x instanceof DiscreteVariable && ((DiscreteVariable)x).getNumCategories() == 2;
    }

    private Result regress(int[] target, String targetName, double[][] regressors, List<String> regressorNames) {
        int j;
        int i;
        int numRegressors = regressors.length;
        int numCases = target.length;
        double[][] x = new double[numRegressors + 1][];
        double[] c1 = new double[numCases];
        x[0] = c1;
        System.arraycopy(regressors, 0, x, 1, numRegressors);
        for (int i2 = 0; i2 < numCases; ++i2) {
            x[0][i2] = 1.0;
            c1[i2] = 1.0;
        }
        double[] xMeans = new double[numRegressors + 1];
        double[] xStdDevs = new double[numRegressors + 1];
        double[] y0 = new double[numCases];
        double[] y1 = new double[numCases];
        for (int i3 = 0; i3 < numCases; ++i3) {
            y0[i3] = 0.0;
            y1[i3] = 0.0;
        }
        int ny0 = 0;
        int ny1 = 0;
        int nc = 0;
        for (i = 0; i < numCases; ++i) {
            if ((double)target[i] == 0.0) {
                y0[i] = 1.0;
                ++ny0;
            } else {
                y1[i] = 1.0;
                ++ny1;
            }
            nc = (int)((double)nc + (y0[i] + y1[i]));
            for (j = 1; j <= numRegressors; ++j) {
                int n = j;
                xMeans[n] = xMeans[n] + (y0[i] + y1[i]) * x[j][i];
                int n2 = j;
                xStdDevs[n2] = xStdDevs[n2] + (y0[i] + y1[i]) * x[j][i] * x[j][i];
            }
        }
        for (int j2 = 1; j2 <= numRegressors; ++j2) {
            int n = j2;
            xMeans[n] = xMeans[n] / (double)nc;
            int n3 = j2;
            xStdDevs[n3] = xStdDevs[n3] / (double)nc;
            xStdDevs[j2] = FastMath.sqrt(FastMath.abs(xStdDevs[j2] - xMeans[j2] * xMeans[j2]));
        }
        xMeans[0] = 0.0;
        xStdDevs[0] = 1.0;
        for (i = 0; i < nc; ++i) {
            for (j = 1; j <= numRegressors; ++j) {
                x[j][i] = (x[j][i] - xMeans[j]) / xStdDevs[j];
            }
        }
        double[] par = new double[numRegressors + 1];
        double[] parStdErr = new double[numRegressors + 1];
        par[0] = FastMath.log((double)ny1 / (double)ny0);
        for (int j3 = 1; j3 <= numRegressors; ++j3) {
            par[j3] = 0.0;
        }
        double[][] arr = new double[numRegressors + 1][numRegressors + 2];
        double llP = 2.0E10;
        double ll = 1.0E10;
        double llN = 0.0;
        while (FastMath.abs(llP - ll) > 1.0E-7) {
            int i4;
            int j4;
            llP = ll;
            ll = 0.0;
            for (j4 = 0; j4 <= numRegressors; ++j4) {
                for (int k = j4; k <= numRegressors + 1; ++k) {
                    arr[j4][k] = 0.0;
                }
            }
            for (i4 = 0; i4 < nc; ++i4) {
                double q;
                double ln1mV;
                double lnV;
                int j5;
                double v = par[0];
                for (j5 = 1; j5 <= numRegressors; ++j5) {
                    v += par[j5] * x[j5][i4];
                }
                if (v > 15.0) {
                    lnV = -FastMath.exp(-v);
                    ln1mV = -v;
                    q = FastMath.exp(-v);
                    v = FastMath.exp(lnV);
                } else if (v < -15.0) {
                    lnV = v;
                    ln1mV = -FastMath.exp(v);
                    q = FastMath.exp(v);
                    v = FastMath.exp(lnV);
                } else {
                    v = 1.0 / (1.0 + FastMath.exp(-v));
                    lnV = FastMath.log(v);
                    ln1mV = FastMath.log(1.0 - v);
                    q = v * (1.0 - v);
                }
                ll = ll - 2.0 * y1[i4] * lnV - 2.0 * y0[i4] * ln1mV;
                for (j5 = 0; j5 <= numRegressors; ++j5) {
                    double xij = x[j5][i4];
                    double[] dArray = arr[j5];
                    int n = numRegressors + 1;
                    dArray[n] = dArray[n] + xij * (y1[i4] * (1.0 - v) + y0[i4] * -v);
                    for (int k = j5; k <= numRegressors; ++k) {
                        double[] dArray2 = arr[j5];
                        int n4 = k;
                        dArray2[n4] = dArray2[n4] + xij * x[k][i4] * q * (y0[i4] + y1[i4]);
                    }
                }
            }
            if (llP == 1.0E10) {
                llN = ll;
            }
            for (j4 = 1; j4 <= numRegressors; ++j4) {
                for (int k = 0; k < j4; ++k) {
                    arr[j4][k] = arr[k][j4];
                }
            }
            for (i4 = 0; i4 <= numRegressors; ++i4) {
                double s = arr[i4][i4];
                arr[i4][i4] = 1.0;
                for (int k = 0; k <= numRegressors + 1; ++k) {
                    arr[i4][k] = arr[i4][k] / s;
                }
                for (int j6 = 0; j6 <= numRegressors; ++j6) {
                    if (i4 == j6) continue;
                    s = arr[j6][i4];
                    arr[j6][i4] = 0.0;
                    for (int k = 0; k <= numRegressors + 1; ++k) {
                        arr[j6][k] = arr[j6][k] - s * arr[i4][k];
                    }
                }
            }
            for (j4 = 0; j4 <= numRegressors; ++j4) {
                int n = j4;
                par[n] = par[n] + arr[j4][numRegressors + 1];
            }
        }
        double chiSq = llN - ll;
        double[] pValues = new double[numRegressors + 1];
        for (int j7 = 1; j7 <= numRegressors; ++j7) {
            double prob;
            par[j7] = par[j7] / xStdDevs[j7];
            parStdErr[j7] = FastMath.sqrt(arr[j7][j7]) / xStdDevs[j7];
            par[0] = par[0] - par[j7] * xMeans[j7];
            double zScore = par[j7] / parStdErr[j7];
            pValues[j7] = prob = this.norm(FastMath.abs(zScore));
        }
        parStdErr[0] = FastMath.sqrt(arr[0][0]);
        double zScore = par[0] / parStdErr[0];
        pValues[0] = this.norm(zScore);
        double intercept = par[0];
        double[] coefficients = par;
        return new Result(targetName, regressorNames, xMeans, xStdDevs, numRegressors, ny0, ny1, coefficients, parStdErr, pValues, intercept, ll, chiSq, this.alpha);
    }

    private double norm(double z) {
        double q = z * z;
        double piOver2 = 1.5707963267948966;
        if (FastMath.abs(q) > 7.0) {
            return (1.0 - 1.0 / q + 3.0 / (q * q)) * FastMath.exp(-q / 2.0) / (FastMath.abs(z) * FastMath.sqrt(1.5707963267948966));
        }
        return new ChiSquaredDistribution(1.0).cumulativeProbability(q);
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    private int[] getRows() {
        return this.rows;
    }

    public void setRows(int[] rows) {
        this.rows = rows;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
    }

    public static class Result
    implements TetradSerializable {
        static final long serialVersionUID = 23L;
        private final double chiSq;
        private final double alpha;
        private final List<String> regressorNames;
        private final String target;
        private final int ny0;
        private final int ny1;
        private final int numRegressors;
        private final double[] coefs;
        private final double[] stdErrs;
        private final double[] probs;
        private final double[] xMeans;
        private final double[] xStdDevs;
        private final double intercept;
        private final double logLikelihood;

        public Result(String target, List<String> regressorNames, double[] xMeans, double[] xStdDevs, int numRegressors, int ny0, int ny1, double[] coefs, double[] stdErrs, double[] probs, double intercept, double logLikelihood, double chiSq, double alpha) {
            if (regressorNames.size() != numRegressors) {
                throw new IllegalArgumentException();
            }
            if (coefs.length != numRegressors + 1) {
                throw new IllegalArgumentException();
            }
            if (stdErrs.length != numRegressors + 1) {
                throw new IllegalArgumentException();
            }
            if (probs.length != numRegressors + 1) {
                throw new IllegalArgumentException();
            }
            if (xMeans.length != numRegressors + 1) {
                throw new IllegalArgumentException();
            }
            if (xStdDevs.length != numRegressors + 1) {
                throw new IllegalArgumentException();
            }
            if (target == null) {
                throw new NullPointerException();
            }
            this.intercept = intercept;
            this.target = target;
            this.xMeans = xMeans;
            this.xStdDevs = xStdDevs;
            this.regressorNames = regressorNames;
            this.numRegressors = numRegressors;
            this.ny0 = ny0;
            this.ny1 = ny1;
            this.coefs = coefs;
            this.stdErrs = stdErrs;
            this.probs = probs;
            this.logLikelihood = logLikelihood;
            this.chiSq = chiSq;
            this.alpha = alpha;
        }

        public static Result serializableInstance() {
            return new Result("X1", new ArrayList<String>(), new double[1], new double[1], 0, 0, 0, new double[1], new double[1], new double[1], 1.5, 0.0, 0.0, 0.05);
        }

        public List<String> getRegressorNames() {
            return this.regressorNames;
        }

        public String getTarget() {
            return this.target;
        }

        public int getNy0() {
            return this.ny0;
        }

        public int getNy1() {
            return this.ny1;
        }

        public int getNumRegressors() {
            return this.numRegressors;
        }

        public double[] getCoefs() {
            return this.coefs;
        }

        public double[] getStdErrs() {
            return this.stdErrs;
        }

        public double[] getProbs() {
            return this.probs;
        }

        public double[] getxMeans() {
            return this.xMeans;
        }

        public double[] getxStdDevs() {
            return this.xStdDevs;
        }

        public double getIntercept() {
            return this.intercept;
        }

        public double getLogLikelihood() {
            return this.logLikelihood;
        }

        public String toString() {
            DecimalFormat nf = new DecimalFormat("0.0000");
            StringBuilder report = new StringBuilder();
            report.append(this.ny0).append(" cases have ").append(this.target).append(" = 0; ").append(this.ny1).append(" cases have ").append(this.target).append(" = 1.\n");
            report.append("Overall Model Fit...\n");
            report.append("  Chi Square = ").append(nf.format(this.chiSq)).append("; df = ").append(this.numRegressors).append("; ").append("p = ").append(nf.format(new ChiSquaredDistribution(this.numRegressors).cumulativeProbability(this.chiSq))).append("\n");
            report.append("\nCoefficients and Standard Errors...\n");
            report.append("\tCoeff.\tStdErr\tprob.\tsig.");
            report.append("\n");
            for (int i = 0; i < this.regressorNames.size(); ++i) {
                report.append("\n").append(this.regressorNames.get(i)).append("\t").append(nf.format(this.coefs[i + 1])).append("\t").append(nf.format(this.stdErrs[i + 1])).append("\t").append(nf.format(this.probs[i + 1])).append("\t").append(this.probs[i + 1] < this.alpha ? "*" : "");
            }
            report.append("\n\nIntercept = ").append(nf.format(this.intercept)).append("\n");
            return report.toString();
        }
    }
}

