using Base.Threads: @threads
using ProgressMeter, Flux
using LogExpFunctions
include("util.jl")

##### struct of Hamflow
struct HamFlow <: mixflow
    # Ham Flow struct (we dont put n_ref inside for ease of changing K in flow_fwd) 
    d::Int64
    n_lfrg::Int64 # number of n_lfrg between refresh
    # target 
    logp::Function # log target density
    ∇logp::Function
    # VI distribution
    q_sampler::Function
    logq0::Function
    # momentum 
    ρ_sampler::Function
    lpdf_mom::Function
    ∇logp_mom::Function
    cdf_mom::Function
    invcdf_mom::Function
    pdf_mom::Function
end

"""
leapfrog that combines 2 consecutive steps into 1
"""
function leapfrog(∇logp::Function, ∇logm::Function, n_lfrg::Int, ϵ, z, ρ)
    ρ += 0.5 .* ϵ .* ∇logp(z)
    for i in 1:(n_lfrg - 1)
        z -= ϵ .* ∇logm(ρ)
        ρ += ϵ .* ∇logp(z)
    end
    z -= ϵ .* ∇logm(ρ)
    ρ += 0.5 .* ϵ .* ∇logp(z)
    return z, ρ
end
leapfrog(o::HamFlow, ϵ, z, ρ) = leapfrog(o.∇logp, o.∇logp_mom, o.n_lfrg, ϵ, z, ρ)

function ref_coord(o::HamFlow, z, ρ, u)
    ρ1 = copy(ρ)
    for i in 1:(o.d)
        ξ = (o.cdf_mom(ρ1[i]) + stream(z[i], u)) % 1.0
        ρ1[i] = o.invcdf_mom(ξ)
        u = time_shift(u)
    end
    return ρ1, u
end
function inv_ref_coord(o::HamFlow, z, ρ, u)
    ρ1 = copy(ρ)
    for i in (o.d):-1:1
        u = inv_timeshift(u)
        ξ = (o.cdf_mom(ρ1[i]) + 1.0 - stream(z[i], u)) % 1.0
        ρ1[i] = o.invcdf_mom(ξ)
    end
    return ρ1, u
end

function one_fwd(o::HamFlow, ϵ, refresh, z, ρ, u)
    z, ρ = leapfrog(o, ϵ, z, ρ)
    logj = -o.lpdf_mom(ρ)
    ρ, u = refresh(o, z, ρ, u)
    logj += o.lpdf_mom(ρ)
    return z, ρ, u, logj
end
function one_bwd(o::HamFlow, ϵ, inv_ref, z, ρ, u)
    logj = o.lpdf_mom(ρ)
    ρ, u = inv_ref(o, z, ρ, u)
    logj -= o.lpdf_mom(ρ)
    z, ρ = leapfrog(o, -ϵ, z, ρ)
    return z, ρ, u, logj
end
function flow_fwd(o::HamFlow, ϵ::Vector{T}, refresh, z, ρ, u, n_ref::Int64) where {T<:Real}
    logJ = zero(T)
    for i in 1:(n_ref - 1)
        z, ρ, u, logj = one_fwd(o, ϵ, refresh, z, ρ, u)
        logJ += logj
    end
    return z, ρ, u, logJ
end

function flow_bwd(o::HamFlow, ϵ::Vector{T}, inv_ref, z, ρ, u, n_ref::Int64) where {T<:Real}
    logJ = zero(T)
    for i in 1:(n_ref - 1)
        z, ρ, u, logj = one_bwd(o, ϵ, inv_ref, z, ρ, u)
        logJ += logj
    end
    return z, ρ, u, logJ
end

function flow_trace_fwd(
    o::HamFlow, ϵ::Vector{T}, refresh, z, ρ, u, n_ref::Int64
) where {T<:Real}
    zs = Matrix{eltype(z)}(undef, n_ref, o.d)
    ρs = Matrix{eltype(z)}(undef, n_ref, o.d)
    us = Vector{eltype(z)}(undef, n_ref)
    zs[1, :] .= z
    ρs[1, :] .= ρ
    us[1] = u
    for i in 1:(n_ref - 1)
        z, ρ, u, _ = one_fwd(o, ϵ, refresh, z, ρ, u)
        zs[i + 1, :] .= z
        ρs[i + 1, :] .= ρ
        us[i + 1] = u
    end
    return zs, ρs, us
