using Base.Threads: @threads
using StatsBase, LinearAlgebra
using Flux
include("../util/train.jl")
# we do not consider sparsity for the flow
mutable struct sb_HF
    # Ham Flow struct (we dont put n_ref inside for ease of changing K in flow_fwd) 
    d::Int64
    n_lfrg::Int64 # number of n_lfrg between refresh
    # target 
    logp::Function # log target density
    ∇logp::Function
    # VI distribution
    q_sampler::Function 
    logq0::Function
    # momentum 
    # ρ_sampler::Function
    lpdf_mom::Function
    ∇lpdf_mom::Function
    # pdf_mom::Function
    # estimated standarization 
    # μs::Vector{Real}
    # L_invs::Vector{Real}
end

# function oneleapfrog!(∇logp::Function, ∇logm::Function, ϵ::Vector{T}, z, ρ) where T<:Real
#     ρ .+= 0.5 .* ϵ .* ∇logp(z) 
#     z .-= ϵ .* ∇logm(ρ)
#     ρ .+= 0.5 .* ϵ .* ∇logp(z) 
# end

# function oneleapfrog(∇logp::Function, ∇logm::Function, ϵ::Vector{T}, z, ρ) where T<:Real
#     ρ += 0.5 .* ϵ .* ∇logp(z) 
#     z -= ϵ .* ∇logm(ρ)
#     ρ += 0.5 .* ϵ .* ∇logp(z) 
#     return z, ρ
# end

function leapfrog!(∇logp::Function, ∇logm::Function, ϵ::Vector{T}, z, ρ, L::Int64) where T<:Real
    ρ .+= 0.5 .* ϵ .* ∇logp(z) 
    for i in 1:L-1
        # println(z)
        # println(∇logm(ρ))
        z .-= ϵ .* ∇logm(ρ)
        ρ .+= ϵ .* ∇logp(z) 
        # oneleapfrog!(∇logp, ∇logm, ϵ, z, ρ)
    end
    z .-= ϵ .* ∇logm(ρ)
    ρ .+= 0.5 .* ϵ .* ∇logp(z)
end

function leapfrog(∇logp::Function, ∇logm::Function, ϵ::Vector{T}, z, ρ, L::Int64) where T<:Real 
    ρ += 0.5 .* ϵ .* ∇logp(z)
    for i in 1:L-1
        z -= ϵ .* ∇logm(ρ)
        ρ += ϵ .* ∇logp(z)
        # z, ρ  = oneleapfrog(∇logp, ∇logm, ϵ, z, ρ)
    end
    z -= ϵ .* ∇logm(ρ)
    ρ += 0.5 .* ϵ .* ∇logp(z)
    return z, ρ
end

function normalize(o::sb_HF, i::Int64, ρ, μs, L_invs) 
    μ, L_inv = @view(μs[d*(i-1)+1: d*i]), @view(L_invs[d*(i-1)+1: d*i])
    ρ1 = L_inv .* (ρ .- μ)
    # logJ = sum(log.(abs.(L_inv))) 
    return ρ1
end

function inv_normalize(o::sb_HF, i::Int64, ρ, μs, L_invs) 
    μ, L_inv = @view(μs[d*(i-1)+1: d*i]), @view(L_invs[d*(i-1)+1: d*i])
    ρ0 = ρ./L_inv .+ μ
    # logJ = sum(log.(abs.(L_inv))) 
    return ρ0
end

function batch_normalize!(μp, L_inv, ps)
    ps .-= μp'
    ps .*= L_inv'
end


function one_fwd(o::sb_HF, ϵ::Vector{T}, z, ρ, i::Int64, μs, L_invs) where T<:Real
    z, ρ = leapfrog(o.∇logp, o.∇lpdf_mom, ϵ, z, ρ, o.n_lfrg)
    ρ = normalize(o, i, ρ, μs, L_invs)
    return z, ρ
