/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.csb.mgm;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.pitt.csb.mgm.ConvexProximal;
import org.apache.commons.math3.util.FastMath;

public class ProximalGradient {
    private final double beta;
    private final double alpha;
    private final Algebra alg = new Algebra();
    private final DoubleFactory1D factory1D = DoubleFactory1D.dense;
    private final boolean edgeConverge;
    private int noEdgeChangeTol = 3;

    public ProximalGradient(double beta, double alpha, boolean edgeConverge) {
        if (beta <= 0.0 || beta >= 1.0) {
            throw new IllegalArgumentException("beta must be (0,1): " + beta);
        }
        if (alpha <= 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("alpha must be (0,1): " + alpha);
        }
        this.beta = beta;
        this.alpha = alpha;
        this.edgeConverge = edgeConverge;
    }

    public ProximalGradient() {
        this.beta = 0.5;
        this.alpha = 0.9;
        this.edgeConverge = false;
    }

    public void setEdgeChangeTol(int t) {
        this.noEdgeChangeTol = t;
    }

    public DoubleMatrix1D learnBackTrack(ConvexProximal cp, DoubleMatrix1D Xin, double epsilon, int iterLimit) {
        DoubleMatrix1D X;
        block18: {
            double L;
            double theta;
            X = cp.proximalOperator(1.0, Xin.copy());
            DoubleMatrix1D Y = X.copy();
            DoubleMatrix1D Z = X.copy();
            DoubleMatrix1D GrY = cp.smoothGradient(Y);
            DoubleMatrix1D GrX = cp.smoothGradient(X);
            int iterCount = 0;
            int noEdgeChangeCount = 0;
            double thetaOld = theta = Double.POSITIVE_INFINITY;
            double Lold = L = 1.0;
            boolean backtrackSwitch = true;
            double Fx = Double.POSITIVE_INFINITY;
            double Gx = Double.POSITIVE_INFINITY;
            do {
                Lold = L;
                L *= this.alpha;
                thetaOld = theta;
                DoubleMatrix1D Xold = X.copy();
                double obj = Fx + Gx;
                while (true) {
                    double LocalL;
                    if ((theta = 2.0 / (1.0 + FastMath.sqrt(1.0 + 4.0 * L / (Lold * FastMath.pow(thetaOld, 2))))) < 1.0) {
                        Y.assign(Xold.copy().assign(Functions.mult(1.0 - theta)));
                        Y.assign(Z.copy().assign(Functions.mult(theta)), Functions.plus);
                    }
                    double Fy = cp.smooth(Y, GrY);
                    DoubleMatrix1D temp = Y.copy().assign(GrY.copy().assign(Functions.mult(1.0 / L)), Functions.minus);
                    Gx = cp.nonSmooth(1.0 / L, temp, X);
                    Fx = backtrackSwitch ? cp.smoothValue(X) : cp.smooth(X, GrX);
                    DoubleMatrix1D XmY = X.copy().assign(Y, Functions.minus);
                    double normXY = this.alg.norm2(XmY);
                    if (normXY == 0.0) break;
                    if (backtrackSwitch) {
                        double Qx = Fy + this.alg.mult(XmY, GrY) + L / 2.0 * normXY;
                        LocalL = L + 2.0 * FastMath.max(Fx - Qx, 0.0) / normXY;
                        double backtrackTol = 1.0E-10;
                        backtrackSwitch = FastMath.abs(Fy - Fx) >= backtrackTol * FastMath.max(FastMath.abs(Fx), FastMath.abs(Fy));
                    } else {
                        LocalL = 2.0 * this.alg.mult(XmY, GrX.assign(GrY, Functions.minus)) / normXY;
                    }
                    if (LocalL <= L) break;
                    if (LocalL != Double.POSITIVE_INFINITY) {
                        L = LocalL;
                    } else {
                        LocalL = L;
                    }
                    L = FastMath.max(LocalL, L / this.beta);
                }
                int diffEdges = 0;
                for (int i = 0; i < X.size(); ++i) {
                    double b;
                    double a = X.get(i);
                    if (a != 0.0 & (b = Xold.get(i)) == 0.0) {
                        ++diffEdges;
                        continue;
                    }
                    if (!(a == 0.0 & b != 0.0)) continue;
                    ++diffEdges;
                }
                double dx = ProximalGradient.norm2(X.copy().assign(Xold, Functions.minus)) / FastMath.max(1.0, ProximalGradient.norm2(X));
                if (diffEdges == 0 && this.edgeConverge) {
                    if (++noEdgeChangeCount >= this.noEdgeChangeTol) {
                        System.out.println("Edges converged at iter: " + iterCount + " with |dx|/|x|: " + dx);
                        System.out.println("Iter: " + iterCount + " |dx|/|x|: " + dx + " normX: " + ProximalGradient.norm2(X) + " nll: " + Fx + " reg: " + Gx + " DiffEdges: " + 0 + " L: " + L);
                        break block18;
                    }
                } else {
                    if (this.noEdgeChangeTol < 0 && diffEdges <= FastMath.abs(this.noEdgeChangeTol)) {
                        System.out.println("Edges converged at iter: " + iterCount + " with |dx|/|x|: " + dx);
                        System.out.println("Iter: " + iterCount + " |dx|/|x|: " + dx + " normX: " + ProximalGradient.norm2(X) + " nll: " + Fx + " reg: " + Gx + " DiffEdges: " + diffEdges + " L: " + L);
                        break block18;
                    }
                    noEdgeChangeCount = 0;
                }
                if (dx < epsilon && !this.edgeConverge) {
                    System.out.println("Converged at iter: " + iterCount + " with |dx|/|x|: " + dx + " < epsilon: " + epsilon);
                    System.out.println("Iter: " + iterCount + " |dx|/|x|: " + dx + " normX: " + ProximalGradient.norm2(X) + " nll: " + Fx + " reg: " + Gx + " DiffEdges: " + diffEdges + " L: " + L);
                    break block18;
                }
                if (Fx + Gx > obj) {
                    theta = Double.POSITIVE_INFINITY;
                    Y.assign(X.copy());
                    Z.assign(X.copy());
                } else if (theta == 1.0) {
                    Z.assign(X.copy());
                } else {
                    Z.assign(X.copy().assign(Functions.mult(1.0 / theta)));
                    Z.assign(Xold.copy().assign(Functions.mult(1.0 - 1.0 / theta)), Functions.plus);
                }
                int printIter = 100;
                if (iterCount % printIter != 0) continue;
                System.out.println("Iter: " + iterCount + " |dx|/|x|: " + dx + " normX: " + ProximalGradient.norm2(X) + " nll: " + Fx + " reg: " + Gx + " DiffEdges: " + diffEdges + " L: " + L);
            } while (++iterCount < iterLimit);
            System.out.println("Iter limit reached");
        }
        return X;
    }

    public static double norm2(DoubleMatrix1D vec) {
        return FastMath.sqrt(new Algebra().norm2(vec));
    }
}