end

function flow_trace_bwd(
    o::HamFlow, ϵ::Vector{T}, inv_ref, z, ρ, u, n_ref::Int64
) where {T<:Real}
    zs = Matrix{eltype(z)}(undef, n_ref, o.d)
    ρs = Matrix{eltype(z)}(undef, n_ref, o.d)
    us = Vector{eltype(z)}(undef, n_ref)
    zs[1, :] .= z
    ρs[1, :] .= ρ
    us[1] = u
    for i in 1:(n_ref - 1)
        z, ρ, u, _ = one_bwd(o, ϵ, inv_ref, z, ρ, u)
        zs[i + 1, :] .= z
        ρs[i + 1, :] .= ρ
        us[i + 1] = u
    end
    return zs, ρs, us
end

function Sampler!(
    T::Matrix{H},
    M::Matrix{H},
    U::Vector{H},
    o::HamFlow,
    a::HF_params,
    refresh,
    n_ref::Int;
    nBurn::Int64=0,
    nsample::Int=1000,
) where {H<:Real}
    d = o.d
    @info "ErgFlow Sampling"
    prog_bar = ProgressMeter.Progress(
        nsample;
        dt=0.5,
        barglyphs=ProgressMeter.BarGlyphs("[=> ]"),
        barlen=50,
        color=:yellow,
    )
    @threads for i in 1:nsample
        # Sample(Unif{1, ..., n_ref})
        n_step = rand((nBurn + 1):n_ref)
        z0 = a.D .* o.q_sampler(d) .+ a.μ
        ρ0, u0 = o.ρ_sampler(H, d), rand(H)
        z, ρ, u, _ = flow_fwd(o, a.leapfrog_stepsize, refresh, z0, ρ0, u0, n_step)
        T[i, :] .= z
        M[i, :] .= ρ
        U[i] = u
        ProgressMeter.next!(prog_bar)
    end
end

function Sampler(
    o::HamFlow, a::HF_params, refresh, n_ref::Int; nBurn::Int64=0, nsample::Int=1000
)
    T = Matrix{eltype(a.μ)}(undef, nsample, o.d)
    M = Matrix{eltype(a.μ)}(undef, nsample, o.d)
    U = Vector{eltype(a.μ)}(undef, nsample)
    Sampler!(T, M, U, o, a, refresh, n_ref; nBurn=nBurn, nsample=nsample)
    return T, M, U
end

function log_density_eval(
    z, ρ, u, o::HamFlow, a::HF_params, inv_ref::Function, n_ref::Int; nBurn::Int=0
)
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    lpdfs = Vector{eltype(z)}(undef, n_ref)
    lpdfs[1] = o.lpdf_mom(ρ) + o.logq0(z, μ, D)
    logJ = zero(eltype(z))
    for i in 1:(n_ref - 1)
        z, ρ, u, logj = one_bwd(o, ϵ, inv_ref, z, ρ, u)
        logJ += logj
        # logJ += o.lpdf_mom(ρ)
        # ρ, u = inv_ref(o, z, ρ, u) # ρ_k -> ρ_(k-1/2)
        # logJ -= o.lpdf_mom(ρ)
        # z, ρ = leapfrog(o, -ϵ, z, ρ)
        lpdfs[i + 1] = o.logq0(z, μ, D) + o.lpdf_mom(ρ) + logJ
    end
    logqN = logmeanexp(@view(lpdfs[(nBurn + 1):end]))
    return logqN
end

function log_density_cum(z, ρ, u, o::HamFlow, a::HF_params, inv_ref::Function, n_ref::Int)
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    lpdfs = Vector{eltype(z)}(undef, n_ref)
    lpdfs[1] = o.lpdf_mom(ρ) + o.logq0(z, μ, D)
    logJ = zero(eltype(z))
    for i in 1:(n_ref - 1)
        z, ρ, u, logj = one_bwd(o, ϵ, inv_ref, z, ρ, u)
        logJ += logj
        lpdfs[i + 1] = o.logq0(z, μ, D) + o.lpdf_mom(ρ) + logJ
    end
    logqNs = cumlogsumexp(lpdfs) .- log.([1:n_ref;])
    return logqNs
