

using Plots, NPZ, LaTeXStrings
using StatsPlots
ENV["GKSwstype"]="100" #set env variable for UAHPC


#-----Load data

function load_losses(method)
	L = npzread("./learned_generalization/loss_t_$(method)_losskl_lf_kl_t.npy")
	Lg = npzread("./learned_generalization/loss_g_t_$(method)_losskl_lf_kl_t.npy")
	return L, Lg
end

function load_losses_theta(method)
	L = npzread("./learned_generalization/Lθ_$(method)_kl_lf_kl_t.npy")
	Lg = npzread("./learned_generalization/Lgθ_$(method)_kl_lf_kl_t.npy")
	return L, Lg
end

#---Plotting

function plot_rot_error(θ_idx)
    gr(size=(800,800))
    println("*************** plot gen_error ******************")

    rot_errs = [rot_rfθ_node[θ_idx], rot_rfθ_nnsum[θ_idx], rot_rfθ_nnsum2[θ_idx],
                rot_rfθ_gradp[θ_idx], rot_rfθ_rotinv[θ_idx], rot_rfθ_eos[θ_idx], rot_rfθ_phy[θ_idx]];

    methods = [L"NODE", L"\sum_j NN_{ij}", L"\sum_j NN2_{ij}", L"\nabla P_{NN}", L"RI_{NN}", L"EoS_{NN}", L"Phys"];
    # methods = [L"NODE" L"\sum_i NN_{ij}" L"\sum_i NN2_{ij}" L"\nabla P_{nn}" L"rotNN_{inv}" L"EoS_{nn}" L"Phys"];

    plt_rot = bar(methods, rot_errs, yaxis=:log, color="blue", label=false)
    title!(L"\textrm{Rotational - Invariance - Errors }", titlefont=16)
    xlabel!(L"\textrm{Method}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"||F(RX) - RF(X)||_2", ytickfontsize=10, yguidefontsize=14)
    display(plt_rot)
    path_rot = "./gen_error_figures/bar_plot_rot_err_coarse$(coarse).png"
    savefig(plt_rot, path_rot)
end



function plot_gen_error(θ_idx, t_idx)
    gr(size=(800,800))
    println("*************** plot gen_error ******************")

    gen_errsθ = [Lgθ_node[θ_idx], Lgθ_nnsum[θ_idx], Lgθ_nnsum2[θ_idx],
                Lgθ_gradp[θ_idx], Lgθ_rotinv[θ_idx], Lgθ_eos[θ_idx], Lgθ_phy[θ_idx]];
    lossesθ = [Lθ_node[θ_idx], Lθ_nnsum[θ_idx], Lθ_nnsum2[θ_idx],
                Lθ_gradp[θ_idx], Lθ_rotinv[θ_idx], Lθ_eos[θ_idx], Lθ_phy[θ_idx]];
    # rot_errs = [rot_rfθ_node[θ_idx], rot_rfθ_nnsum[θ_idx], rot_rfθ_nnsum2[θ_idx],
    #             rot_rfθ_gradp[θ_idx], rot_rfθ_rotinv[θ_idx], rot_rfθ_eos[θ_idx], rot_rfθ_phy];

    gen_errs_t = [Lgt_node[t_idx], Lgt_nnsum2[t_idx], Lgt_nnsum[t_idx],
                Lgt_gradp[t_idx], Lgt_rotinv[t_idx], Lgt_eos[t_idx], Lgt_phy[t_idx]];
    losses_t = [Lt_node[t_idx], Lt_nnsum[t_idx], Lt_nnsum2[t_idx],
                Lt_gradp[t_idx], Lt_rotinv[t_idx], Lt_eos[t_idx],  Lt_phy[t_idx]];

    methods = [L"NODE", L"\sum_j NN_{ij}", L"\sum_j NN2_{ij}", L"\nabla P_{NN}", L"RI_{NN}", L"EoS_{NN}", L"Phys"];

    plt_θ = bar(methods, gen_errsθ, yaxis=:log, color="blue", label=false)
    title!(L"\textrm{Comparing Generalization Error Over } \theta", titlefont=16)
    xlabel!(L"\textrm{Method}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"\textrm{KL}", ytickfontsize=10, yguidefontsize=14)
    display(plt_θ)

    plt_t = bar(methods, gen_errs_t, yaxis=:log, color="blue", label=false)
    title!(L"\textrm{Comparing Generalization Error Over } t", titlefont=16)
    xlabel!(L"\textrm{Method}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"\textrm{KL}", ytickfontsize=10, yguidefontsize=14)
    display(plt_t)

    path_θ = "./gen_error_figures/bar_plot_gen_errθ_coarse$(coarse).png"
    path_t = "./gen_error_figures/bar_plot_gen_errt_coarse$(coarse).png"
    savefig(plt_θ, path_θ)
    savefig(plt_t, path_t)
end




function obtain_box_plot_t(Lgs)
    gr(size=(700,700))
    path_lg = "./learned_figures/Lg_box_plot_over_t.png"
    # path_rot = "./gen_error_figures/rot_box_plot_over_t_coarse$(coarse).png"

    # methods = ["node" "nn_sum" "nnsum2" "gradp" "rotnn" "eos" "phy"];
    # methods = [L"NODE" L"\sum_i NN_{ij}" L"\sum_i NN2_{ij}" L"\nabla P_{nn}" L"Rot_{inv}NN" L"EoS_{nn}" L"Phys"];
    methods = [L"NODE" L"\sum_j NN_{ij}" L"RI_{NN}" L"\nabla P_{nn}" L"EoS_{NN}" L"Phys"];

    plt = boxplot(methods, Lgs, yaxis=:log, legend=false, outliers=false)
    title!(L"\textrm{Generalization-Error-Over: t} ", titlefont=20)
    xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=20)
    ylabel!(L"L_{kl} + L_{f}", ytickfontsize=14, yguidefontsize=20)

    display(plt)
    savefig(plt, path_lg)

    # plt_rot = boxplot(methods, rot_errts, yaxis=:log, legend=false)
    # title!(L"\textrm{Rotational - Invariance - Error - Over - t} ", titlefont=16)
    # xlabel!(L"\textrm{Method}", xtickfontsize=10, xguidefontsize=16)
    # ylabel!(L"||F(RX) - RF(X)||_2", ytickfontsize=10, yguidefontsize=14)
    # display(plt_rot)
    # savefig(plt_rot, path_rot)
end

function obtain_box_plot_theta(Lgs)
    gr(size=(700,700))
	path_lg = "./learned_figures/Lg_box_plot_over_theta.png"
    # path_rot = "./gen_error_figures/rot_box_plot_over_t_coarse$(coarse).png"

    # methods = ["node" "nn_sum" "nnsum2" "gradp" "rotnn" "eos" "phy"];
    # methods = [L"NODE" L"\sum_i NN_{ij}" L"\sum_i NN2_{ij}" L"\nabla P_{nn}" L"Rot_{inv}NN" L"EoS_{nn}" L"Phys"];
	methods = [L"NODE" L"\sum_j NN_{ij}" L"RI_{NN}" L"\nabla P_{nn}" L"EoS_{NN}" L"Phys"];

    plt = boxplot(methods, Lgs, yaxis=:log, legend=false, outliers=false)
    title!(L"\textrm{Generalization-Error-Over: } \theta ", titlefont=20)
    xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=20)
    ylabel!(L"L_{kl} + L_{f}", ytickfontsize=14, yguidefontsize=20)

    display(plt)
    savefig(plt, path_lg)
	#
    # plt_rot = boxplot(methods, rot_errθs, yaxis=:log, legend=false)
    # title!(L"\textrm{Rotational - Invariance - Error - Over } \theta", titlefont=16)
    # xlabel!(L"\textrm{Method}", xtickfontsize=10, xguidefontsize=16)
    # ylabel!(L"||F(RX) - RF(X)||_2", ytickfontsize=10, yguidefontsize=14)
    # display(plt_rot)
    # savefig(plt_rot, path_rot)
end






#---------Plotting

# plot_gen_error(3, 5)
# plot_rot_error(2)

L_node_t, Lg_node_t = load_losses("node")
L_nnsum_t, Lg_nnsum_t = load_losses("nnsum")
L_rot_inv_t, Lg_rot_inv_t = load_losses("rot_inv")
L_grad_p_t, Lg_grad_p_t = load_losses("grad_p")
L_eos_nn_t, Lg_eos_nn_t = load_losses("eos_nn")
L_phys_inf_t, Lg_phys_inf_t = load_losses("phys_inf")


Lθ_node_t, Lgθ_node_t = load_losses_theta("node")
Lθ_nnsum_t, Lgθ_nnsum_t = load_losses_theta("nnsum")
Lθ_rot_inv_t, Lgθ_rot_inv_t = load_losses_theta("rot_inv")
Lθ_grad_p_t, Lgθ_grad_p_t = load_losses_theta("grad_p")
Lθ_eos_nn_t, Lgθ_eos_nn_t = load_losses_theta("eos_nn")
Lθ_phys_inf_t, Lgθ_phys_inf_t = load_losses_theta("phys_inf")


Lgts = hcat(L_node_t, L_nnsum_t, L_rot_inv_t, L_grad_p_t,
            L_eos_nn_t, L_phys_inf_t);

Lgθs = hcat(Lθ_node_t, Lθ_nnsum_t, Lθ_rot_inv_t, Lθ_grad_p_t,
			Lθ_eos_nn_t, Lθ_phys_inf_t);

# Lgθs_kl = hcat(Lgθ_node_t, Lgθ_nnsum_t, Lgθ_rot_inv_t, Lgθ_grad_p_t,
# 			Lgθ_eos_nn_t, Lgθ_phys_inf_t);

# rot_errts = hcat(rot_rft_node[1:16], rot_rft_nnsum[1:16], rot_rft_nnsum2[1:16], rot_rft_gradp[1:16],
#             rot_rft_rotinv[1:16], rot_rft_eos[1:16], rot_rft_phy[1:16]);
#
#
# obtain_box_plot_t(Lgts, rot_errts)
obtain_box_plot_t(Lgts)
obtain_box_plot_theta(Lgθs)


# obtain_box_plot_theta(Lgθs_kl)


# rot_errθs = hcat(rot_rfθ_node[1:6], rot_rfθ_nnsum[1:6], rot_rfθ_nnsum2[1:6], rot_rfθ_gradp[1:6],
#             rot_rfθ_rotinv[1:6], rot_rfθ_eos[1:6], rot_rfθ_phy[1:6]);
#
# Lgθs = hcat(Lgθ_node[1:6], Lgθ_nnsum[1:6], Lgθ_nnsum2[1:6], Lgθ_gradp[1:6],
#             Lgθ_rotinv[1:6], Lgθ_eos[1:6], Lgθ_phy[1:6]);
#
# obtain_box_plot_theta(Lgθs, rot_errθs)










#
#
#
# #------ Ploting functions
#
#
# function plotting_Lg()
#     gr(size=(500,500))
#     x_s = -width
#     x_e = width
#     G_pred(x) = kde(x, Vel_inc_pred[T, :, 1])
#     plt = plot(x->G_u(x), x_s, x_e, label=L"G(\tau, d)",
#                 color="indigo", linewidth = 2.5)
#     plot!(x->G_pred(x), x_s, x_e, marker=:x, markersize=4, color="forestgreen",
#           markercolor = :black, label=L"G_{\theta}(\tau, d)",
#           linestyle=:dash, linewidth = 2.5)
#
#     title!(L"\textrm{Comparing - } G(\tau,z) \textrm{ - and - } \hat{G}_{\theta}(\tau,z)", titlefont=18)
#     xlabel!(L"\textrm{Velocity - increment}", xtickfontsize=12, xguidefontsize=20)
#     ylabel!(L"G_u \textrm{ - distribution}", ytickfontsize=12, yguidefontsize=20)
#
#     display(plt)
#     if path == "train"
#         out_path = "./figures/$(method)_Gtrain_kde_$(file_title)_θ$(θ_in).png"
#     elseif path == "test"
#         out_path = "./figures/$(method)_Gtest_kde_$(file_title)_θ$(θ_in).png"
#     end
#     savefig(plt, out_path)
# end
#
#
# function plot_lg_comp()
#     gr(size=(600,500))
#     println("*************** plot gen_error ******************")
#     plt = bar(methods, gen_errs, yaxis=:log, color="blue", label=false)
#
#     title!(L"\textrm{Comparing - Generalization - Error}", titlefont=16)
#     xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=16)
#     ylabel!(L"\textrm{KL}", ytickfontsize=10, yguidefontsize=16)
#
#     display(plt)
#     savefig(plt, "./$(out_file_title)/figures/$(out_file_title)_gen_err.png")
# end


#----symmetry errors

function load_symmetry_errors(method)
	rot = npzread("./learned_generalization/rot_RF_$(method)_loss$(loss_method).npy")
	gal = npzread("./learned_generalization/gal_inv_$(method)_loss$(loss_method).npy")
	return rot, gal
end

rot_node, gal_node = load_symmetry_errors("node")
rot_nnsum, gal_nnsum = load_symmetry_errors("nnsum")
rot_rot_inv, gal_rot_inv = load_symmetry_errors("rot_inv")
rot_grad_p, gal_grad_p = load_symmetry_errors("grad_p")
rot_eos_nn, gal_eos_nn = load_symmetry_errors("eos_nn")
rot_phys_inf, gal_phys_inf = load_symmetry_errors("phys_inf")

rots = vcat(rot_node, rot_nnsum, rot_rot_inv, rot_grad_p,
            rot_eos_nn, rot_phys_inf);

gals = vcat(gal_node, gal_nnsum, gal_rot_inv, gal_grad_p,
			gal_eos_nn, gal_phys_inf);


function plot_rot_gal()
    gr(size=(700,700))
    println("*************** plot gen_error ******************")
	methods = [L"NODE", L"\sum_j NN_{ij}", L"RI_{NN}", L"\nabla P_{NN}", L"EoS_{NN}", L"Phys"];

    plt = bar(methods, rots, yaxis=:log, color="blue", label=false)

    title!(L"\textrm{Rotational - Symmetry - Error}", titlefont=20)
    xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=20)
    ylabel!(L"||F(RX) - RF(X)||_2", ytickfontsize=14, yguidefontsize=20)

	# title!(L"\textrm{Generalization-Error-Over: t} ", titlefont=20)
    # xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=20)
    # ylabel!(L"L_{kl} + L_{f}", ytickfontsize=14, yguidefontsize=20)

	 display(plt)
	 savefig(plt, "./learned_figures/rots_bar_err.png")


	plt2 = bar(methods, gals, yaxis=:log, color="blue", label=false)

    title!(L"\textrm{Translational - Symmetry - Error}", titlefont=20)
    xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=20)
    ylabel!(L"||F(X) - F(X - s)||_2", ytickfontsize=14, yguidefontsize=20)

	  display(plt2)
	  savefig(plt2, "./learned_figures/gals_bar_err.png")
end


plot_rot_gal()