end
function one_bwd(o::sb_HF, ϵ::Vector{T}, z, ρ, i::Int64, μs, L_invs) where T<:Real
    ρ = inv_normalize(o, i, ρ, μs, L_invs)
    z, ρ = leapfrog(o.∇logp, o.∇lpdf_mom, -ϵ, z, ρ, o.n_lfrg)
    return z, ρ
end

function flow(o::sb_HF, ϵ::Vector{T}, z, ρ, n_ref::Int64, direction::String, μs, L_invs) where T<:Real
    logJ = -sum(log.(abs.(L_invs)))
    if direction == "fwd"
        for i in 1: n_ref
            z, ρ = one_fwd(o, ϵ, z, ρ, i, μs, L_invs)
        end
    elseif direction == "bwd"
        for i in n_ref:-1:1
            z, ρ = one_bwd(o, ϵ, z, ρ, i, μs, L_invs)
        end
    else
        error("direction should be fwd or bwd")
    end
    return z, ρ, logJ
end

function log_density(o::sb_HF, ϵ::Vector{T}, z, ρ, n_ref::Int64, μs, L_invs) where T<:Real
    z0, ρ0, logJ = flow(o, ϵ, z, ρ, n_ref, "bwd", μs, L_invs)   
    lpdf = o.logq0(z0) + o.lpdf_mom(ρ0)+ logJ
    return lpdf
end

function single_elbo(o::sb_HF, ϵ::Vector{T}, z, ρ, n_ref, μs, L_invs) where T<:Real
    zn, ρn, logJ = flow(o, ϵ, z, ρ, n_ref, "fwd", μs, L_invs)    
    lq0 = o.logq0(z) + o.lpdf_mom(ρ)
    lp = o.logp(zn) + o.lpdf_mom(ρn) 
    el = lp - lq0 - logJ
    return el
end

function ELBO_shf(o::sb_HF, ϵ, n_ref, μs, L_invs; elbo_size = 10)
    elbo = zero(eltype(ϵ))
    for i in 1:elbo_size
        z, ρ = o.q_sampler(o.d), randn(o.d)
        elbo += single_elbo(o, ϵ, z, ρ, n_ref, μs, L_invs)
    end
    return elbo/elbo_size
end

function sampler_shf(o::sb_HF, ϵ, μs, L_invs, nsamples)
    zs = zeros(nsamples, o.d)
    ms = zeros(nsamples, o.d)
    n_ref = Int(size(μs, 1)/o.d)
    @threads for i in 1:nsamples 
        z, ρ = o.q_sampler(o.d), randn(o.d)
        zn, ρn, logJ = flow(o, ϵ, z, ρ, n_ref, "fwd", μs, L_invs)    
        zs[i,:] .= zn
        ms[i,:] .= ρn
    end
    return zs, ms
end


function flow_save_all(o::sb_HF, ϵ::Vector{T}, z, ρ, n_ref::Int64, direction::String) where T<:Real
    Zs = zeros(n_ref+1, o.d)
    Ps = zeros(n_ref+1, o.d)

    logJ = -sum(log.(abs.(L_invs)))
    Zs[1,:] .= z
    Ps[1,:] .= ρ
    if direction == "fwd"   
        for i in 1: n_ref
            z, ρ = one_fwd(o, ϵ, z, ρ, i, μs, L_invs)
            Zs[i+1,:] .= z
            Ps[i+1,:] .= ρ
        end
    elseif direction == "bwd"
        for i in n_ref:-1:1
            z, ρ = one_bwd(o, ϵ, z, ρ, i, μs, L_invs)
            Zs[i+1,:] .= z
            Ps[i+1,:] .= ρ
        end
    else
        error("direction should be fwd or bwd")
    end
    return Zs, Ps, logJ
end

