Base.@kwdef struct Solver{TP, Tf, Tg, TA, TJA, Tx, Ty, R, Tβ}
    problem :: TP
    f :: Tf
    g :: Tg
    A :: TA
    JA :: TJA
    x0 :: Tx
    y0 :: Ty
    σ0 :: R
    initial_inner_tol :: R
    λ :: R
    β :: Tβ
end

Base.@kwdef mutable struct ALMStateLogs
    objective
    feasibility
    nit
end

Base.@kwdef mutable struct ALMState{Tx, Ty, R, TS, IR, I}
    k :: I
    x :: Tx
    y :: Ty
    σ :: R
    q :: R
    inner_tol :: R
    res :: Tx
    Ax :: Ty
    Ax1 :: Ty
    inner_solver :: TS
    inner_max_it :: IR
    minimum_gamma :: R
    γ_rule
    ρ_rule
    adaptive_gamma :: Bool
    nit :: I
    verbose :: Bool
    logging :: Bool
    hist :: ALMStateLogs
    triple_loop :: Bool
    ippm_max_it :: I
end

function al(solver, state :: ALMState, x, y, k)
    return solver.f(x) + LA.dot(y, solver.A(x)) + solver.β(k) / (state.q + 1) * LA.norm(solver.A(x))^(state.q+1)
end

function al_gradx!(grad, solver, state :: ALMState, x, y, k)
    grad_fx, fx = ProximalOperators.gradient(solver.f, x)
    JA, Ax = solver.JA(x)

    if LA.norm(Ax) == 0.
        grad[:] .= grad_fx + JA' * y
    else
        grad[:] .= grad_fx + JA' * y + solver.β(k) * JA' * Ax * LA.norm(Ax)^(state.q - 1)
    end

    return fx + LA.dot(y, Ax) + solver.β(k) / (state.q + 1) * LA.norm(Ax)^(state.q+1)
end

function update_inner_tolerance!(solver, state :: ALMState, k)
    state.inner_tol = max(solver.λ / solver.β(k), 1e-8)
end

function solve_primal!(solver, state :: ALMState, k)
    error("Not implemented")
end

function update_dual_stepsize!(solver, state :: ALMState, k)

    # Tuning from Sahin et al
    state.σ = solver.σ0 * min(
        1. / sqrt(k+1),
        (100) / (LA.norm(state.Ax)^state.q * (k+1) * log(10, k+2)^2)
    )

    # state.σ = solver.σ0 * min(
    #     1.,
    #     (LA.norm(state.Ax1)^state.q * log(10, 2)^2) / (LA.norm(state.Ax)^state.q * (k+1) * log(10, k+2)^2)        
    # )

end

function update_dual!(solver, state :: ALMState)

    state.y .+= (state.σ * LA.norm(state.Ax)^(state.q - 1)) * state.Ax

end

function perform_alm_iteration!(solver, state :: ALMState, k)
    # Update the inner tolerance
    update_inner_tolerance!(solver, state, k)

    # Solve the primal problem up to some tolerance
    nit = solve_primal!(solver, state, k)

    # Store some values to avoid recomputations
    state.Ax .= solver.A(state.x)

    # Update the dual step sizes
    update_dual_stepsize!(solver, state, k)

    # Dual ascent
    update_dual!(solver, state)

    return nit
end

function run_alm(
    solver, 
    should_stop;
    q = 1., 
    maxit = 20, 
    tol = 1e-8, 
    inner_max_it = 25,
    inner_solver = ProximalAlgorithms.FastForwardBackward,
    update_q = (_, _, _) -> (),
    update_inner_max_it = (_, _, _) -> (),
    minimum_gamma = 1e-7,
    γ_rule = (_, _) -> nothing,
    ρ_rule = (_, _) -> nothing,
    adaptive_gamma = true,
    verbose = false,
    logging = false,
    triple_loop = false,
    ippm_max_it = 10,
)
    state = ALMState(
        1,
        copy(solver.x0),
        copy(solver.y0),
        copy(solver.σ0),
        q,
        copy(solver.initial_inner_tol),
        copy(solver.x0),
        solver.A(solver.x0),
        solver.A(solver.x0),
        inner_solver,
        inner_max_it,
        minimum_gamma,
        γ_rule,
        ρ_rule,
        adaptive_gamma,
        0,
        verbose,
        logging,
        ALMStateLogs(
            zeros(maxit+1),
            zeros(maxit+1),
            zeros(maxit+1),
        ),
        triple_loop,
        ippm_max_it,
    )

    for k = 1:maxit
        state.k = k

        nit_inner = perform_alm_iteration!(solver, state, k)
        state.nit += nit_inner

        update_q(solver, state, k)

        update_inner_max_it(solver, state, k)

        if should_stop(solver, state)
            break
        end
    end

    return state
end

export run_alm