using MAT
using AdaProx

@enum QUADRATICS_TYPE begin
    quadratics_gaussian = 1
end

export quadratics_gaussian

function generate_quadratics_data(m, n, quadratics_type :: QUADRATICS_TYPE)
    if quadratics_type === quadratics_gaussian
        eigs = 50 * Random.randn(n)
        eigs = LA.diagm(eigs)
        Σ = Random.randn(n, n)
        Σ = Σ / LA.norm(Σ)
        Q = Σ' * eigs * Σ

        q = 2 * Random.randn(n)

        A = Random.randn(m, n)
        b = A * randn(n)

        lb = -5 * ones(n)
        ub = 5 * ones(n)
    else
        error("QUADRATICS_TYPE not recognized.")

    end

    return QUADRATICS(Q, q, A, b, lb, ub)
end

struct QUADRATICS{TQ, Tq}
    Q :: TQ
    q :: Tq
    A :: TQ
    b :: Tq
    lb :: Tq
    ub :: Tq
end

# QUADRATICS(Q :: TQ, q :: Tq, A :: TQ, b :: Tq, lb :: Tq, ub :: Tq) where {TQ, Tq} = QUADRATICS{TQ, Tq}(Q, q, A, b, lb, ub)

function A_eval(quadratics :: QUADRATICS, x)
    A, b = quadratics.A, quadratics.b

    return A * x - b
end

function JA_eval(quadratics :: QUADRATICS, x)
    A, b = quadratics.A, quadratics.b

    return A, A * x - b
end

