using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalCore: Zero
using LinearAlgebra
using Printf

"""
    UPFAGIteration(; <keyword-arguments>)

Iterator implementing the unified problem-parameter free accelerated gradient algorithm [1].

This iterator solves nonconvex optimization problems of the form

    minimize f(x) + g(x),

where `f` is H\"older smooth and g is convex.

See also: [`UPFAG`](@ref).

# Arguments
- `x0`: initial point.
- `f=Zero()`: smooth objective term.
- `g=Zero()`: proximable objective term.
- `mf=0`: convexity modulus of `f`.
- `Lf=nothing`: Lipschitz constant of the gradient of `f`.
- `gamma=nothing`: stepsize, defaults to `1/Lf` if `Lf` is set, and `nothing` otherwise.
- `adaptive=true`: makes `gamma` adaptively adjust during the iterations; this is by default `gamma === nothing`.
- `minimum_gamma=1e-7`: lower bound to `gamma` in case `adaptive == true`.
- `reduce_gamma=0.5`: factor by which to reduce `gamma` in case `adaptive == true`, during backtracking.
- `increase_gamma=1.0`: factor by which to increase `gamma` in case `adaptive == true`, before backtracking.

# References
1. Saeed Ghadimi, Guanghui Lan, and Hongchao Zhang. Generalized Uniformly Optimal Methods for Nonlinear
Programming. Journal of Scientific Computing, 79(3):1854–1881, June 2019.
"""
Base.@kwdef struct UPFAGIteration{R,Tx,Tf,Tg,TLf,Tgamma}
    f::Tf = Zero()
    g::Tg = Zero()
    x0::Tx
    mf::R = real(eltype(x0))(0)
    Lf::TLf = nothing
    gamma::Tgamma = Lf === nothing ? nothing : (1 / Lf)
    adaptive::Bool = gamma === nothing
    minimum_gamma::R = real(eltype(x0))(1e-7)
    max_backtracks::Int = 20
    gamma_1::R = real(eltype(x0))(0.5)
    gamma_2::R = real(eltype(x0))(0.5)
    gamma_3::R = real(eltype(x0))(0.005)
    δ::R = real(eltype(x0))(1e-3)
end

Base.IteratorSize(::Type{<:UPFAGIteration}) = Base.IsInfinite()

Base.@kwdef mutable struct UPFAGState{R,Tx}
    x::Tx             # iterate
    f_x::R            # value f at x
    grad_f_x::Tx      # gradient of f at x
    gamma::R          # stepsize parameter of forward and backward steps
    λhat::R           # stepsize parameter of forward and backward steps
    βhat::R           # stepsize parameter of forward and backward steps
    y::Tx             # forward point
    z::Tx             # forward-backward point
    g_z::R            # value of g at z
    res::Tx           # fixed-point residual at iterate (= z - x)
    x_prev::Tx = copy(x)
    xag_prev::Tx = copy(x)
    x_ag::Tx = copy(x)
    xmd::Tx = copy(x)
    x_tilde::Tx = copy(x)
    xbar::Tx = copy(x)
    λ::R = zero(gamma)
    η::R = zero(gamma)
    β::R = zero(gamma)
    tau1::R = zero(gamma)
    tau2::R = zero(gamma)
    Λ::R = zero(gamma)
    α::R = zero(gamma)
    it::Int64 = 1
end

function Base.iterate(iter::UPFAGIteration)
    x = copy(iter.x0)
    grad_f_x, f_x = ProximalAlgorithms.gradient(iter.f, x)
    gamma = iter.gamma
    y = x - gamma .* grad_f_x
    z, g_z = ProximalAlgorithms.prox(iter.g, y, gamma)
    state = UPFAGState(
        x=x, f_x=f_x, grad_f_x=grad_f_x, gamma = gamma, λhat=gamma, βhat = gamma,
        y=y, z=z, g_z=g_z, res=x - z,
    )
    return state, state
end

