include("mf_setting.jl")
import PlotlyJS as pjs

##########################3
# visualize different trajectories
##########################
Random.seed!(1)
z0 = o.q_sampler(d) .* D .+ μ
ρ0 = randn(d)
u0 = rand()
n_ref = 100

zs, ρs, us = MixFlow.flow_trace_fwd(
    o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
)
zz, ρρ, uu = MixFlow.flow_trace_fwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

zbs, ρbs, ubs = MixFlow.flow_trace_bwd(
    o, ft.(ϵ), MixFlow.ref_coord, ft.(z0), ft.(ρ0), ft.(u0), n_ref
)
zzb, ρρb, uub = MixFlow.flow_trace_bwd(o, ϵ, MixFlow.ref_coord, z0, ρ0, u0, n_ref)

# plot trajectories
x = -20:0.1:20
y = -15:0.1:30
f = (x, y) -> exp(logp([x, y]))
p1 = contour(x, y, f; colorbar=false, title="Banana", color=:viridis, levels=10)
scatter!(zs[1:2:30, 1], zs[1:2:30, 2]; label="exact", ms=6, msw=1, color=:red, alpha=0.8)
scatter!(
    zz[1:2:30, 1], zz[1:2:30, 2]; label="numerical", color=:blue, ms=6, msw=1, alpha=0.6
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legend=:top,
    legendfontsize=20,
    title="MixFlow Fwd Orbit (N = 30)",
)
savefig("figure/fwd_traj.png")

p2 = contour(x, y, f; colorbar=false, title="Banana", color=:viridis, levels=10)
scatter!(zbs[1:2:30, 1], zbs[1:2:30, 2]; label="exact", ms=6, msw=1, color=:red, alpha=0.8)
scatter!(
    zzb[1:2:30, 1], zzb[1:2:30, 2]; label="numerical", color=:blue, ms=6, msw=1, alpha=0.6
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legend=:top,
    legendfontsize=20,
    title="MixFlow Bwd Orbit (N = 30)",
)
savefig("figure/bwd_traj.png")

##################################
# scatter plot
##################################
Random.seed!(1)
nsample = 1000
T, M, U = MixFlow.Sampler(o, a, MixFlow.ref_coord, 500; nsample=nsample)

# plot scatters
x = -20:0.1:20
y = -15:0.1:30
f = (x, y) -> exp(logp([x, y]))
p = contour(
    x, y, f; colorbar=false, title="MixFlow sample scatters", color=:viridis, levels=10
)
scatter!(T[:, 1], T[:, 2]; label="", ms=4, msw=1, alpha=0.5)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legend=:topleft,
    legendfontsize=20,
    xlim=(-20, 20),
    ylim=(-15, 30),
)
savefig("figure/scatter.png")

################
# lpdf  vis
###############
X = [-20.001:0.5:20;]
Y = [-15.001:0.5:20;]
Ds = zeros(length(X), length(Y))
Ds_big = zeros(length(X), length(Y))
Dd = zeros(length(X), length(Y))

m0 = zeros(2)
u0 = rand()
n_ref = 500
nBurn = 5
# lpdf_est, Error
n1, n2 = size(X, 1), size(Y, 1)
@showprogress for i in 1:n1
    @threads for j in 1:n2
        Ds[i, j] = MixFlow.log_density_eval(
            [X[i], Y[j]], m0, u0, o, a, MixFlow.inv_ref_coord, n_ref; nBurn=nBurn
        )
        Dd[i, j] = logp([X[i], Y[j]])

        Ds_big[i, j] = MixFlow.log_density_eval(
            ft.([X[i], Y[j]]),
            ft.(m0),
            ft(u0),
            o,
            a_big,
            MixFlow.inv_ref_coord,
            n_ref;
            nBurn=nBurn,
        )
    end
end

JLD.save("result/lpdfs_mix.jld", "lpdfs", Ds, "true", Dd, "lpdfs_big", Ds_big)

res = JLD.load("result/lpdfs_mix.jld")
Ds = res["lpdfs"]
Dd = res["true"]
Ds_big = res["lpdfs_big"]
layout = pjs.Layout(;
    width=500,
    height=500,
    scene=pjs.attr(;
        xaxis=pjs.attr(; showticklabels=false, visible=false),
        yaxis=pjs.attr(; showticklabels=false, visible=false),
        zaxis=pjs.attr(; showticklabels=false, visible=false),
    ),
    margin=pjs.attr(; l=0, r=0, b=0, t=0, pad=0),
    colorscale="Vird",
)
p_est = pjs.plot(pjs.surface(; z=Ds, x=X, y=Y, showscale=false), layout)
pjs.savefig(p_est, joinpath("figure/", "lpdf_mixflow.png"))

p_target = pjs.plot(pjs.surface(; z=Dd, x=X, y=Y, showscale=false), layout)
pjs.savefig(p_target, joinpath("figure/", "lpdf.png"))

p_big = pjs.plot(pjs.surface(; z=Ds_big, x=X, y=Y, showscale=false), layout)
pjs.savefig(p_big, joinpath("figure/", "lpdf_big.png"))