end

function single_elbo_traj(
    o::HamFlow, a::HF_params, refresh, inv_ref, n_ref::Int, z, ρ, u; nBurn::Int=0
)
    # # init sample
    # d = o.d
    # z = D .* o.q_sampler(d) .+ μ
    # ρ, u = o.ρ_sampler(d), rand()
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    z0, ρ0, u0 = copy(z), copy(ρ), copy(u)

    # save logjs, logq0,logqn
    logjs = zeros(eltype(z), 2 * n_ref - 2)
    logq0s = zeros(eltype(z), 2 * n_ref - 1)
    logqns = zeros(eltype(z), n_ref)

    # flow bwd n-1 step 
    @inbounds for i in (n_ref - 1):-1:1
        z0, ρ0, u0, logj = one_bwd(o, ϵ, inv_ref, z0, ρ0, u0)
        logjs[i] = logj
        logq0s[i] = o.logq0(z0, μ, D) + o.lpdf_mom(ρ0)
    end
    logq0s[n_ref] = o.lpdf_mom(ρ) + o.logq0(z, μ, D)
    logp = o.logp(z) + o.lpdf_mom(ρ)
    logqns[1] = logmeanexp(
        @view(logq0s[n_ref:-1:1]) .+
        cumsum(vcat([zero(eltype(z))], @view(logjs[(n_ref - 1):-1:1]))),
    )

    # flow fwd n-1 step
    @inbounds for i in n_ref:(2 * n_ref - 2)
        z, ρ, u, logj = one_fwd(o, ϵ, refresh, z, ρ, u)
        logjs[i] = logj
        logq0 = o.logq0(z, μ, D) + o.lpdf_mom(ρ)
        logq0s[i + 1] = logq0
        # update logqn(t^n x)
        logqns[i - n_ref + 2] = logmeanexp(
            @view(logq0s[(i + 1):-1:(i - n_ref + 2)]) .+
            cumsum(vcat([0.0], @view(logjs[i:-1:(i - n_ref + 2)]))),
        )
        logp += o.logp(z) + o.lpdf_mom(ρ)
    end
    logp /= n_ref
    return logp - mean(logqns)
end

function elbo_multiple(
    o::HamFlow, a::HF_params, refresh, inv_ref, n_ref::Int; elbo_size::Int=1, nBurn::Int=0
)
    d = o.d
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    ft = eltype(μ)
    el = zeros(ft, elbo_size)
    prog_bar = ProgressMeter.Progress(
        elbo_size;
        dt=0.5,
        barglyphs=ProgressMeter.BarGlyphs("[=> ]"),
        barlen=50,
        color=:yellow,
    )
    @threads for i in 1:elbo_size
        z = D .* o.q_sampler(d) .+ μ
        ρ, u = o.ρ_sampler(d), rand()
        el[i] = single_elbo_traj(
            o, a, refresh, inv_ref, n_ref, ft.(z), ft.(ρ), ft(u); nBurn=nBurn
        )
        ProgressMeter.next!(prog_bar)
    end
end

function ELBO(
    o::HamFlow, a::HF_params, refresh, inv_ref, n_ref::Int; elbo_size::Int=1, nBurn::Int=0
)
    el = elbo_multiple(o, a, refresh, inv_ref, n_ref; elbo_size=elbo_size, nBurn=nBurn)
    return mean(el)
end

