using Plots, NPZ, LaTeXStrings, LinearAlgebra, Flux
using Flux.Losses, Statistics
using BSON: @load

ENV["GKSwstype"]="100" #set env variable for UAHPC

T = 2;
t_start = 3205;
const r = 5.0;			#number of smoothing (r*hkde) lengths for determining bounds of integration in KL
const n_int = 90;
h_kde = 0.9


# traj_gt = npzread("./data/traj_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
# vels_gt = npzread("./data/vels_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
# rhos_gt = npzread("./data/rhos_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
#
# traj_gt = traj_gt[t_start:end, :, :]
# vels_gt = vels_gt[t_start:end, :, :]
# rhos_gt = rhos_gt[t_start:end, :]

traj_gt_path = "./learned_data/traj_N4096_T110_ts1_h0.335_Vrandn_c10.0_α1.0_β2.0_θ0.5_truth.npy"
vels_gt_path = "./learned_data/vels_N4096_T110_ts1_h0.335_Vrandn_c10.0_α1.0_β2.0_θ0.5_truth.npy"
rhos_gt_path = "./learned_data/rhos_N4096_T110_ts1_h0.335_Vrandn_c10.0_α1.0_β2.0_θ0.5_truth.npy"
traj_gt = npzread(traj_gt_path);
vels_gt = npzread(vels_gt_path);
rhos_gt = npzread(rhos_gt_path);


t_, N, D = size(traj_gt);
m = (2. * pi)^D / N; #so that ρ₀ = 1;
α = 1.0
β = 2.0*α  #usual value of params for alpha and β but these depend on problem
θ = 5e-1;
c = 10.0
g = 7.0
h = 0.335
cdt = 0.4;
dt = cdt * h / c;
h_kde = 0.9;
r = 5.0;
n_int = 200;

#---KDE smoothing kernel for density estimation (both for gt and pred data)
K(x) = 1/(sqrt(2*pi))*exp(-x^2/2) #guassian kernel (best results so far)
# K(x) = maximum([1 - abs(x), 0])     #triangle (produces frequency polygon)
# K(x) = 3/4*maximum([1 - x^2, 0])   #Epanechnikov
K_prime(x) = ForwardDiff.derivative(x -> K(x), x) #(used in computing ∇L)


function create_dir(dir_name)
	if isdir(dir_name) == true
	    println("directory already exists")
	else mkdir(dir_name)
	end
end

# create_dir(out_file_title);
create_dir("./learned_generalization/figures");


#----- Load data
method = "node"; latex_method = "NODE"
# method = "nnsum"; latex_method = "NN Sum"
# method = "rot_inv"; latex_method = "Rot Inv"
# method = "eos_nn"; latex_method = "EoS NN"
# method = "grad_p"; latex_method = "Grad P"
# method = "Wnn"; latex_method = "W NN"
# method = "phys_inf"; latex_method = "Phys Inf"
# method = "truth"; latex_method = "Truth"

function load_losses(method)
	L = npzread("./learned_generalization/loss_t_$(method)_kl_lf_kl_t.npy")
	Lg = npzread("./learned_generalization/loss_g_t_$(method)_kl_lf_kl_t.npy")
	return L, Lg
end

L_node_t, Lg_node_t = load_losses("node")
L_nnsum_t, Lg_nnsum_t = load_losses("nnsum")
L_rot_inv_t, Lg_rot_inv_t = load_losses("rot_inv")
L_grad_p_t, Lg_grad_p_t = load_losses("grad_p")
L_eos_nn_t, Lg_eos_nn_t = load_losses("eos_nn")
L_phys_inf_t, Lg_phys_inf_t = load_losses("phys_inf")

#
# #----------Semi-informed data;
# traj_test_set = npzread("./data/traj_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.16_AV_neg_rel_ke.npy")
# vels_test_set = npzread("./data/vels_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.16_AV_neg_rel_ke.npy")
# traj_test_set = traj_test_set[t_start:end, :, :]
# vels_test_set = vels_test_set[t_start:end, :, :]
# t_, N, D = size(traj_test_set);
#
#
# include("kde_G.jl")
# include("loss_functions_2d.jl")
#
#
# const θ_test = 1.6e-1;
# const α_test = 1.0;
# const β_test = 2.0;
# const c_test = 12.0;
#
# Q,R = qr(randn(D,D)); Q = Q*Diagonal(sign.(diag(R))); #random orthogonal matrix
# R_90 = [0.0 -1.0; 1.0 0.0]
# function rotational_metric(X, V, p, c, h, α, β, θ, model_A)
# 	Fqy, rh_ = model_A((Q*X')', (Q*V')', p, c, h, α, β, θ)
# 	Fry, rh_ = model_A((R_90*X')', (R_90*V')', p, c, h, α, β, θ)
# 	F, rh_ = model_A(X, V, p, c, h, α, β, θ)
# 	QF = (Q*F')'
# 	RF = (R_90 * F')'
# 	return mse(Fqy, QF), mse(Fry, RF)
# end
#