function flow_save(o, ϵ, μs, L_invs, direction::String; Ns = [10, 20, 30, 40, 50])
    Zs = zeros(length(Ns)+1, o.d)
    Ps = zeros(length(Ns)+1, o.d)
    
    n_ref = Ns[end]
    Zs[1,:] .= z
    Ps[1,:] .= ρ
    if direction == "fwd"   
        for i in 1: n_ref
            z, ρ = one_fwd(o, ϵ, z, ρ, i, μs, L_invs)
            if i in Ns
                Zs[findfirst(Ns .== i)+1,:] .= z
                Ps[findfirst(Ns .== i)+1,:] .= ρ
            end
        end
    elseif direction == "bwd"
        for i in n_ref:-1:1
            z, ρ = one_bwd(o, ϵ, z, ρ, i, μs, L_invs)
            if i in Ns
                Zs[findfirst(Ns[end:1] .== i)+1,:] .= z
                Ps[findfirst(Ns .== i)+1,:] .= ρ
            end
        end
    else
        error("direction should be fwd or bwd")
    end
    return Zs, Ps
end

function leapfrog_err(o::sb_HF, ϵ, L::Int64, nsamples)
    Zs, Ps = o.q_sampler(nsamples, o.d), randn(nsamples, o.d)
    ferr = zeros(nsamples)
    berr = zeros(nsamples)

    @threads for i in eachindex(Zs[:,1])
        z, ρ = @view(Zs[i,:]), @view(Ps[i,:])
        zn, ρn = leapfrog(o.∇logp, o.∇lpdf_mom, ϵ, z, ρ, L)
        z0, ρ0 = leapfrog(o.∇logp, o.∇lpdf_mom, -ϵ, zn, ρn, L)
        ferr[i] = sqrt(sum(abs2, z .- z0) + sum(abs2, ρ .- ρ0))

        zb, ρb = leapfrog(o.∇logp, o.∇lpdf_mom, -ϵ, z, ρ, L)
        zb0, ρb0 = leapfrog(o.∇logp, o.∇lpdf_mom, ϵ, zb, ρb, L)
        berr[i] = sqrt(sum(abs2, z .- zb0) + sum(abs2, ρ .- ρb0))
    end
    return ferr, berr
end

function fwd_inv_err(o::sb_HF, ϵ, n_ref, μs, L_invs, z, ρ)
    # error of T⁻¹∘T
    zn, ρn = flow(o, ϵ, z, ρ, n_ref, "fwd", μs, L_invs)
    z0, ρ0 = flow(o, ϵ, zn, ρn, n_ref, "bwd", μs, L_invs)
    err_fwd = sqrt(sum(abs2, z .- z0) + sum(abs2, ρ .- ρ0))
    return err_fwd
end

function bwd_inv_err(o::sb_HF, ϵ, n_ref, μs, L_invs, z, ρ)
    # error of T∘T⁻¹
    zbn, ρbn = flow(o, ϵ, z, ρ, n_ref, "bwd", μs, L_invs)
    zb0, ρb0 = flow(o, ϵ, zbn, ρbn, n_ref, "fwd", μs, L_invs)
    err_bwd = sqrt(sum(abs2, z .- zb0) + sum(abs2, ρ .- ρb0))
    return err_bwd
end

function inversion_err(o::sb_HF, ϵ, n_ref, μs, L_invs, z, ρ)
    # error of T⁻¹∘T
    err_fwd = fwd_inv_err(o, ϵ, n_ref, μs, L_invs, z, ρ)
    # error of T∘T⁻¹
    err_bwd = bwd_inv_err(o, ϵ, n_ref, μs, L_invs, z, ρ)
    return err_fwd, err_bwd
end

function batch_inversion_err(o::sb_HF, ϵ, n_ref, μs, L_invs; nsamples = 1000)
    Zs, Ps = o.q_sampler(nsamples, o.d), randn(nsamples, o.d)
    fwd_err = zeros(nsamples)
    bwd_err = zeros(nsamples)
    @threads for i in eachindex(Zs[:,1])
        F, B = inversion_err(o, ϵ, n_ref, μs, L_invs, Zs[i,:], Ps[i,:])
        fwd_err[i] = F
        bwd_err[i] = B
    end
    return fwd_err, bwd_err