function DensityTripleSweep(
    z::Vector{T},
    ρ::Vector{T},
    u::T,
    o::HamFlow,
    a::HF_params,
    refresh::Function,
    inv_ref::Function,
    n_mcmc::Int,
) where {T<:Real}
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    z0, ρ0, u0 = copy(z), copy(ρ), copy(u)
    # save t^k(x), logjs, logqn
    logjs = zeros(2 * n_mcmc - 2)
    logq0s = zeros(2 * n_mcmc - 1)
    logps = zeros(n_mcmc)

    # logq0(x0)
    logq0s[n_mcmc] = o.lpdf_mom(ρ0) + o.logq0(z0, μ, D)
    # flow bwd n-1 step 
    @inbounds for i in (n_mcmc - 1):-1:1
        logjs[i] = o.lpdf_mom(ρ0)
        ρ0, u0 = inv_ref(o, z0, ρ0, u0)
        logjs[i] -= o.lpdf_mom(ρ0)
        z0, ρ0 = leapfrog(o, -ϵ, z0, ρ0)
        logq0s[i] = o.logq0(z0, μ, D) + o.lpdf_mom(ρ0)
    end

    # flow fwd n-1 step
    logps[1] = o.logp(z) + o.lpdf_mom(ρ)
    @inbounds for i in n_mcmc:(2 * n_mcmc - 2)
        z, ρ = leapfrog(o, ϵ, z, ρ)
        logjs[i] = -o.lpdf_mom(ρ)
        ρ, u = refresh(o, z, ρ, u)
        logjs[i] += o.lpdf_mom(ρ)
        logq0 = o.logq0(z, μ, D) + o.lpdf_mom(ρ)
        logq0s[i + 1] = logq0
        logps[i - n_mcmc + 2] = o.logp(z) + o.lpdf_mom(ρ)
    end
    return logq0s, logjs, logps
end

function single_elbo_sweep(
    o::HamFlow, a::HF_params, refresh, inv_ref, z, ρ, u, Ns::Vector{Int64}
)
    # init sample
    d = o.d
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    ft = eltype(μ)
    # z = D .* o.q_sampler(d) .+ μ
    # ρ, u = o.ρ_sampler(d), rand()
    z, ρ, u = ft.(z), ft.(ρ), ft(u)

    # compute Tk(x), logJs, logps
    n_mcmc = maximum(Ns)
    logq0s, logjs, logps = DensityTripleSweep(z, ρ, u, o, a, refresh, inv_ref, n_mcmc)

    # compute logqn
    K = size(Ns, 1)
    logqns = zeros(ft, K, n_mcmc)
    f = zeros(ft, K)
    g = zeros(ft, K)
    @inbounds begin
        logqns[:, 1] .= logsumexp_sweep(
            @view(logq0s[n_mcmc:-1:1]) .+
            cumsum(vcat([zero(ft)], @view(logjs[(n_mcmc - 1):-1:1]))),
            Ns,
        )
        for i in n_mcmc:(2 * n_mcmc - 2)
            logqns[:, i - n_mcmc + 2] .= logsumexp_sweep(
                @view(logq0s[(i + 1):-1:(i - n_mcmc + 2)]) .+
                cumsum(vcat([zero(ft)], @view(logjs[i:-1:(i - n_mcmc + 2)]))),
                Ns,
            )
        end
        # compute elbo 
        @simd for i in 1:K
            f[i] = sum(@view(logps[1:Ns[i]])) / Ns[i]
            g[i] = mean(@view(logqns[i, 1:Ns[i]])) - log(Ns[i])
        end
    end
    return vec(f .- g)
end

function elbo_sweep_multiple(
    o::HamFlow, a::HF_params, refresh, inv_ref, Ns::Vector{<:Int}; elbo_size::Int=1
)
    d = o.d
    ϵ, μ, D = a.leapfrog_stepsize, a.μ, a.D
    ft = eltype(μ)
    el = zeros(ft, elbo_size, size(Ns, 1))
    prog_bar = ProgressMeter.Progress(
        elbo_size;
        dt=0.5,
        barglyphs=ProgressMeter.BarGlyphs("[=> ]"),
        barlen=50,
        color=:yellow,
    )
    @threads for i in 1:elbo_size
        z = D .* o.q_sampler(d) .+ μ
        ρ, u = o.ρ_sampler(d), rand()
        el[i, :] .= single_elbo_sweep(o, a, refresh, inv_ref, ft.(z), ft.(ρ), ft(u), Ns)
        ProgressMeter.next!(prog_bar)
    end
    return el
end

function ELBO_sweep(
    o::HamFlow, a::HF_params, refresh, inv_ref, Ns::Vector{<:Int}; elbo_size::Int=1
)
    el = elbo_sweep_multiple(o, a, refresh, inv_ref, Ns; elbo_size=elbo_size)
    return vec(mean(el; dims=1))
end