#---- generalization erros

function generalization_errors(path, method, traj_test_set, vels_test_set, θ_test, T)
	@load "$(path)/NN_model.bson" NN
	p_, re = Flux.destructure(NN)
	println(size(p_))
	include("./acceleration_$(method).jl")
	include("./sph_simulator.jl")
	p_fin = npzread("$(path)/params_intermediate.npy")
	println(size(p_fin))
	c_gt = c;

	Diff_test, Vel_inc_test =
	obtain_pred_dists(traj_test_set[1:(T+1),:,:], vels_test_set[1:(T+1),:,:], traj_test_set[1,:,:], vels_test_set[1,:,:]);

	Gu_test(x) = kde(x, Vel_inc_test[T, :, 1]); hu_kde_test = obtain_h_kde(Vel_inc_test[T, :, 1])
	Gv_test(x) = kde(x, Vel_inc_test[T, :, 2]); hv_kde_test = obtain_h_kde(Vel_inc_test[T, :, 2])

	traj_model_test, vels_model_test, rhos_model_test =
	vel_verlet_NN(traj_test_set, vels_test_set, p_fin, α, β, h, c, g, θ_test, T);

	Diff_model_test, Vel_inc_model_test =
	obtain_pred_dists(traj_model_test, vels_model_test, traj_test_set[1,:,:], vels_test_set[1,:,:]);
	# comparing_Gu(Gu_test, Vel_inc_model_test, "test", θ_test, 0.05);

	traj_pred, vels_pred, rhos_pred = vel_verlet_NN(traj_gt, vels_gt, p_fin, α, β, h, c, g, θ, T);
	Diff_pred, Vel_inc_pred = obtain_pred_dists(traj_pred, vels_pred, traj_gt[1,:,:], vels_gt[1,:,:]);

	L_final = Ikl_fixed_τ(Vel_inc_gt, Vel_inc_pred, G_u, G_v, hu_kde_gt, hv_kde_gt);
	Lg_final = Ikl_fixed_τ(Vel_inc_test, Vel_inc_model_test, Gu_test, Gv_test, hu_kde_test, hv_kde_test);

	rot_qrr, rot_err_final = rotational_metric(traj_gt[1,:,:], vels_gt[1,:,:], p_fin, c, h, α, β, θ, obtain_sph_AV_A);

	return L_final, Lg_final, rot_err_final
end



#---- node

# L_node_final = npzread("$(node_data_dir)/loss.npy")[end];
# Lg_node_final = npzread("$(node_data_dir)/gen_loss.npy")[end];
# rot_err_node_final = npzread("$(node_data_dir)/rot_error_rf.npy")[end];

L_node_final, Lg_node_final, rot_err_node_final =
	generalization_errors(node_data_dir, "node", traj_test_set, vels_test_set, θ_test, T);




#
#---- nnsum

L_nnsum_final, Lg_nnsum_final, rot_err_nnsum_final =
	generalization_errors(nnsum_data_dir, "nnsum", traj_test_set, vels_test_set, θ_test, T);




#----- nnsum 2

L_nnsum2_final, Lg_nnsum2_final, rot_err_nnsum2_final =
	generalization_errors(nnsum2_data_dir, "nnsum2", traj_test_set, vels_test_set, θ_test, T);




