

using Flux, Statistics, LaTeXStrings
using ForwardDiff, Plots
using Flux.Losses
using Zygote
using NPZ

using BSON: @load

lg_method="kl_t"

# loss_method = "kl"
# loss_method = "l2"
# loss_method = "kl_t_one_dis"
# loss_method = "kl_t"
# loss_method = "kl_l2_t"
loss_method = "kl_lf"


method = "node"
# method = "nn_sum"
# method = "rot_inv"
# method = "grad_p"
# method = "eos_nn"
# method = "Wnn"

r = 5.0;			#number of smoothing (r*hkde) lengths for determining bounds of integration in KL
h_kde = 0.9;
n_int = 90;
window = 0.2;
nb = "all";

α = 1.0; β = 2*α; h = 0.335; g = 7; θ = 5e-1; c = 10.0; c_gt = c;
cdt = 0.4;


ic_data_dir = "./data"

t_start = 20;
coarse_mult = 1;
T = 10;

if method == "node"
	height = 3;
	nn_data_dir = "./semi_inf_data/output_data_forward_node_kl_lf_Vrandn_itr801_lr0.005_T10_D3_N4096_c10.0_α1.0_β2.0_h0.335_nball_nint200_ts20_coarse1_height5_klswitch0"
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	p = params(NN); n_list = floor(Int, size(p[1])[2]/4)
	params_path = "$(nn_data_dir)/params_fin.npy"
	p_fin = npzread(params_path)
	include("sensitivities_3d_node.jl")
end


if method =="nn_sum"
	height = 3;
	# if loss_method=="kl_t"
	# 	nn_data_dir = "./semi_inf_data4/output_data_kl_t_nn_sum_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3_coarse5"
	# end
	if loss_method=="kl_t"
		nn_data_dir = "./semi_inf_data/output_data_kl_t_nn_sum_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3_coarse5"
	end
	# if loss_method=="kl_t"
	# 	nn_data_dir = "./semi_inf_data3/output_data_kl_l2_t_nn_sum_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3_coarse5"
	# end
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	include("sensitivities_nnsum.jl")
end



if method == "rot_inv"
	height = 3;
	if loss_method=="kl_t"
		nn_data_dir = "./semi_inf_data3/output_data_kl_t_rot_inv_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3_coarse5"
	end
	# coarse_mult = 5
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	include("sensitivities_rot_nn.jl")
end


if method =="grad_p"
	# if loss_method=="kl_t"
	# 	height = 20;
	# 	nn_data_dir = "./semi_inf_data3/output_data_adjoint_kl_t_one_dist_grad_p_s_hit_itr2000_lr0.05_T5_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt20_coarse3"
	# end
	if loss_method=="kl_t"
		nn_data_dir = "./semi_inf_data4/output_data_kl_t_grad_p_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt3_coarse5"
		height = 3;
	end
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	include("sensitivities_grad_p.jl")
end


if method =="eos_nn"
	height = 8;
	# nn_data_dir = "./semi_inf_data2/output_data_l2_eos_nn_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt8_coarse5"
	nn_data_dir = "./semi_inf_data3/output_data_l2_eos_nn_s_hit_itr2000_lr0.05_T10_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt8_coarse1"
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	include("sensitivities_eos_nn.jl")
end

if method =="Wnn"
	height = 10;
	# nn_data_dir = "./semi_inf_data2/output_data_l2_eos_nn_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt8_coarse5"
	nn_data_dir = "./output_data_forward_kl_Wnn_s_hit_itr3004_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint140_ts3205_hgt10_coarse1"
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	include("sensitivities_Wnn.jl")
end


traj_gt = npzread("./data/traj_N4096_T50_ts1_h0.335_Vrandn_cdt0.4_c10.0_α1.0_β2.0_θ0.5_AV_neg_rel.npy")
vels_gt = npzread("./data/vels_N4096_T50_ts1_h0.335_Vrandn_cdt0.4_c10.0_α1.0_β2.0_θ0.5_AV_neg_rel.npy")
rhos_gt = npzread("./data/rhos_N4096_T50_ts1_h0.335_Vrandn_cdt0.4_c10.0_α1.0_β2.0_θ0.5_AV_neg_rel.npy")

traj_gt = traj_gt[t_start:coarse_mult:end, :, :]
vels_gt = vels_gt[t_start:coarse_mult:end, :, :]
rhos_gt = rhos_gt[t_start:coarse_mult:end, :]


