include("mf_setting.jl")
import PlotlyJS as pjs
using JLD, LinearAlgebra, StatsBase, StatsPlots

##########################3
# flow error scaling
##########################
res = JLD.load("result/flow_err.jld")
plot(
    res["Ns"],
    vec(median(res["fwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["fwd_err"]' .+ 1e-16),
    lw=3,
    label="Fwd",
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow numerical error",
)
plot!(
    res["Ns"],
    vec(median(res["bwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["bwd_err"]' .+ 1e-16),
    lw=3,
    label="Bwd",
)
plot!(;
    yaxis=:log10,
    xlim=(0, 100),
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legend=:topleft,
    legendfontsize=20,
)
savefig("figure/flow_err_log.png")

res = JLD.load("result/flow_err.jld")
plot(
    res["Ns"],
    vec(median(res["fwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["fwd_err"]' .+ 1e-16),
    lw=3,
    label="Fwd",
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow numerical error",
)
plot!(
    res["Ns"],
    vec(median(res["bwd_err"]'; dims=2)) .+ 1e-16;
    ribbon=get_percentiles(res["bwd_err"]' .+ 1e-16),
    lw=3,
    label="Bwd",
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    titlefontsize=30,
    legend=:topleft,
    legendfontsize=20,
)
savefig("figure/flow_err.png")

################
# sampling err scaling
###############
res = JLD.load("result/sampling_err_rel.jld")
Ns = res["Ns"]
F1, Fs1 = res["absx"]
F2, Fs2 = res["sinx"]
F3, Fs3 = res["sigmoid"]
r1 = abs.(average_run2(F1) .- average_run2(Fs1)) ./ abs.(average_run2(Fs1)) .+ 1e-16
r2 = abs.(average_run2(F2) .- average_run2(Fs2)) ./ abs.(average_run2(Fs2)) .+ 1e-16
r3 = abs.(average_run2(F3) .- average_run2(Fs3)) ./ abs.(average_run2(Fs3)) .+ 1e-16
p1 = plot(Ns, vec(median(r1'; dims=2)); ribbon=get_percentiles(r1'), lw=3, label="|x|")
plot!(Ns, vec(median(r2'; dims=2)); ribbon=get_percentiles(r2'), lw=3, label="sinx+1")
plot!(Ns, vec(median(r3'; dims=2)); ribbon=get_percentiles(r3'), lw=3, label="sigmoid")
plot!(;
    xlabel="#transformations",
    ylabel="Rel. err.",
    title="MixFlow sampling error",
    xrotation=20,
    legend=:bottomright,
)
plot!(;
    yaxis=:log10,
    ylim=(1e-7, 1),
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/sampling_err_log_rel.png")

######################
# lpdfs error 
######################
res = JLD.load("result/lpdfs_err.jld")
# reshapw Ds Dd to 2d array where the last dimension is the original third dimension
Ds = res["lpdfs"]
Dd = res["lpdfs_big"]
Ds = reshape(Ds, size(Ds, 1) * size(Ds, 2), size(Ds, 3))
Dd = reshape(Dd, size(Dd, 1) * size(Dd, 2), size(Dd, 3))
Es = abs.(Ds .- Dd)'
Ns = [1:size(Ds, 2);]

p = plot(
    Ns[1:800], vec(median(Es[1:800, :]; dims=2)); ribbon=get_percentiles(Es), lw=3, label=""
)
plot!(;
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err.png")

res = JLD.load("result/lpdfs_err.jld")
# reshapw Ds Dd to 2d array where the last dimension is the original third dimension
Ds = res["lpdfs"]
Dd = res["lpdfs_big"]
Ds = reshape(Ds, size(Ds, 1) * size(Ds, 2), size(Ds, 3))
Dd = reshape(Dd, size(Dd, 1) * size(Dd, 2), size(Dd, 3))
Es = abs.(Ds .- Dd)' .+ 1e-16
Ns = [1:size(Ds, 2);]

p = plot(Ns, vec(median(Es; dims=2)); ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Error",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
plot!(;
    yaxis=:log10,
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err_log.png")

res = JLD.load("result/lpdfs_err.jld")
# reshapw Ds Dd to 2d array where the last dimension is the original third dimension
Ds = res["lpdfs"]
Dd = res["lpdfs_big"]
Ds = reshape(Ds, size(Ds, 1) * size(Ds, 2), size(Ds, 3))
Dd = reshape(Dd, size(Dd, 1) * size(Dd, 2), size(Dd, 3))
Es = abs.(Ds .- Dd)' ./ abs.(Dd)' .+ 1e-16
Ns = [1:size(Ds, 2);]

p = plot(Ns, vec(median(Es; dims=2)); ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Rel. err.",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err_rel.png")

p = plot(Ns, vec(median(Es; dims=2)) .+ 1e-16; ribbon=get_percentiles(Es), lw=3, label="")
plot!(;
    xlabel="#transformations",
    ylabel="Rel. err.",
    title="MixFlow log-density error",
    xrotation=20,
    legend=:topright,
)
# force yaxis to put enough 
plot!(;
    yaxis=:log10,
    yticks=[1e-15, 1e-12, 1e-8, 1e-5, 1e-2, 1],
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/lpdfs_err_log_rel.png")

######################
# elbo
######################
res = JLD.load("result/elbo_err.jld")
Ns = res["Ns"]
EL = res["EL"]
EL_big = res["EL_big"]
# el_err = abs.(EL_big .- EL)'

p1 = plot(Ns, vec(mean(EL; dims=1)); lw=3, label="numerical")
plot!(Ns, vec(mean(EL_big; dims=1)); lw=3, label="exact")
plot!(;
    xlabel="#transformations",
    ylabel="ELBO",
    title="MixFlow ELBO est.",
    xrotation=20,
    legend=:bottomright,
)
plot!(;
    size=(800, 500),
    xtickfontsize=25,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/elbos.png")

###################
# one step error
######################
# box plot E_fwd and E_bwd
E_fwd = vec(JLD.load("result/delta.jld")["E_fwd"]) .+ 1e-16
E_bwd = vec(JLD.load("result/delta.jld")["E_bwd"]) .+ 1e-16
p1 = boxplot(
    ["Fwd err." "Bwd err."], [E_fwd E_bwd]; legend=false, title="MixFlow single map err."
)
plot!(p1; xlabel="", ylabel="Error", yaxis=:log10)
plot!(;
    size=(800, 500),
    # yticks=[1e-16, 1e-11, 1e-10],
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)
savefig(p1, joinpath("figure/", "delta.png"))

#######################
# shadowing window
#######################

res = JLD.load("result/windows.jld")
Ns = res["Ns"]
W_fwd = res["W_fwd"]
W_bwd = res["W_bwd"]
T_fwd = res["T_fwd"]
T_bwd = res["T_bwd"]

p = plot(Ns, vec(median(W_fwd'; dims=2)); ribbon=get_percentiles(W_fwd'), lw=3, label="Fwd")
plot!(Ns, vec(median(W_bwd'; dims=2)); ribbon=get_percentiles(W_bwd'), lw=3, label="Bwd")
plot!(;
    xlabel="#transformations",
    ylabel="",
    title="MixFlow ϵ size",
    xrotation=20,
    legend=:topleft,
)
plot!(;
    size=(900, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/shadowing.png")

p = plot(Ns, vec(median(T_fwd'; dims=2)); ribbon=get_percentiles(T_fwd'), lw=3, label="Fwd")
plot!(Ns, vec(median(T_bwd'; dims=2)); ribbon=get_percentiles(T_bwd'), lw=3, label="Bwd")
plot!(;
    xlabel="#transformations",
    ylabel="Wall time in sec",
    title="ϵ computation time",
    xrotation=20,
    legend=:topleft,
)
plot!(;
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
)

savefig("figure/shadowing_time.png")
