include("mf_setting.jl")

##########################3
# flow err scaling
##########################
Random.seed!(1)
nsample = 20
n_ref = 500
E_fwd = zeros(nsample, n_ref)
E_bwd = zeros(nsample, n_ref)
@threads for i in 1:nsample
    z0 = o.q_sampler(d) .* D .+ μ
    ρ0 = randn(d)
    u0 = rand()
    zs, ρs, us = MixFlow.flow_trace_fwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

    E_fwd[i, :] .=
        sqrt.(
            vec(
                sum(abs2, zs .- zz; dims=2) .+ sum(abs2, ρs .- ρρ; dims=2) .+
                (us .- uu) .^ 2,
            )
        )

    zbs, ρbs, ubs = MixFlow.flow_trace_bwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zzb, ρρb, uub = MixFlow.flow_trace_bwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)
    E_bwd[i, :] .=
        sqrt.(
            vec(
                sum(abs2, zbs .- zzb; dims=2) .+ sum(abs2, ρbs .- ρρb; dims=2) .+
                (ubs .- uub) .^ 2,
            )
        )
end
JLD.save("result/flow_err.jld", "fwd_err", E_fwd, "bwd_err", E_bwd, "Ns", [1:n_ref;])

################
# sampling err scaling
###############
Random.seed!(1)
nsample = 20
n_ref = 800
f1(x) = abs.(x)
f2(x) = sin.(x) .+ 1
f3(x) = 1 ./ (1 .+ exp.(-x))

F1 = zeros(nsample, n_ref)
F2 = zeros(nsample, n_ref)
F3 = zeros(nsample, n_ref)
Fs1 = zeros(nsample, n_ref)
Fs2 = zeros(nsample, n_ref)
Fs3 = zeros(nsample, n_ref)

prog_bar = ProgressMeter.Progress(
    nsample; dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow
)
@threads for i in 1:nsample
    z0 = o.q_sampler(d) .* D .+ μ
    ρ0 = randn(d)
    u0 = rand()
    zs, ρs, us = MixFlow.flow_trace_fwd(
        o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
    )
    zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

    mf1, mfs1 = MCf(f1, zs, ρs, us, zz, ρρ, uu)
    mf2, mfs2 = MCf(f2, zs, ρs, us, zz, ρρ, uu)
    mf3, mfs3 = MCf(f3, zs, ρs, us, zz, ρρ, uu)
    F1[i, :] = mf1
    F2[i, :] = mf2
    F3[i, :] = mf3
    Fs1[i, :] = mfs1
    Fs2[i, :] = mfs2
    Fs3[i, :] = mfs3

    # update progress bar
    ProgressMeter.next!(prog_bar)
end
JLD.save(
    "result/sampling_err_rel.jld",
    "absx",
    (F1, Fs1),
    "sinx",
    (F2, Fs2),
    "sigmoid",
    (F3, Fs3),
    "Ns",
    [1:n_ref;],
)

# ###############
# lpdf  err scaling
# ##############
Random.seed!(1)
X = [-5:1.0:5;]
Y = [-5:1.0:5;]
m0 = randn(2)
u0 = rand()
n_ref = 800
Ds = zeros(length(X), length(Y), n_ref)
Dd = zeros(length(X), length(Y), n_ref)

# lpdf_est, Error
n1, n2 = size(X, 1), size(Y, 1)
prog_bar = ProgressMeter.Progress(
    n1 * n2; dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow
)
@threads for i in 1:n1
    for j in 1:n2
        Ds[i, j, :] = MixFlow.log_density_cum(
            [X[i], Y[j]], m0, u0, o, a, MixFlow.inv_ref_coord, n_ref
        )
        Dd[i, j, :] = MixFlow.log_density_cum(
            ft.([X[i], Y[j]]), ft.(m0), ft.(u0), o, a_big, MixFlow.inv_ref_coord, n_ref
        )
        # update progress bar
        ProgressMeter.next!(prog_bar)
    end
end
JLD.save("result/lpdfs_err.jld", "lpdfs", Ds, "lpdfs_big", Dd)

##########################3
# ELBO 
##########################
Random.seed!(1)
Ns = [10, 50, 100, 200, 300, 500, 800]
el_size = 200
EL = MixFlow.elbo_sweep_multiple(
    o, a, MixFlow.ref_coord, MixFlow.inv_ref_coord, Ns; elbo_size=el_size
)
EL_big = MixFlow.elbo_sweep_multiple(
    o, a_big, MixFlow.ref_coord, MixFlow.inv_ref_coord, Ns; elbo_size=el_size
)
JLD.save("result/elbo_err.jld", "Ns", Ns, "EL", EL, "EL_big", EL_big)