N = size(traj_gt)[2];
D = size(traj_gt)[3];
m = (2. * pi)^D / N; #so that ρ₀ = 1;
dt = coarse_mult * cdt * h / c_gt;



#--------------------Computing errors
function create_dir(dir_name)
	if isdir(dir_name) == true
	    println("directory already exists")
	else mkdir(dir_name)
	end
end

create_dir("generalization_data_coarse$(coarse_mult)"); create_dir("sims_learned"); create_dir("figures");


include("./kde_G.jl")
include("./sph_simulator.jl")
include("./loss_functions_2d.jl")



#------- over θ
α_test = 1.0; β_test = 2.0; c_test = 10.0;



function obtain_gen_errs_over_θ(T)
	θ_range = 0.5e-1 : 0.25 : 1.5
	L = zeros(size(θ_range)[1]); Lg = zeros(size(θ_range)[1]); rot_errθ = zeros(size(θ_range)[1]);
	ts_test = 900;
	ii = 1;
	for θ_test in θ_range
		traj_test_set = npzread("./data/traj_N1024_T1000_ts3205_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ$(θ_test)_AV_neg_rel_ke.npy")[ts_test:coarse_mult:end, :, :];
		vels_test_set = npzread("./data/vels_N1024_T1000_ts3205_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ$(θ_test)_AV_neg_rel_ke.npy")[ts_test:coarse_mult:end, :, :];

		Diff_test, Vel_inc_test =
		obtain_pred_dists(traj_test_set[1:(T+1),:,:], vels_test_set[1:(T+1),:,:], traj_test_set[1,:,:], vels_test_set[1,:,:], T);

		Gt_utest(x, t) = kde(x, Vel_inc_test[t, :, 1]); hu_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 1]);
		Gt_vtest(x, t) = kde(x, Vel_inc_test[t, :, 2]); hv_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 2]);

		traj_model_test_th, vels_model_test_th, rhos_model_test_th =
		vel_verlet_NN(traj_test_set, vels_test_set, p_fin, α, β, h, c, g, θ_test, T);
		Diff_model_test_th, Vel_inc_model_test_th =
		obtain_pred_dists(traj_model_test_th, vels_model_test_th, traj_test_set[1,:,:], vels_test_set[1,:,:], T);

		# L[ii] = compute_L(loss_method, Vel_inc_gt, Vel_inc_pred, hut_kde_test, hvt_kde_test, Gt_utest, Gt_vtest, T)
		Lg[ii] = compute_Lg(lg_method, Vel_inc_test, Vel_inc_model_test_th, hu_kde_test, hv_kde_test, Gt_utest, Gt_vtest, T)

		rot_errθ[ii] = rotational_metric_semi_inf(traj_test_set[T,:,:], vels_test_set[T,:,:], p_fin, c, h, α, β, θ_test, obtain_sph_AV_A)

		println("L = ", L[ii], "  Lg = ", Lg[ii], "  rot = ", rot_errθ[ii]);
		ii += 1;
	end
	return L, Lg, rot_errθ
end


Lθ, Lgθ, rot_errθ = obtain_gen_errs_over_θ(T)


#------- vary T, θ = 0.8e-1; (θ from training set)

θ_test = θ;
α_test = 1.0;
β_test = 2.0;
c_test = 10.0;

