include("mf_setting.jl")

###############################
file_name = "inv_2048.jld"
Ns = [50, 100, 200, 500, 800]
inversion_err(o, a_big, Ns; nsample=10, res_name=file_name)

# plot the inversino error
res = JLD.load(joinpath("result", file_name))
Ns = res["Ns"]
EE_fwd = res["fwd_err"] .+ 1e-50
EE_bwd = res["bwd_err"] .+ 1e-50
p1 = plot(
    Ns,
    median(EE_fwd; dims=2);
    ribbon=get_percentiles(EE_fwd),
    label="Fwd err.",
    legend=:topleft,
    title="MixFlow inversion error",
    xlabel="N",
    ylabel="Error",
)
plot!(Ns, median(EE_bwd; dims=2); ribbon=get_percentiles(EE_bwd), label="Bwd err.")
plot!(;
    yaxis=:log10,
    size=(800, 500),
    xtickfontsize=30,
    ytickfontsize=30,
    margin=10Plots.mm,
    guidefontsize=30,
    legendfontsize=20,
    titlefontsize=30,
    xtick=[0:200:800;],
)
savefig(p1, joinpath("figure", "inversion_err_2048.png"))
