
function grad_elbo!(rng, logdensityprob, λ::AbstractVector,
                    M::Int, φ, ψ⁻¹, ϕ, 
                    unflatten, param_type, estimator_type, ad_type,
                    sample_batch, diffresult)
    d      = LogDensityProblems.dimension(logdensityprob)
    q_stop = contruct_q(param_type, λ, ϕ, unflatten)

    if !isnothing(sample_batch)
        logdensityprob = sample_batch(logdensityprob)
    end

    diffresult = grad!(ad_type, λ, diffresult) do λ′
        -elbo(rng, estimator_type, logdensityprob, λ′, q_stop, ψ⁻¹, ϕ, φ, d, M, param_type, unflatten)
    end
    nelbo = DiffResults.value(diffresult)
    g     = DiffResults.gradient(diffresult)
    nelbo, g, q_stop
end

function bbvi(logdensityprob, M, T, m₀, C₀;
              rng            = Random.GLOBAL_RNG,
              ψ⁻¹            = Bijectors.identity,
              ϕ              = exp,
              φ              = Normal(),
              optimizer      = Otimisers.ADAM(),
              show_progress  = true,
              callback!      = nothing,
              param_type     = :cholesky,
              ad_type        = ForwardDiffAD,
              sample_batch   = nothing,
              projection     = (λ, flatten, unflatten) -> λ,
              terminate      = nothing,
              estimator_type = ClosedFormEntropy{is_proximal(optimizer)}())
    flatten, unflatten = get_flatten_utils(Val(param_type), logdensityprob)
    
    λ = if param_type == :cholesky
        flatten(m₀, diag(C₀), tril(C₀, -1))
    elseif param_type == :squareroot
        flatten(m₀, nothing, C₀)
    elseif param_type == :meanfield
        flatten(m₀, diag(C₀), nothing)
    end

    prog     = Progress(T; enabled=show_progress, showspeed=true)
    stats    = Vector{NamedTuple}(undef, T)
    grad_buf = DiffResults.GradientResult(λ)

    optstate = Optimisers.init(optimizer, λ)
    for t = 1:T
        nelbo, g, q_stop = grad_elbo!(
            rng, logdensityprob, λ, M, φ, ψ⁻¹, ϕ, 
            unflatten, param_type, estimator_type, ad_type, sample_batch, grad_buf)

        if !isfinite(nelbo) || !all(isfinite.(g)) || !all(isfinite.(λ))
            @error("ELBO is not finite!\n", nelbo, norm(g), norm(λ))
            stats[t:end] = [(t=t′, elbo=-Inf) for t′ = t:T]
            return nothing, nothing, stats
        end

        optstate, dλ = Optimisers.apply!(optimizer, optstate, λ, g)
        Optimisers.subtract!(λ, dλ)

        if is_proximal(optimizer)
            λ = prox_scale(optimizer, optstate, λ, unflatten, flatten)
        end

        stat     = (t=t, elbo=-nelbo,)
        stats[t] = stat
        if !isnothing(callback!)
            stat′  = callback!(t, stats, λ, q_stop, -nelbo, g)
            stat   = merge(stat, stat′)
        end

        if !isnothing(projection)
            λ = projection(λ, flatten, unflatten)
        end

        stats[t] = stat

        if !isnothing(terminate) && terminate(t, λ, q_stop, stats)
            stats[t+1:end] = [(t=t′, elbo=-Inf) for t′ = t+1:T]
            break
        end

        pm_next!(prog, stat)
    end
    q = contruct_q(param_type, λ, ϕ, unflatten)
    q, λ, stats
end
