include("mf_setting.jl")

# #########################3
# flow err scaling
# #########################
Random.seed!(1)
nsample = 100
n_ref = 500
E_fwd = zeros(nsample, n_ref)
E_bwd = zeros(nsample, n_ref)
@threads for i in 1:nsample
    z0 = o.q_sampler(d) .* D .+ μ
    ρ0 = randn(d)
    u0 = rand()
    zs, ρs, us = MixFlow.flow_trace_fwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

    E_fwd[i, :] .=
        sqrt.(
            vec(
                sum(abs2, zs .- zz; dims=2) .+ sum(abs2, ρs .- ρρ; dims=2) .+
                (us .- uu) .^ 2,
            )
        )

    zbs, ρbs, ubs = MixFlow.flow_trace_bwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zzb, ρρb, uub = MixFlow.flow_trace_bwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)
    E_bwd[i, :] .=
        sqrt.(
            vec(
                sum(abs2, zbs .- zzb; dims=2) .+ sum(abs2, ρbs .- ρρb; dims=2) .+
                (ubs .- uub) .^ 2,
            )
        )
end
JLD.save("result/flow_err.jld", "fwd_err", E_fwd, "bwd_err", E_bwd, "Ns", [1:n_ref;])

################
# sampling err scaling
###############
Random.seed!(1)
nsample = 20
n_ref = 650
f1(x) = abs.(x)
f2(x) = sin.(x) .+ 1
f3(x) = 1 ./ (1 .+ exp.(-x))

F1 = zeros(nsample, n_ref)
F2 = zeros(nsample, n_ref)
F3 = zeros(nsample, n_ref)
Fs1 = zeros(nsample, n_ref)
Fs2 = zeros(nsample, n_ref)
Fs3 = zeros(nsample, n_ref)

prog_bar = ProgressMeter.Progress(
    nsample; dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow
)
@threads for i in 1:nsample
    z0 = o.q_sampler(d) .* D .+ μ
    ρ0 = randn(d)
    u0 = rand()
    zs, ρs, us = MixFlow.flow_trace_fwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

    mf1, mfs1 = MCf(f1, zs, ρs, us, zz, ρρ, uu)
    mf2, mfs2 = MCf(f2, zs, ρs, us, zz, ρρ, uu)
    mf3, mfs3 = MCf(f3, zs, ρs, us, zz, ρρ, uu)
    F1[i, :] = mf1
    F2[i, :] = mf2
    F3[i, :] = mf3
    Fs1[i, :] = mfs1
    Fs2[i, :] = mfs2
    Fs3[i, :] = mfs3

    # update progress bar
    ProgressMeter.next!(prog_bar)
end
JLD.save(
    "result/sampling_err_rel.jld",
    "absx",
    (F1, Fs1),
    "sinx",
    (F2, Fs2),
    "sigmoid",
    (F3, Fs3),
    "Ns",
    [1:n_ref;],
)

################
# lpdf  err scaling
###############
 Random.seed!(1)
 X = [-20.001:1:20;]
 Y = [-15.001:1:20;]
 n_lfrg = 200
 o = MixFlow.HamFlow(
     d,
     n_lfrg,
     logp,
     ∇logp,
     randn,
     logq,
     randn,
     MixFlow.lpdf_normal,
     MixFlow.∇lpdf_normal,
     MixFlow.cdf_normal,
     MixFlow.invcdf_normal,
     MixFlow.pdf_normal,
 )

 m0 = randn(2)
 u0 = rand()
 n_ref = 500
 Ds = zeros(length(X), length(Y), n_ref)
 Dd = zeros(length(X), length(Y), n_ref)

 # lpdf_est, Error
 n1, n2 = size(X, 1), size(Y, 1)
 prog_bar = ProgressMeter.Progress(
     n1 * n2; dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow
 )
 @threads for i in 1:n1
     for j in 1:n2
         Ds[i, j, :] = MixFlow.log_density_cum(
             [X[i], Y[j]], m0, u0, o, a, MixFlow.inv_ref_coord, n_ref
         )
         Dd[i, j, :] = MixFlow.log_density_cum(
             ft.([X[i], Y[j]]), ft.(m0), ft.(u0), o, a_big, MixFlow.inv_ref_coord, n_ref
         )
         # update progress bar
         ProgressMeter.next!(prog_bar)
     end
 end
 JLD.save("result/lpdfs_err.jld", "lpdfs", Ds, "lpdfs_big", Dd)

##########################3
# ELBO 
##########################
Random.seed!(1)
Ns = [10, 20, 50, 100, 200, 300, 500]
el_size = 200
EL = MixFlow.elbo_sweep_multiple(
    o, a, MixFlow.ref_coord, MixFlow.inv_ref_coord, Ns; elbo_size=el_size
)
EL_big = MixFlow.elbo_sweep_multiple(
    o, a_big, MixFlow.ref_coord, MixFlow.inv_ref_coord, Ns; elbo_size=el_size
)
JLD.save("result/elbo_err.jld", "Ns", Ns, "EL", EL, "EL_big", EL_big)
