/*
 * Decompiled with CFR 0.152.
 */
package edu.wisc.game.tools;

import edu.wisc.game.tools.LoglikProblem;
import edu.wisc.game.util.Util;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;

class LoglikP0Problem
extends LoglikProblem {
    final double[] p0;
    final int defect;

    LoglikP0Problem(int[] _y, double[] _p0) {
        super(_y);
        this.p0 = _p0;
        int n = 0;
        for (double p : this.p0) {
            if (p != 1.0) continue;
            ++n;
        }
        this.defect = n;
    }

    @Override
    int size() {
        return this.y.length - this.defect;
    }

    @Override
    public ObjectiveFunction getObjectiveFunction() {
        return new ObjectiveFunction(new MultivariateFunction(){

            public double value(double[] point) {
                double B = point[0];
                double C = point[1];
                double tI = point[2];
                double k = point[3];
                double sum = 0.0;
                for (int t = 0; t < LoglikP0Problem.this.y.length; ++t) {
                    double u = ((double)t - tI) * k;
                    double ex = Math.exp(-u);
                    double rex = Math.exp(u);
                    double g = B / (1.0 + rex) + C / (1.0 + ex);
                    if (LoglikP0Problem.this.p0[t] == 1.0) continue;
                    double p = LoglikP0Problem.this.p0[t] + (1.0 - LoglikP0Problem.this.p0[t]) * g;
                    sum += LoglikProblem.regLog(LoglikP0Problem.this.y[t] == 1 ? p : 1.0 - p);
                }
                if (LoglikProblem.verbose) {
                    System.out.println("f(" + Util.joinNonBlank(",", LoglikProblem.df, point) + ") = " + LoglikProblem.df.format(sum));
                }
                return sum;
            }
        });
    }

    @Override
    public ObjectiveFunctionGradient getObjectiveFunctionGradient() {
        return new ObjectiveFunctionGradient(new MultivariateVectorFunction(){

            public double[] value(double[] point) {
                double B = point[0];
                double C = point[1];
                double tI = point[2];
                double k = point[3];
                double[] sum = new double[point.length];
                for (int t = 0; t < LoglikP0Problem.this.y.length; ++t) {
                    double u = ((double)t - tI) * k;
                    double ex = Math.exp(-u);
                    double rex = Math.exp(u);
                    double g = B / (1.0 + rex) + C / (1.0 + ex);
                    if (LoglikP0Problem.this.p0[t] == 1.0) continue;
                    double op = 1.0 - LoglikP0Problem.this.p0[t];
                    double p = LoglikP0Problem.this.p0[t] + op * g;
                    double r = LoglikP0Problem.this.y[t] == 1 ? LoglikProblem.regLogDerivative(p) : -LoglikProblem.regLogDerivative(1.0 - p);
                    sum[0] = sum[0] + op * r / (1.0 + rex);
                    sum[1] = sum[1] + op * r / (1.0 + ex);
                    double z = op * r * (C - B) / (2.0 + ex + rex);
                    sum[2] = sum[2] + -k * z;
                    sum[3] = sum[3] + ((double)t - tI) * z;
                }
                if (LoglikProblem.verbose) {
                    System.out.println("gradF(" + Util.joinNonBlank(",", LoglikProblem.df, point) + ") = " + Util.joinNonBlank(",", LoglikProblem.df, sum));
                }
                return sum;
            }
        });
    }
}

