/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.sem.ParamConstraint;
import edu.cmu.tetrad.sem.ParamConstraintType;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.NumberFormat;
import java.util.List;

public final class SemEstimatorGibbs {
    static final long serialVersionUID = 23L;
    private double[][] sampleCovars;
    private int numIterations;
    private double stretch1;
    private double stretch2;
    private double tolerance;
    private double priorVariance;
    private SemPm semPm;
    private double[] parameterMeans;
    private ParamConstraint[] paramConstraints;
    private SemIm startIm;
    private DoubleMatrix2D priorCov;
    private SemIm estimatedSem;
    private boolean flatPrior;
    private DoubleMatrix2D dataSet;

    public SemEstimatorGibbs(SemPm semPm, SemIm startIm, double[][] sampleCovars, boolean flatPrior, double stretch, int numIterations) {
        this.sampleCovars = sampleCovars;
        this.semPm = semPm;
        this.startIm = startIm;
        this.flatPrior = flatPrior;
        this.stretch1 = stretch;
        this.stretch2 = 1.0;
        this.numIterations = numIterations;
        this.tolerance = 1.0E-4;
        this.priorVariance = 16.0;
    }

    public void estimate() {
        boolean lrtest = false;
        List<Parameter> parameters = this.semPm.getParameters();
        int numParameters = parameters.size();
        double[][] parameterCovariances = new double[numParameters][numParameters];
        this.parameterMeans = new double[numParameters];
        this.paramConstraints = new ParamConstraint[numParameters];
        DenseDoubleMatrix2D data = new DenseDoubleMatrix2D(parameters.size(), this.numIterations / 50);
        if (this.flatPrior) {
            for (int i = 0; i < numParameters; ++i) {
                Parameter param = parameters.get(i);
                this.parameterMeans[i] = param.isFixed() ? 0.0 : this.priorVariance;
                this.paramConstraints[i] = param.getType() == ParamType.VAR ? new ParamConstraint(this.startIm, param, ParamConstraintType.GT, 0.0) : new ParamConstraint(this.startIm, param, ParamConstraintType.NONE, 0.0);
                for (int j = 0; j < numParameters; ++j) {
                    parameterCovariances[i][j] = i == j && !param.isFixed() ? this.priorVariance : 0.0;
                }
            }
        } else {
            if (!lrtest) {
                // empty if block
            }
            System.out.println("Informative Prior. Exiting.");
            return;
        }
        this.priorCov = new DenseDoubleMatrix2D(parameterCovariances);
        DoubleMatrix2D impliedCovMatrix = this.startIm.getImplCovar();
        SemIm posteriorIm = new SemIm(this.startIm);
        List<Parameter> postFreeParams = posteriorIm.getFreeParameters();
        System.out.println("entering main loop");
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            System.out.println(iter);
            for (int param = 0; param < postFreeParams.size(); ++param) {
                double denom;
                double cx;
                double bx;
                double ax;
                double number;
                Parameter p = parameters.get(param);
                ParamConstraint constraint = this.paramConstraints[param];
                if (p.isFixed()) continue;
                double d = number = constraint.getParam2() == null ? constraint.getNumber() : this.startIm.getParamValue(constraint.getParam2());
                if (constraint.getType() == ParamConstraintType.NONE) {
                    ax = -500.0;
                    bx = 0.0;
                    cx = 500.0;
                } else if (constraint.getType() == ParamConstraintType.GT) {
                    ax = number;
                    cx = number + 500.0;
                    bx = (ax + cx) / 2.0;
                } else if (constraint.getType() == ParamConstraintType.LT) {
                    cx = number;
                    ax = number - 500.0;
                    bx = (ax + cx) / 2.0;
                } else if (constraint.getType() == ParamConstraintType.EQ) {
                    bx = number;
                    ax = number - 500.0;
                    cx = number + 500.0;
                } else {
                    ax = -500.0;
                    bx = 0.0;
                    cx = 500.0;
                }
                double[] mean = new double[1];
                double dmean = -this.brent(param, ax, bx, cx, this.tolerance, mean, parameters);
                double gap = 0.005;
                do {
                    gap = 2.0 * gap;
                    boolean gapThreshold = true;
                    double minDenom = 0.01;
                    if (gap > (double)gapThreshold) {
                        denom = minDenom;
                        break;
                    }
                    System.out.println(p.getNodeA() + " " + p.getNodeA().getNodeType());
                    System.out.println(p.getNodeB() + " " + p.getNodeB().getNodeType());
                    double dmeanplus = this.neglogpost(param, mean[0] + gap, parameters);
                    denom = dmean + dmeanplus;
                    if (!(denom < minDenom)) continue;
                    denom = minDenom;
                } while (denom < 0.0);
                double vr = this.stretch1 * 0.5 * gap * gap / denom;
                boolean realdraw = false;
                double rj = 0.0;
                double accept = 0.0;
                double cand = 0.0;
                while (!realdraw || rj <= accept) {
                    cand = mean[0] + Math.max(RandomUtil.getInstance().nextNormal(0.0, 1.0) * Math.sqrt(vr), 0.0);
                    realdraw = constraint.wouldBeSatisfied(cand);
                    if (!realdraw) continue;
                    double dcand = -1.0 * this.neglogpost(param, cand, parameters);
                    double numer = dcand - dmean;
                    double denom1 = -1.0 * Math.sqrt(cand - mean[0]) / (2.0 * vr) - Math.log(this.stretch2);
                    rj = numer - denom1;
                    accept = Math.log(RandomUtil.getInstance().nextDouble());
                    int rejectionThreshold = 5;
                    if (!(rj > (double)rejectionThreshold)) continue;
                    rj = rejectionThreshold;
                }
                Parameter ppost = postFreeParams.get(param);
                if (ppost.isFixed()) {
                    posteriorIm.setFixedParamValue(ppost, cand);
                    continue;
                }
                posteriorIm.setParamValue(ppost, cand);
            }
            int subsampleStride = 50;
            if (iter % subsampleStride != 0 || iter <= 0) continue;
            for (int i = 0; i < numParameters; ++i) {
                Parameter ppost = posteriorIm.getSemPm().getParameters().get(i);
                data.set(i, iter / subsampleStride - 1, posteriorIm.getParamValue(ppost));
            }
        }
        this.dataSet = data;
        this.estimatedSem = posteriorIm;
    }

    private double brent(int param, double ax, double bx, double cx, double tol, double[] xmin, List<Parameter> parameters) {
        double fx;
        double v;
        int ITMAX = 100;
        double CGOLD = 0.381966;
        double ZEPS = 1.0E-10;
        double w = v = bx;
        double x = v;
        double d = 0.0;
        double e = 0.0;
        double a = ax < cx ? ax : cx;
        double b = ax > cx ? ax : cx;
        double fv = fx = this.neglogpost(param, x, parameters);
        double fw = fx;
        for (int iter = 1; iter <= ITMAX; ++iter) {
            double u;
            double xm = 0.5 * (a + b);
            double tol1 = tol * Math.abs(x) + ZEPS;
            double tol2 = 2.0 * tol1;
            if (Math.abs(x - xm) <= tol2 - 0.5 * (b - a)) {
                xmin[0] = x;
                return fx;
            }
            if (Math.abs(e) > tol1) {
                double r = (x - w) * (fx - fv);
                double q = (x - v) * (fx - fw);
                double p = (x - v) * q - (x - w) * r;
                if ((q = 2.0 * (q - r)) > 0.0) {
                    p = -p;
                }
                q = Math.abs(q);
                double etemp = e;
                e = d;
                if (Math.abs(p) >= Math.abs(0.5 * q * etemp) || p <= q * (a - x) || p >= q * (b - x)) {
                    e = x >= xm ? a - x : b - x;
                    d = CGOLD * e;
                } else {
                    d = p / q;
                    u = x + d;
                    if (u - a < tol2 || b - u < tol2) {
                        d = xm - x >= 0.0 ? Math.abs(tol1) : -Math.abs(tol1);
                    }
                }
            } else {
                e = x >= xm ? a - x : b - x;
                d = CGOLD * e;
            }
            double s = tol1 > -0.0 ? Math.abs(d) : -Math.abs(d);
            u = Math.abs(d) >= tol1 ? x + d : x + s;
            double fu = this.neglogpost(param, u, parameters);
            if (fu <= fx) {
                if (u >= x) {
                    a = x;
                } else {
                    b = x;
                }
                v = w;
                fv = fw;
                w = x;
                fw = fx;
                x = u;
                fx = fu;
                continue;
            }
            if (u < x) {
                a = u;
            } else {
                b = u;
            }
            if (fu <= fw || w == x) {
                v = w;
                fv = fw;
                w = u;
                fw = fu;
                continue;
            }
            if (!(fu <= fv) && v != x && v != w) continue;
            v = u;
            fv = fu;
        }
        xmin[0] = x;
        return fx;
    }

    private double neglogpost(int param, double x, List<Parameter> parameters) {
        double a = this.negloglike(param, x);
        double b = 0.0;
        if (!this.flatPrior) {
            b = this.neglogprior(param, x, parameters);
        }
        return a + b;
    }

    private double negloglike(int param, double x) {
        Parameter p = this.semPm.getParameters().get(param);
        double tparm = this.startIm.getParamValue(p);
        if ((p.getType() == ParamType.VAR || p.getType() == ParamType.COEF) && this.paramConstraints[param].wouldBeSatisfied(x)) {
            this.startIm.setParamValue(p, x);
        }
        double nll = -this.startIm.getTruncLL();
        this.startIm.setParamValue(p, tparm);
        return nll;
    }

    private double negchi2(int param, double x, List<Parameter> parameters) {
        double answer = 0.0;
        int n = 0;
        int numParameters = parameters.size();
        double[] xvec = new double[numParameters];
        double[] temp = new double[numParameters];
        for (int i = 0; i < numParameters; ++i) {
            Parameter p = parameters.get(i);
            if (p.isFixed()) continue;
            xvec[n] = i == param ? x - this.parameterMeans[i] : this.startIm.getParamValue(p) - this.parameterMeans[i];
        }
        DoubleMatrix2D invPrior = new Algebra().inverse(this.priorCov);
        for (int i = 0; i < n; ++i) {
            temp[i] = 0.0;
        }
        for (int col = 0; col < n; ++col) {
            for (int k = 0; k < n; ++k) {
                temp[col] = temp[col] + xvec[k] * invPrior.get(k, col);
            }
        }
        for (int k = 0; k < n; ++k) {
            answer += temp[k] * xvec[k];
        }
        return -answer;
    }

    private double neglogprior(int param, double x, List<Parameter> parameters) {
        return -this.negchi2(param, x, parameters) / 2.0;
    }

    public SemIm getEstimatedSem() {
        return this.estimatedSem;
    }

    public String toString() {
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        StringBuilder buf = new StringBuilder();
        buf.append("\nSemEstimator");
        if (this.getEstimatedSem() == null) {
            buf.append("\n\t...SemIm has not been estimated yet.");
        } else {
            SemIm sem = this.getEstimatedSem();
            buf.append("\n\n\tfml = ");
            buf.append("\n\n\tnegtruncll = ");
            buf.append(nf.format(-sem.getTruncLL()));
            buf.append("\n\n\tmeasuredNodes:\n\t");
            buf.append(sem.getMeasuredNodes());
            buf.append("\n\n\tedgeCoef:\n");
            buf.append(MatrixUtils.toString(sem.getEdgeCoef().toArray()));
            buf.append("\n\n\terrCovar:\n");
            buf.append(MatrixUtils.toString(sem.getErrCovar().toArray()));
        }
        return buf.toString();
    }

    private DataSet subset(DataSet dataSet, SemPm semPm) {
        String[] measuredVarNames = semPm.getMeasuredVarNames();
        int[] varIndices = new int[measuredVarNames.length];
        for (int i = 0; i < measuredVarNames.length; ++i) {
            Node variable = dataSet.getVariable(measuredVarNames[i]);
            varIndices[i] = dataSet.getVariables().indexOf(variable);
        }
        return dataSet.subsetColumns(varIndices);
    }

    private void setMeans(SemIm semIm, DoubleMatrix2D dataSet) {
        int i;
        double[] means = new double[semIm.getSemPm().getVariableNodes().size()];
        int numMeans = means.length;
        if (dataSet == null) {
            for (i = 0; i < numMeans; ++i) {
                means[i] = 0.0;
            }
        } else {
            double[] sum = new double[numMeans];
            for (int j = 0; j < dataSet.columns(); ++j) {
                for (int i2 = 0; i2 < dataSet.rows(); ++i2) {
                    int n = j;
                    sum[n] = sum[n] + dataSet.get(i2, j);
                }
                means[j] = sum[j] / (double)dataSet.rows();
            }
        }
        for (i = 0; i < semIm.getVariableNodes().size(); ++i) {
            Node node = semIm.getVariableNodes().get(i);
            semIm.setMean(node, means[i]);
        }
    }

    public void setEstimatedSem(SemIm estimatedSem) {
        this.estimatedSem = estimatedSem;
    }

    public SemPm getSemPm() {
        return this.semPm;
    }

    public DoubleMatrix2D getDataSet() {
        return this.dataSet;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
    }
}