end

function inversion_err(o::sb_HF, ϵ, μs, L_invs, z, ρ; Ns = [10, 20, 30, 40, 50])
    EF = zeros(length(Ns))
    EB = zeros(length(Ns))
    for i in eachindex(Ns)
        n_ref = Ns[i]
        μ, L_inv = @view(μs[1:o.d*n_ref]), @view(L_invs[1:o.d*n_ref])
        err_fwd, err_bwd = inversion_err(o, ϵ, n_ref, μ, L_inv, z, ρ)
        EF[i] = err_fwd
        EB[i] = err_bwd
    end
    return Ns, EF, EB
end
function batch_inversion_err(o::sb_HF, ϵ, μs, L_invs; nsamples = 1000, Ns = [10, 20, 30, 40, 50])
    Zs, Ps = o.q_sampler(nsamples, o.d), randn(nsamples, o.d)
    fwd_err = zeros(nsamples, length(Ns))
    bwd_err = zeros(nsamples, length(Ns))
    @threads for i in eachindex(Zs[:,1])
        _, F, B = inversion_err(o, ϵ, μs, L_invs, Zs[i,:], Ps[i,:]; Ns = Ns)
        fwd_err[i, :] = F
        bwd_err[i, :] = B
    end
    return Ns, fwd_err, bwd_err
end


# estimating standarization matrix
# and push normalization params into Params
function warm_start!(o, ϵ, sample_size::Int64, n_ref::Int64)
        ∇logp, ∇lpdf_mom, n_lfrg, d = o.∇logp, o.∇lpdf_mom, o.n_lfrg, o.d
        # samples used to estimate refresh parameters
        zs_ref = o.q_sampler(sample_size, d)
        ps_ref = randn(sample_size, d)

        μps = zeros(d*n_ref)
        L_invs = zeros(d*n_ref)
        Ls = zeros(d*n_ref)

        prog_bar = ProgressMeter.Progress(n_ref, dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow)
        for n in 1:n_ref
            @threads for i in 1:sample_size
                leapfrog!(∇logp, ∇lpdf_mom, ϵ, @view(zs_ref[i,:]), @view(ps_ref[i,:]), n_lfrg)
            end
            # estimate initial μp, L, L_inv 
            μp = vec(mean(ps_ref, dims=1))
            L = vec(std(ps_ref, dims = 1))
            L_inv = one.(L) ./L

            μps[d*(n-1)+1: d*n] .= μp
            Ls[d*(n-1)+1: d*n] .= L
            L_invs[d*(n-1)+1: d*n] .= L_inv
            # normalize momentums
            batch_normalize!(μp, L_inv, ps_ref)
            ProgressMeter.next!(prog_bar)
        end
    return μps, L_invs
end

#############################3
# training SFH
#############################3
function sb_HamFlow(o, ϵ0, n_warpup, n_ref; 
                    elbo_size = 10, niters = 100000, optimizer = Flux.ADAM(1e-3), kwargs...)

    @info "warm start with $(n_warpup) samples"
    μs, L_invs = warm_start!(o, ϵ0, n_warpup, n_ref)
    logϵ = log.(ϵ0)
    ps = Flux.params(logϵ, μs, L_invs)
    # prep params for optimization
    loss = () -> begin 
        ϵ = exp.(logϵ) 
        elbo = ELBO_shf(o, ϵ, n_ref, μs,L_invs; elbo_size = elbo_size)
        return -elbo
    end

    @info "start training"
    ls_trace, ps_trace = vi_train!(niters, loss, ps, optimizer; kwargs...)
    return -ls_trace, ps_trace
end


