using Optimisers

grads!(dict::IdDict, ℓ::Optimisers.Leaf, x, ::Optimisers.Zero...) = nothing
grads!(dict::IdDict, t, x, ::Optimisers.Zero...) = nothing
function grads!(dict::IdDict, ℓ::Optimisers.Leaf, x, x̄s...)
    if haskey(dict, ℓ)
        x̄s₀ = dict[ℓ]
        foreach((x̄₀, x̄) -> (x̄₀ .+= x̄), x̄s₀, x̄s)
        foreach(maybe_free!, x̄s)
    else
        dict[ℓ] = x̄s
    end
    nothing
end
function grads!(dict::IdDict, tree, x, x̄s...)
    x̄s′ = map(x̄ -> Optimisers.functor(typeof(x), Optimisers.base(x̄))[1], x̄s)
    x′, _ = Optimisers.functor(typeof(x), x)
    Optimisers.valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

function foreachgrad(f, dict::IdDict)
    for x̄s in values(dict)
        for x̄ in x̄s
            f(x̄)
        end
    end
    nothing
end
