
# Convert a vector to lower-triangular matrix
function vec_to_lowtril(v::AbstractVector{T}) where {T}
    n_vec  = length(v)
    n_tril = Int((sqrt(1 + 8 * n_vec) + 1) / 2) # Infer the size of the matrix from the vector
    L = zeros(T, n_tril, n_tril)
    L[tril!(trues(size(L)), -1)] = v
    return L
end

function ChainRulesCore.rrule(::typeof(vec_to_lowtril), v::AbstractVector{T}) where {T}
    L = vec_to_lowtril(v)
    pullback_vec_to_lowtril(Δ) = NoTangent(), lowtril_to_vec(unthunk(Δ))
    return L, pullback_vec_to_lowtril
end

# Convert a lower-triangular matrix to a vector (without the zeros)
# Adapted from https://stackoverflow.com/questions/50651781/extract-lower-triangle-portion-of-a-matrix
function lowtril_to_vec(X::AbstractMatrix{T}) where {T}
    n, m = size(X)
    n == m || error("Matrix needs to be square")
    return X[tril!(trues(size(X)), -1)]
end

function get_flatten_utils(::Val{:squareroot}, problem)
    d = LogDensityProblems.dimension(problem)
    function flatten(m, s, C)
        vcat(m, reshape(C, :))
    end
    function unflatten(λ)
        m = λ[1:d]
        C = reshape(λ[d+1:end], (d, d)) |> Hermitian
        m, nothing, C
    end
    flatten, unflatten
end

function get_flatten_utils(::Val{:cholesky}, problem)
    d = LogDensityProblems.dimension(problem)
    function flatten(m, s, L_low)
        λ_L_low = lowtril_to_vec(L_low)
        vcat(m, s, λ_L_low)
    end
    function unflatten(λ)
        m       = λ[1:d]
        s       = λ[d+1:2*d]
        λ_L_low = vec_to_lowtril(λ[2*d+1:end])
        m, s, λ_L_low
    end
    flatten, unflatten
end

function get_flatten_utils(::Val{:meanfield}, problem)
    d = LogDensityProblems.dimension(problem)
    function flatten(m, s, ::Any)
        vcat(m, s)
    end
    function unflatten(λ)
        m = λ[1:d]
        s = λ[d+1:2*d]
        m, s, nothing
    end
    flatten, unflatten
end