#
#
#
# #---- nnsum2
#
# # method = "nn_sum"
# # method = "nn_sum2"
# # method = "rot_inv"
# # method = "node"
# # method = "eos_nn"
# # method = "grad_p"
#
# L_nnsum_final, Lg_nnsum_final, rot_err_nnsum_final =
# 	generalization_errors(path, method, traj_test_set, vels_test_set, θ_test, T)
#
#
#
# #---- rotnn
#
#
# L_rotnn_final = npzread("$(rotnn_data_dir)/loss.npy")[end];
# Lg_rotnn_final = npzread("$(rotnn_data_dir)/gen_loss.npy")[end];
# rot_err_rotnn_final = npzread("$(rotnn_data_dir)/rot_error_rf.npy")[end];
#
#
#
#
#
#
# phys_inf_data_dir = "./basic_physics_informed/output_data_fsa_4p_cabg_KL_s_hit_itr3000_lr0.05_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nball_nint90_ts3210"
# node_data_dir = "./optimiz_semi_inf/output_data_kl_node_s_hit_itr1500_lr0.04_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3/"
# #needs gen comp
# nnsum_data_dir = "./optimiz_semi_inf/output_data_kl_nn_sum_s_hit_itr1400_lr0.05_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3/"
# #needs gen comp
# nnsum2_data_dir = "./optimiz_semi_inf/output_data_kl_nn_sum2_s_hit_itr1400_lr0.05_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3/"
# #ok
# rotnn_data_dir = "./optimiz_semi_inf/output_data_kl_rot_inv_s_hit_itr1800_lr0.04_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt2/"
# gradp_data_dir = "./optimiz_semi_inf/output_data_kl_grad_p_s_hit_itr1800_lr0.05_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3/"
# #needs gen comp
# eosnn_data_dir = "./optimiz_semi_inf/output_data_kl_eos_nn_s_hit_itr1800_lr0.05_T2_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3/"
#
#
#
#
#
#
#
#
#
#
#
#
#
# #-------Generate outputs
#
# # Lg_phys_final = 0.00012897;
# # L_phys_final = 7.241973e-5;
# # rot_err_phys_final = 10^(-32);
# #
# # Lg_node_final = 10^(0.92)
# # L_node_final = 10^(-3.5)
# # rot_err_node_final = 10^(-0.09)
# #
# # L_rotnn_final = 10^(-1.8)
# # Lg_rotnn_final = 10^(0.87)
# # rot_err_rotnn_final = 10^(-27.9)
# #
# # L_nnsum2_final = 10^(-2.1)
# # Lg_nnsum2_final = 10^(-1.9)
# # rot_err_nnsum2_final = 10^(1.97)
# #
# # gen_errs = [Lg_node_final, Lg_nnsum2_final, Lg_rotnn_final, Lg_phys_final];
# # losses = [L_node_final, L_nnsum2_final, L_rotnn_final, L_phys_final]
# # rot_errs = [rot_err_node_final, rot_err_nnsum2_final, rot_err_rotnn_final, rot_err_phys_final]
# #
# # methods = ["node", "nn_sum", "rot_nn", "phys_inf"];
# # # bar(methods, gen_errs, yaxis=:log, color="blue")
# # # title!("Generalization Error")
# #
# # function plot_lg_comp()
# #     gr(size=(600,500))
# #     println("*************** plot gen_error ******************")
# #     plt = bar(methods, gen_errs, yaxis=:log, color="blue", label=false)
# #
# #     title!(L"\textrm{Comparing - Generalization - Error}", titlefont=16)
# #     xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=16)
# #     ylabel!(L"\textrm{KL}", ytickfontsize=10, yguidefontsize=16)
# #
# #     display(plt)
# #     savefig(plt, "./$(out_file_title)/figures/$(out_file_title)_gen_err.png")
# # end
# #
# # function plot_l_comp()
# #     gr(size=(600,500))
# #     println("*************** plot loss ******************")
# #     plt = bar(methods, losses, yaxis=:log, color="blue")
# #
# #     title!(L"\textrm{Comparing - Losses}", titlefont=16)
# #     xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=16)
# #     ylabel!(L"\textrm{KL}", ytickfontsize=10, yguidefontsize=16)
# #
# #     display(plt)
# #     savefig(plt, "./$(out_file_title)/figures/$(out_file_title)_loss.png")
# # end
# #
# # function plot_rot_comp()
# #     gr(size=(600,500))
# #     println("*************** plot loss ******************")
# #     plt = bar(methods, rot_errs, yaxis=:log, color="blue")
# #
# #     title!(L"\textrm{Comparing - Error - in - Rotational - Symmetry}", titlefont=16)
# #     xlabel!(L"\textrm{Method}", xtickfontsize=14, xguidefontsize=16)
# #     ylabel!(L"\textrm{KL}", ytickfontsize=10, yguidefontsize=16)
# #
# #     display(plt)
# #     savefig(plt, "./$(out_file_title)/figures/$(out_file_title)_rot_err.png")
# # end
# #
# # plot_l_comp()
# # plot_lg_comp()
# # plot_rot_comp()