function obtain_gen_loss_t(t_s, t_end, t_skip)
	t_range = range(t_s, t_end, step=t_skip);
	num_t_idx = size(t_range)[1];
	println("num_t_idx  = ", num_t_idx);
	Lt = zeros(num_t_idx); Lgt = zeros(num_t_idx); rot_errt = zeros(num_t_idx);
	ii = 1;
	for T in t_range
		Diff_gt, Vel_inc_gt = obtain_gt_dists(traj_gt, vels_gt, T)
		G_u(x) = kde(x, Vel_inc_gt[T, :, 1]); hu_kde_gt = obtain_h_kde(Vel_inc_gt[T, :, 1])
		G_v(x) = kde(x, Vel_inc_gt[T, :, 2]); hv_kde_gt = obtain_h_kde(Vel_inc_gt[T, :, 2])

		Gt_u(x, t) = kde(x, Vel_inc_gt[t, :, 1]); hut_kde_gt(t) = obtain_h_kde(Vel_inc_gt[t, :, 1]);
		Gt_v(x, t) = kde(x, Vel_inc_gt[t, :, 2]); hvt_kde_gt(t) = obtain_h_kde(Vel_inc_gt[t, :, 2]);

		Gt_utest(x, t) = Gt_u(x, t); Gt_vtest(x, t) = Gt_v(x, t);

		traj, vels, rhos = vel_verlet_NN(traj_gt, vels_gt, p_fin, α, β, h, c, g, θ, T);
		Diff_pred, Vel_inc_pred = obtain_pred_dists(traj, vels, traj_gt[1,:,:], vels_gt[1,:,:], T);

		Diff_test, Vel_inc_test = Diff_pred, Vel_inc_pred
		hu_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 1])
		hv_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 2])

		traj_model_test_th, vels_model_test_th, rhos_model_test_th = traj, vels, rhos;
		Diff_model_test_th, Vel_inc_model_test_th = Diff_pred, Vel_inc_pred

		# Lt[ii] = compute_L(loss_method, Vel_inc_gt, Vel_inc_pred, hu_kde_test(T), hv_kde_test(T), Gt_utest, Gt_vtest, T)
		Lgt[ii] = compute_Lg(lg_method, Vel_inc_test, Vel_inc_model_test_th, hu_kde_test, hv_kde_test, Gt_utest, Gt_vtest, T)

		rot_errt[ii] = rotational_metric_semi_inf(traj_gt[T,:,:], vels_gt[T,:,:], p_fin, c, h, α, β, θ, obtain_sph_AV_A)

		println("T = ", T*coarse_mult, "   L = ", Lt[ii], "   Lg = ", Lgt[ii],  "  rot = ", rot_errt[ii]);
		ii += 1;
	end
	return Lt, Lgt, rot_errt
end


# Lt, Lgt, rot_errt = obtain_gen_loss_t(5, 155, 10) #coarse = 1;
# Lt, Lgt, rot_errt = obtain_gen_loss_t(2, 48, 3) #coarse = 3;
Lt, Lgt, rot_errt = obtain_gen_loss_t(1, 31, 2) #coarse = 5;
# Lt, Lgt, rot_errt = obtain_gen_loss_t(5, 35, 2) #coarse = 1;
# Lt, Lgt, rot_errt = obtain_gen_loss_t(1, 16, 1) #coarse = 3;



#-----------Output data

function save_output_data(data, path)
    npzwrite(path, data)
end

save_output_data(Lθ, "./generalization_data_coarse$(coarse_mult)/loss_theta_$(method)_$(loss_method).npy")
save_output_data(Lgθ, "./generalization_data_coarse$(coarse_mult)/loss_g_theta_$(method)_$(loss_method).npy")
save_output_data(rot_errθ, "./generalization_data_coarse$(coarse_mult)/rot_error_theta_$(method)_$(loss_method).npy")

save_output_data(Lt, "./generalization_data_coarse$(coarse_mult)/loss_t_$(method)_$(loss_method).npy")
save_output_data(Lgt, "./generalization_data_coarse$(coarse_mult)/loss_g_t_$(method)_$(loss_method).npy")
save_output_data(rot_errt, "./generalization_data_coarse$(coarse_mult)/rot_error_t_$(method)_$(loss_method).npy")










sim_path = "./sims_learned/traj_N$(N)_T$(T)_h$(h)_cdt$(cdt)_c$(c)_α$(α)_β$(β)_θ$(θ)_$(method)_$(loss_method).mp4"


function simulate(pos, sim_time=10)
    gr(size=(1000,800))
    println("**************** Simulating the particle flow ***************")
    #theme(:juno)
    n_2 = round(Int,N/2)
    anim = @animate for i ∈ 1:2:(T+1)
         Plots.scatter(pos[i, 1:n_2, 1], pos[i, 1:n_2, 2],
         title = "WCSPH_$(method): N=$(N), h=$(h), c=$(c)", xlims = [0, 2*pi], ylims = [0,2*pi], legend = false)
         Plots.scatter!(pos[i, (n_2+1):end, 1], pos[i, (n_2+1):end, 2], color = "red")
    end
    gif(anim, sim_path, fps = round(Int, T/sim_time))
    println("****************  Simulation COMPLETE  *************")
end

# simulate(traj, 4)
# simulate_same_color(traj, 8)


println("$(method) complete")