function Base.iterate(iter::UPFAGIteration{R}, state::UPFAGState{R,Tx}) where {R,Tx}
    copyto!(state.x_prev, state.x)

    # Step 1: Compute xtilde
    # state.tau1 = R(1)
    state.η = state.λhat #* iter.gamma_1^state.tau1
    state.λ = 0.5 * (state.η + sqrt(state.η^2 + 4 * state.η * state.Λ))

    state.Λ += state.λ
    state.α = state.λ / state.Λ

    state.xmd .= (1 - state.α) .* state.xag_prev .+ state.α .* state.x_prev

    f_md = ProximalAlgorithms.gradient!(state.grad_f_x, iter.f, state.xmd)
    state.y .= state.x_prev .- state.λ .* state.grad_f_x
    state.g_z = ProximalAlgorithms.prox!(state.x, iter.g, state.y, state.λ)

    state.x_tilde .= (1 - state.α) .* state.xag_prev .+ state.α .* state.x
    f_tilde = iter.f(state.x_tilde)

    for k in 1:iter.max_backtracks
        if f_tilde <= f_md + state.α * dot(state.grad_f_x, state.x .- state.x_prev) + state.α / (2 * state.λ) norm(state.x .- state.x_prev)^2 + iter.δ * state.α
            break
        end

        state.Λ -= state.λ

        state.η = k >= iter.max_backtracks ? R(0) : state.η * iter.gamma_1
        state.λ = 0.5 * (state.η + sqrt(state.η^2 + 4 * state.η * state.Λ))

        state.Λ += state.λ
        state.α = state.λ / state.Λ

        state.xmd .= (1 - state.α) .* state.xag_prev .+ state.α .* state.x_prev

        f_md = ProximalAlgorithms.gradient!(state.grad_f_x, iter.f, state.xmd)
        state.y .= state.x_prev .- state.λ .* state.grad_f_x
        state.g_z = ProximalAlgorithms.prox!(state.x, iter.g, state.y, state.λ)

        state.x_tilde .= (1 - state.α) .* state.xag_prev .+ state.α .* state.x
        f_tilde = iter.f(state.x_tilde)
    end

    # state.λhat = state.η

    psi_xtilde = f_tilde #+ iter.g(state.x_tilde)

    # Step 2: Compute xbar
    # state.tau2 = R(1)
    state.β = state.βhat #* iter.gamma_2^state.tau2

    state.f_x = ProximalAlgorithms.gradient!(state.grad_f_x, iter.f, state.xag_prev)
    state.y .= state.xag_prev .- state.β .* state.grad_f_x
    state.g_z = ProximalAlgorithms.prox!(state.xbar, iter.g, state.y, state.β)

    psi_ag_old = iter.f(state.xag_prev) #+ iter.g(state.xag_prev)
    psi_xbar = iter.f(state.xbar) #+ iter.g(state.xbar)

    state.it += 1
    tol = 1. / state.it

    # println("\n")

    for k in 1:iter.max_backtracks
        # println("psi_xbar = $(psi_xbar),\tpsi_ag_old = $(psi_ag_old),\tf_xbar = $(iter.f(state.xbar)),\tf_ag_old = $(iter.f(state.xag_prev)),\tgrad_f_ag_prev=$(norm(state.grad_f_x)),\tres = $(norm(state.xbar .- state.xag_prev))")
        if psi_xbar <= psi_ag_old - iter.gamma_3 / (2 * state.β) * norm(state.xbar .- state.xag_prev)^2 + tol
            break
        end

        state.β = k >= iter.max_backtracks ? R(0) : state.β * iter.gamma_2

        # state.f_x = ProximalAlgorithms.gradient!(state.grad_f_x, iter.f, state.xag_prev)
        state.y .= state.xag_prev .- state.β .* state.grad_f_x
        state.g_z = ProximalAlgorithms.prox!(state.xbar, iter.g, state.y, state.β)

        psi_xbar = iter.f(state.xbar) #+ iter.g(state.xbar)
    end

    # state.βhat = state.β

    # state.res .= state.xag_prev .- state.xbar
    # state.gamma = state.β

    # Step 3: Update x
    copyto!(state.xag_prev, state.x_ag)
    if psi_xtilde <= psi_ag_old && psi_xtilde <= psi_xbar
        # println("x_tilde accepted")
        copyto!(state.x_ag, state.x_tilde)
    elseif psi_xbar <= psi_ag_old && psi_xbar <= psi_xtilde
        # println("xbar accepted")
        copyto!(state.x_ag, state.xbar)
    else
        # println("None accepted")
        copyto!(state.x_ag, state.xag_prev)
    end

    state.f_x = ProximalAlgorithms.gradient!(state.grad_f_x, iter.f, state.x_ag)
    state.y .= state.x_ag .- state.gamma .* state.grad_f_x
    state.g_z = ProximalAlgorithms.prox!(state.z, iter.g, state.y, state.gamma)
    state.res .= state.x .- state.z

    # println("\n")
    # println("psi_xbar = $(psi_xbar)\tpsi_xtilde = $(psi_xtilde)\tpsi_ag_old = $(psi_ag_old)")
    # println("λ = $(state.λ)\tβ = $(state.β)")
    # println("$(norm(state.res))\t\t$(norm(state.res)/state.gamma)")
    # println(state.gamma)

    return state, state
end

default_stopping_criterion(tol, ::UPFAGIteration, state::UPFAGState) = norm(state.res, Inf) / state.gamma <= tol
default_solution(::UPFAGIteration, state::UPFAGState) = state.x_ag
default_display(it, ::UPFAGIteration, state::UPFAGState) = @printf("%5d | %.3e | %.3e\n", it, state.gamma, norm(state.res, Inf) / state.gamma)

"""
    UPFAG(; <keyword-arguments>)

Constructs the unified problem-parameter free accelerated gradient algorithm [1].

This algorithm solves nonconvex optimization problems of the form

    minimize f(x) + g(x),

where `f` is H\"older smooth and g is convex.

The returned object has type `IterativeAlgorithm{UPFAGIteration}`,
and can be called with the problem's arguments to trigger its solution.

See also: [`UPFAGIteration`](@ref), [`IterativeAlgorithm`](@ref).

# Arguments
- `maxit::Int=10_000`: maximum number of iteration
- `tol::1e-8`: tolerance for the default stopping criterion
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
- `verbose::Bool=false`: whether the algorithm state should be displayed
- `freq::Int=100`: every how many iterations to display the algorithm state
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `kwargs...`: additional keyword arguments to pass on to the `UPFAGIteration` constructor upon call

# References
1. Saeed Ghadimi, Guanghui Lan, and Hongchao Zhang. Generalized Uniformly Optimal Methods for Nonlinear
Programming. Journal of Scientific Computing, 79(3):1854–1881, June 2019.
"""
UPFAG(;
    maxit=10_000,
    tol=1e-8,
    stop=(iter, state) -> default_stopping_criterion(tol, iter, state),
    solution=default_solution,
    verbose=false,
    freq=100,
    display=default_display,
    kwargs...
) = ProximalAlgorithms.IterativeAlgorithm(UPFAGIteration; maxit, stop, solution, verbose, freq, display, kwargs...)

# Aliases

const FastProximalGradientIteration = UPFAGIteration
const FastProximalGradient = UPFAG

export UPFAG