(f :: QUADRATICS)(x) = begin
    Q, q = f.Q, f.q

    return 0.5 * (x' * Q * x) + q' * x
end

function ProximalOperators.:gradient!(y, f :: QUADRATICS, x)
    Q, q = f.Q, f.q

    # Update gradient value and evaluate f
    y .= Q * x + q

    return 0.5 * (x' * Q * x) + q' * x
end

function ProximalOperators.:prox!(y, g :: QUADRATICS, x, γ)
    lb, ub = g.lb, g.ub

    y .= min.(max.(x, lb), ub)

    return 0.
end

Base.@kwdef struct QUADRATICSSolver{Tf, Tg, TA, TJA, Tx, Ty, R, Tβ}
    problem :: QUADRATICS
    f :: Tf = problem
    g :: Tg = problem
    A :: TA = x -> A_eval(problem, x)
    JA :: TJA = x -> JA_eval(problem, x)
    x0 :: Tx
    y0 :: Ty
    σ0 :: R
    initial_inner_tol :: R
    λ :: R
    β :: Tβ
end

Base.@kwdef mutable struct ProximalOperatorsQUADRATICS
    solver :: QUADRATICSSolver
    state :: ALMState
    k
end

(f::ProximalOperatorsQUADRATICS)(x) = (al(f.solver, f.state, x, f.state.y, f.k))

function run_adapg(
    f :: ProximalOperators.Regularize,
    g,
    x0,
    should_stop,
    tol,
    maxit,
    gamma_init,
)

    # Setup 
    x = copy(x0)
    grad_fx = copy(x); fx = ProximalOperators.gradient!(grad_fx, f, x) 
    x_new = copy(x); gx = ProximalOperators.prox!(x_new, g, x - gamma_init .* grad_fx, gamma_init)

    grad_fx_new = copy(x)
    ProximalOperators.gradient!(grad_fx_new, f, x_new)

    L = LA.norm(grad_fx - grad_fx_new) / LA.norm(x - x_new)

    if 1 - 2*L > 0.
        gamma_prev = gamma_init * 2 * L / (1 - 2 * L)
    else
        gamma_prev = gamma_init
    end
    gamma = gamma_init
    lambda = gamma_init
    lambda_prev = gamma_prev
    temp = copy(x0)
    y = copy(x0)
    yold = copy(x0)

    k = 1
    while k <= maxit
        rho = gamma / gamma_prev
        theta = lambda / lambda_prev
        L = LA.norm(grad_fx - grad_fx_new) / LA.norm(x - x_new)

        gamma_prev = gamma
        gamma = min(sqrt(1+rho/2.)*gamma, 1. / (2 * L))

        lambda_prev = lambda
        lambda = min(sqrt(1+theta/2.) * lambda, L / (2))

        beta = (sqrt(1/ gamma) - sqrt(lambda)) / (sqrt(1/gamma)+sqrt(lambda))

        copyto!(x, x_new)
        copyto!(grad_fx, grad_fx_new)
        copyto!(yold, y)

        gx = ProximalOperators.prox!(y, g, x - gamma .* grad_fx, gamma)
        x_new .= y .+ beta * (y - yold)
        ProximalOperators.gradient!(grad_fx_new, f, x_new)

        if should_stop(y, grad_fx_new, temp, tol) || LA.norm(x - x_new, Inf) < 1e-12
            break
        end
        k+=1
    end

    return y, k + 2
end

function ProximalOperators.:gradient!(grad, quadratics :: ProximalOperatorsQUADRATICS, x)
    solver = quadratics.solver
    state = quadratics.state
    k = quadratics.k

    return al_gradx!(grad, solver, state, x, state.y, k)
end

function AdaProx.eval_with_pullback(f :: ProximalOperatorsQUADRATICS, x)
    y = copy(x)

    fx = ProximalOperators.gradient!(y, f, x)
    return fx, () -> y
end

function solve_primal!(solver :: QUADRATICSSolver, state :: ALMState, k)
    quadratics = ProximalOperatorsQUADRATICS(solver, state, k)

    x_curr = copy(state.x)
    ### Compute projection onto normal cone
    ### See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
    ###
    bitmask_lb = (state.x .<= quadratics.solver.problem.lb) 
    bitmask_ub = (state.x .>= quadratics.solver.problem.ub)
    # γ = max(state.minimum_gamma, state.γ_rule(solver.β(k), state.y))
    ProximalOperators.gradient!(x_curr, quadratics, state.x)
    # x_curr contains ∇L_β(x)
    state.res .= max.(-1 .* x_curr .* bitmask_ub .+ x_curr .* bitmask_lb, 0.)
    # state.res contains: proj_{>= 0}(-∇L_β(x) .* bitmask)
    state.res .= x_curr + (bitmask_ub .* state.res) .- (bitmask_lb .* state.res)
    # state.res contains ∇L_β(x) + proj_{>= 0}(-∇L_β(x) .* bitmask)

    if state.logging && k == 1
        state.hist.objective[1] = LA.norm(state.res, Inf)
        state.hist.feasibility[1] = LA.norm(solver.A(state.x))
        state.hist.nit[1] = 0
    end

    if state.triple_loop
        if !(state.inner_solver == AdaProx.adaptive_proxgrad)
            if state.adaptive_gamma
                ffb = state.inner_solver(
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                    minimum_gamma = state.minimum_gamma,
                    stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) / ffb_state.gamma < state.inner_tol),
                    adaptive = state.adaptive_gamma,
                )
            else        
                # ffb = ProximalAlgorithms.SFISTA(
                #     tol = max(state.inner_tol, 1e-3),
                #     maxit = round(state.inner_max_it),
                # )
                ffb = state.inner_solver(
                    tol = state.inner_tol,
                    maxit = round(state.inner_max_it),
                    minimum_gamma = state.minimum_gamma,
                    stop=(ffb_iter, ffb_state) -> begin 
                        # temp = LA.norm(ffb_state.res, Inf) / ffb_state.gamma
                        # return temp < state.inner_tol


                        res = copy(ffb_state.x); grad_x = copy(ffb_state.x)
                        ### Compute projection onto normal cone
                        ### See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
                        ###
                        bitmask_lb = (ffb_state.x .<= quadratics.solver.problem.lb) 
                        bitmask_ub = (ffb_state.x .>= quadratics.solver.problem.ub)
                        # γ = max(state.minimum_gamma, state.γ_rule(solver.β(k), state.y))
                        ProximalOperators.gradient!(grad_x, quadratics, ffb_state.x)
                        # x_curr contains ∇L_β(x)
                        res .= max.(-1 .* grad_x .* bitmask_ub .+ grad_x .* bitmask_lb, 0.)
                        # state.res contains: proj_{>= 0}(-∇L_β(x) .* bitmask)
                        res .= grad_x + (bitmask_ub .* res) .- (bitmask_lb .* res)
                        # state.res contains ∇L_β(x) + proj_{>= 0}(-∇L_β(x) .* bitmask)

                        # println("Inner res = $(LA.norm(res))")                        
                        return LA.norm(res) <= state.inner_tol
                    end,
                    adaptive = state.adaptive_gamma,
                    gamma = 0.5 / (2. / state.γ_rule(solver.β(k), state.y) + 2 * state.ρ_rule(solver.β(k), state.y)),
                )
            end
        end

        x_curr = copy(state.x)
        x_new = copy(state.x)
        iters = 0.
        ρ = state.ρ_rule(solver.β(k), state.y)
        println("ρ = $(ρ)")

        for k = 1:state.ippm_max_it
            F = ProximalOperators.Regularize(
                quadratics,
                2 * ρ,
                x_curr
            )
            if state.inner_solver == AdaProx.adaptive_proxgrad
                # x_ffb, iters_ffb = state.inner_solver(
                #     x_curr,
                #     f = F,
                #     g = solver.g,
                #     rule = AdaProx.OurRule(gamma = 0.5 / (2. / state.γ_rule(solver.β(k), state.y) + 2 * state.ρ_rule(solver.β(k), state.y))),
                #     tol = state.inner_tol / 4,
                #     maxit = round(state.inner_max_it),
                # )
                x_ffb, iters_ffb = run_adapg(
                    F,
                    solver.g,
                    x_curr,
                    (x, grad_x, temp, tol) -> begin
                        # Compute projection onto normal cone
                        # See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
                        #
                        bitmask_lb = (x .<= quadratics.solver.problem.lb) 
                        bitmask_ub = (x .>= quadratics.solver.problem.ub)
                        temp .= max.(-1 .* grad_x .* bitmask_ub .+ grad_x .* bitmask_lb, 0.)
                        temp .= grad_x + (bitmask_ub .* temp) .- (bitmask_lb .* temp)
                        
                        # println("AdaPG res = $(LA.norm(temp))")
                        return LA.norm(temp) <= tol
                    end,
                    state.inner_tol / 4.,
                    state.inner_max_it,
                    0.5 / (2. / state.γ_rule(solver.β(k), state.y) + 2 * state.ρ_rule(solver.β(k), state.y))
                )
            else 
                x_ffb, iters_ffb = ffb(x0 = x_curr, f = F, g = solver.g, Lf = (2. / state.γ_rule(solver.β(k), state.y) + 2 * state.ρ_rule(solver.β(k), state.y)), mf = ρ)
                # x_ffb, iters_ffb = ffb(x0 = x_curr, f = F, g = solver.g)
            end
            copyto!(x_new, x_ffb)
            iters += iters_ffb

            ### Compute projection onto normal cone
            ### See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
            ###
            bitmask_lb = (x_new .<= quadratics.solver.problem.lb) 
            bitmask_ub = (x_new .>= quadratics.solver.problem.ub)
            grad, _ = ProximalOperators.gradient(quadratics, x_new)
            dist = max.(-1 .* grad .* bitmask_ub .+ grad .* bitmask_lb, 0.)
            dist .= grad + (bitmask_ub .* dist) .- (bitmask_lb .* dist)

            println("dist AL = $(LA.norm(dist))")

            ### Compute projection onto normal cone
            ### See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
            ###
            bitmask_lb = (x_new .<= quadratics.solver.problem.lb) 
            bitmask_ub = (x_new .>= quadratics.solver.problem.ub)
            grad, _ = ProximalOperators.gradient(F, x_new)
            dist = max.(-1 .* grad .* bitmask_ub .+ grad .* bitmask_lb, 0.)
            dist .= grad + (bitmask_ub .* dist) .- (bitmask_lb .* dist)

            println("dist PP = $(LA.norm(dist))")
            println("opt_cond_pp = $(2 * ρ * LA.norm(x_new - x_curr))")
            println("Inner tolerance = $(2 * state.inner_tol)")

            # grad_curr, _ = ProximalOperators.gradient(F, x_curr)
            # println("μ > $(LA.dot(grad_curr - grad, x_curr - x_new) / LA.norm(x_curr - x_new)^2)\n")

            if 2 * ρ * LA.norm(x_new - x_curr) <= 2 * state.inner_tol #|| iters >= state.inner_max_it
                break
            end  

            copyto!(x_curr, x_new)
        end

        copyto!(state.x, x_new)
    elseif state.inner_solver == AdaProx.adaptive_proxgrad
        x_curr, iters = state.inner_solver(
            state.x,
            f = quadratics,
            g = solver.g,
            rule = AdaProx.OurRule(gamma = state.γ_rule(solver.β(k), state.y)),
            tol = state.inner_tol,
            maxit = round(state.inner_max_it),
        )
        copyto!(state.x, x_curr)
    else
        if state.adaptive_gamma
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> (LA.norm(ffb_state.res, Inf) / ffb_state.gamma < state.inner_tol),
                adaptive = state.adaptive_gamma,
            )
        else
            ffb = state.inner_solver(
                tol = state.inner_tol,
                maxit = round(state.inner_max_it),
                minimum_gamma = state.minimum_gamma,
                stop=(ffb_iter, ffb_state) -> begin 
                        # temp = LA.norm(ffb_state.res, Inf) / ffb_state.gamma
                        # return temp < state.inner_tol


                        res = copy(ffb_state.x); grad_x = copy(ffb_state.x)
                        ### Compute projection onto normal cone
                        ### See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
                        ###
                        bitmask_lb = (ffb_state.x .<= quadratics.solver.problem.lb) 
                        bitmask_ub = (ffb_state.x .>= quadratics.solver.problem.ub)
                        # γ = max(state.minimum_gamma, state.γ_rule(solver.β(k), state.y))
                        ProximalOperators.gradient!(grad_x, quadratics, ffb_state.x)
                        # x_curr contains ∇L_β(x)
                        res .= max.(-1 .* grad_x .* bitmask_ub .+ grad_x .* bitmask_lb, 0.)
                        # state.res contains: proj_{>= 0}(-∇L_β(x) .* bitmask)
                        res .= grad_x + (bitmask_ub .* res) .- (bitmask_lb .* res)
                        # state.res contains ∇L_β(x) + proj_{>= 0}(-∇L_β(x) .* bitmask)

                        # println("Inner res = $(LA.norm(res))")                        
                        return LA.norm(res) <= state.inner_tol
                end,
                adaptive = state.adaptive_gamma,
                gamma = state.γ_rule(solver.β(k), state.y)
            )
        end

        x_curr, iters = ffb(x0 = state.x, f = quadratics, g = solver.g)
        copyto!(state.x, x_curr)
    end

    # Compute FPR
    # γ = max(state.minimum_gamma, state.γ_rule(solver.β(k), state.y))
    # ProximalOperators.gradient!(x_curr, quadratics, state.x)
    # x_curr .*= γ
    # x_curr .-= state.x
    # x_curr .*= (-1.)
    # # x_curr contains: x - γ ∇f(x)
    # ProximalOperators.prox!(state.res, solver.g, x_curr, γ)
    # # state.res contains: prox_{γg}(x - γ ∇f(x))
    # state.res .-= state.x
    # state.res ./= (-γ)

    ### Compute projection onto normal cone
    ### See https://epubs.siam.org/doi/pdf/10.1137/120897547, Eq (2.7) and following
    ###
    bitmask_lb = (state.x .<= quadratics.solver.problem.lb) 
    bitmask_ub = (state.x .>= quadratics.solver.problem.ub)
    # γ = max(state.minimum_gamma, state.γ_rule(solver.β(k), state.y))
    ProximalOperators.gradient!(x_curr, quadratics, state.x)
    # x_curr contains ∇L_β(x)
    state.res .= max.(-1 .* x_curr .* bitmask_ub .+ x_curr .* bitmask_lb, 0.)
    # state.res contains: proj_{>= 0}(-∇L_β(x) .* bitmask)
    state.res .= x_curr + (bitmask_ub .* state.res) .- (bitmask_lb .* state.res)
    # state.res contains ∇L_β(x) + proj_{>= 0}(-∇L_β(x) .* bitmask)
    
    if state.verbose
        println("-----------------\tk=$(k)")
        println("Inner iterations: \t$(iters)")
        println("Objective: \t\t$(solver.f(state.x))")
        println("Constraint violation: \t$(LA.norm(solver.A(state.x)))")
        println("Suboptimality: \t\t$(LA.norm(state.res, Inf))")
        println()
    end

    if state.logging
        state.hist.objective[k+1] = LA.norm(state.res, Inf) #solver.f(state.x)
        state.hist.feasibility[k+1] = LA.norm(solver.A(state.x))
        state.hist.nit[k+1] = state.nit + iters
    end
    
    return iters
end

export generate_quadratics_data, QUADRATICS, QUADRATICSSolver