
is_proximal(::Optimisers.AbstractRule) = false

struct ProxGenAdam{T} <: Optimisers.AbstractRule
    #=
    Yun, Jihun, Aurélie C. Lozano, and Eunho Yang.
    "Adaptive proximal gradient methods for structured neural networks."
    Advances in Neural Information Processing Systems 34 (2021): 24365-24378.
    =##

  eta::T
  beta::Tuple{T, T}
  epsilon::T
end
ProxGenAdam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(eltype(η))) = ProxGenAdam{typeof(η)}(η, β, ϵ)

is_proximal(::ProxGenAdam) = true

Optimisers.init(o::ProxGenAdam, x::AbstractArray) = (zero(x), zero(x), o.beta[1], 1)

function Optimisers.apply!(o::ProxGenAdam, state, x, dx)
    η, β, ϵ = o.eta, o.beta, o.epsilon
    mt, vt, β₁ₜ, t = state

    Optimisers.@.. mt = β₁ₜ * mt + (1 - β₁ₜ) * dx
    Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)

    if t == 1
        Optimisers.@.. vt = abs2(dx)
    end

    dx′ = Optimisers.@lazy mt / (sqrt(vt) + ϵ) * η

    return (mt, vt, β₁ₜ*β[1], t+1), dx′
end

function prox_scale(optimizer::ProxGenAdam, optstate, λ::AbstractArray, unflatten, flatten)
    #=
    Yun, Jihun, Aurélie C. Lozano, and Eunho Yang.
    "Adaptive proximal gradient methods for structured neural networks."
    Advances in Neural Information Processing Systems 34 (2021): 24365-24378.

    Domke, Justin.
    "Provable smoothness guarantees for black-box variational inference."
    International Conference on Machine Learning. PMLR, 2020.
    =##

    η, ϵ  = optimizer.eta, optimizer.epsilon
    _, vₜ, _ = optstate

    m, s, L_low = unflatten(λ)
    _, v_s, _   = unflatten(vₜ)

    γ = @. η / (sqrt(v_s)  + ϵ)
    s = @. s + (sqrt(s^2 + 4*γ) - s)/2

    flatten(m, s, L_low)
end

# Optimisers.init(o::ProxGenAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, o.epsilon)

# function Optimisers.apply!(o::ProxGenAdam, state, x, dx)
#     η, β, ϵ = o.eta, o.beta, o.epsilon
#     mt, vt, βt, ϵt = state

#     Optimisers.@.. mt = β[1] * mt + (1 - β[1]) * dx
#     Optimisers.@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
#     dx′ = Optimisers.@lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵt) * η

#     return (mt, vt, βt.*β, ϵt*0.999), dx′
# end

# function prox_scale(optimizer::ProxGenAdam, optstate, λ::AbstractArray, unflatten, flatten)
#     #=
#     Yun, Jihun, Aurélie C. Lozano, and Eunho Yang.
#     "Adaptive proximal gradient methods for structured neural networks."
#     Advances in Neural Information Processing Systems 34 (2021): 24365-24378.

#     Domke, Justin.
#     "Provable smoothness guarantees for black-box variational inference."
#     International Conference on Machine Learning. PMLR, 2020.
#     =##

#     η, β, ϵ  = optimizer.eta, optimizer.beta, optimizer.epsilon
#     _, vₜ, βt, ϵt = optstate

#     m, s, L_low = unflatten(λ)
#     _, v_s, _   = unflatten(vₜ)

#     #γ  = @. η / (sqrt(v_s) + ϵ)
#     γ  = @. η / (sqrt(v_s) + ϵt)
#     Δs = @. (sqrt(s^2 + 4*γ) - s)/2

#     s += Δs
#     flatten(m, s, L_low)
# end


struct ProxDescent{T} <: Optimisers.AbstractRule
    eta::T
end

ProxDescent(η = 1f-3) = ProxDescent{typeof(η)}(η)

is_proximal(::ProxDescent) = true

Optimisers.init(::ProxDescent, ::AbstractArray) = nothing

function Optimisers.apply!(o::ProxDescent, state, x, dx)
  η = convert(float(eltype(x)), o.eta)
  return state, Optimisers.@lazy dx * η
end

function prox_scale(optimizer::ProxDescent, ::Any, λ::AbstractArray, unflatten, flatten)
    #=
    Yun, Jihun, Aurélie C. Lozano, and Eunho Yang.
    "Adaptive proximal gradient methods for structured neural networks."
    Advances in Neural Information Processing Systems 34 (2021): 24365-24378.

    Domke, Justin.
    "Provable smoothness guarantees for black-box variational inference."
    International Conference on Machine Learning. PMLR, 2020.
    =##

    η = convert(float(eltype(λ)), optimizer.eta)
    m, s, L_low = unflatten(λ)

    s = Optimisers.@. s + (sqrt(s^2 + 4*η) - s)/2

    flatten(m, s, L_low)
end
