

using Plots, NPZ
using Flux, Statistics, LaTeXStrings
using ForwardDiff, Plots
using Flux.Losses
using Zygote
using NPZ


lg_method="kl_t"

# loss_method = "kl"
# loss_method = "l2"
# loss_method = "kl_t_one_dis"
loss_method = "kl_t"
# loss_method = "kl_l2_t"

method = "phys_inf"


r = 5.0;			#number of smoothing (r*hkde) lengths for determining bounds of integration in KL
h_kde = 0.9;
n_int = 200;
window = 0.2;
nb = "all";

α = 1.0; β = 2*α; h = 0.2; g = 7; θ = 0.8e-1; c = 12.0; c_gt = c;
cdt = 0.4;

T = 5;
t_start = 3205;
coarse_mult = 3;

ic_data_dir = "/home/adele/sph_learning/analytic_gradient_method/physics_informed_sph_learning"
# phys_inf_data_dir = "./phys_inf_data3/output_data_kl_fsa_4p_cabg_s_hit_itr6002_lr0.05_T5_D2_N1024_c12.0_α1.0_β2.0_h0.2_nball_nint90_ts3205_coarse3"
# phys_inf_data_dir = "./phys_inf_dataA1/output_data_fsa_4p_cabg_s_hit_itr402_lr0.05_T15_D2_N1024_c12.0_α1.0_β2.0_h0.2_nball_nint140_ts3205_coarse1"

# phys_inf_data_dir = "./phys_inf_dataA1/output_data_fsa_4p_cabg_s_hit_itr7007_lr0.05_T10_D2_N1024_c12.0_α1.0_β2.0_h0.2_nball_nint150_ts3205_coarse1"
phys_inf_data_dir = "./phys_inf_dataA2/output_data_fsa_4p_cabg_woθ_kl_t_s_hit_itr3007_lr0.05_T6_D2_N1024_c12.0_α1.0_β2.0_h0.2_nball_nint140_ts3205_coarse1"


# traj_gt = npzread("$(ic_data_dir)/data/traj_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
# vels_gt = npzread("$(ic_data_dir)/data/vels_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
# rhos_gt = npzread("$(ic_data_dir)/data/rhos_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")

traj_gt = npzread("../data/traj_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
vels_gt = npzread("../data/vels_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")
rhos_gt = npzread("../data/rhos_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")

traj_gt = traj_gt[t_start:coarse_mult:end, :, :]
vels_gt = vels_gt[t_start:coarse_mult:end, :, :]
rhos_gt = rhos_gt[t_start:coarse_mult:end, :]


N = size(traj_gt)[2];
D = size(traj_gt)[3];
m = (2. * pi)^D / N; #so that ρ₀ = 1;
cdt = 0.4; #dt = cdt * h / c;
dt = coarse_mult * cdt * h / c_gt;



c_out = npzread("$(phys_inf_data_dir)/c_out.npy");
α_out = npzread("$(phys_inf_data_dir)/alpha_out.npy");
β_out = npzread("$(phys_inf_data_dir)/beta_out.npy");
g_out = npzread("$(phys_inf_data_dir)/g_out.npy");


c_hat = c_out[end]; α_hat = α_out[end]; β_hat = β_out[end]; g_hat = g_out[end]; h = 0.2;


function create_dir(dir_name)
	if isdir(dir_name) == true
	    println("directory already exists")
	else mkdir(dir_name)
	end
end

create_dir("generalization_data_coarse$(coarse_mult)"); create_dir("sims_learned"); create_dir("figures");


include("./kde_G.jl")
include("./sph_simulator.jl")
include("./loss_functions_2d.jl")



traj, vels, rhos = vel_verlet(traj_gt, vels_gt, α_hat, β_hat, h, c_hat, g_hat, θ, T)
Diff_pred, Vel_inc_pred = obtain_pred_dists(traj, vels, traj_gt[1,:,:], vels_gt[1,:,:], T);


#------- over θ
shift = 3002;
α_test = 1.0; β_test = 2.0; c_test = 12.0;

function obtain_gen_errs_over_θ()
	θ_range = 0.4e-1 : 0.02 : 1.6e-1
	L = zeros(size(θ_range)[1]); Lg = zeros(size(θ_range)[1]); rot_errθ = zeros(size(θ_range)[1]);
	ts_test = 900;
	ii = 1;
	for θ_test in θ_range
		traj_test_set = npzread("./data/traj_N1024_T1000_ts3205_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ$(θ_test)_AV_neg_rel_ke.npy")[ts_test:coarse_mult:end, :, :];
		vels_test_set = npzread("./data/vels_N1024_T1000_ts3205_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ$(θ_test)_AV_neg_rel_ke.npy")[ts_test:coarse_mult:end, :, :];

		Diff_test, Vel_inc_test =
		obtain_pred_dists(traj_test_set[1:(T+1),:,:], vels_test_set[1:(T+1),:,:], traj_test_set[1,:,:], vels_test_set[1,:,:], T);

		Gt_utest(x, t) = kde(x, Vel_inc_test[t, :, 1]); hu_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 1]);
		Gt_vtest(x, t) = kde(x, Vel_inc_test[t, :, 2]); hv_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 2]);

		traj_model_test_th, vels_model_test_th, rhos_model_test_th =
		vel_verlet(traj_test_set, vels_test_set, α_hat, β_hat, h, c_hat, g_hat, θ_test, T);
		Diff_model_test_th, Vel_inc_model_test_th =
		obtain_pred_dists(traj_model_test_th, vels_model_test_th, traj_test_set[1,:,:], vels_test_set[1,:,:], T);

		# L[ii], Lg[ii] = compute_L(Vel_inc_gt, Vel_inc_pred, G_u, G_v, hu_kde_gt, hv_kde_gt, Gt_u, Gt_v, Gt_utest, Gt_vtest,
		# 	Vel_inc_test, Vel_inc_model_test_th, Gu_test, Gv_test, hut_kde_test, hvt_kde_test, traj_test_set, traj_model_test_th, T);

		# L[ii] = compute_L(loss_method, Vel_inc_gt, Vel_inc_pred, hut_kde_test, hvt_kde_test, Gt_utest, Gt_vtest, T)
		Lg[ii] = compute_Lg(lg_method, Vel_inc_test, Vel_inc_model_test_th, hu_kde_test, hv_kde_test, Gt_utest, Gt_vtest, T)

		rot_errθ[ii] = rotational_metric_phys_inf(traj_test_set[T,:,:], vels_test_set[T,:,:], c_hat, h, α_hat, β_hat, θ_test, obtain_sph_av_A)

		println("L = ", L[ii], "  Lg = ", Lg[ii], "  rot = ", rot_errθ[ii]);
		ii += 1;
	end
	return L, Lg, rot_errθ
end


Lθ, Lgθ, rot_errθ = obtain_gen_errs_over_θ()




#------- vary T, θ = 0.8e-1; (θ from training set)

# θ_test = θ;
α_test = 1.0;
β_test = 2.0;
c_test = 12.0;
# dt = cdt * h / c;

function obtain_gen_loss_t(t_s, t_end, t_skip)
	t_range = range(t_s, t_end, step=t_skip);
	num_t_idx = size(t_range)[1];
	println("num_t_idx  = ", num_t_idx);
	Lt = zeros(num_t_idx); Lgt = zeros(num_t_idx); rot_errt = zeros(num_t_idx);
	ii = 1;
	for T in t_range
		Diff_gt, Vel_inc_gt = obtain_gt_dists(traj_gt, vels_gt, T)
		G_u(x) = kde(x, Vel_inc_gt[T, :, 1]); hu_kde_gt = obtain_h_kde(Vel_inc_gt[T, :, 1])
		G_v(x) = kde(x, Vel_inc_gt[T, :, 2]); hv_kde_gt = obtain_h_kde(Vel_inc_gt[T, :, 2])

		Gt_u(x, t) = kde(x, Vel_inc_gt[t, :, 1]); hut_kde_gt(t) = obtain_h_kde(Vel_inc_gt[t, :, 1]);
		Gt_v(x, t) = kde(x, Vel_inc_gt[t, :, 2]); hvt_kde_gt(t) = obtain_h_kde(Vel_inc_gt[t, :, 2]);

		Gt_utest(x, t) = Gt_u(x, t); Gt_vtest(x, t) = Gt_v(x, t);

		traj, vels, rhos = vel_verlet(traj_gt, vels_gt, α_hat, β_hat, h, c_hat, g_hat, θ, T);
		Diff_pred, Vel_inc_pred = obtain_pred_dists(traj, vels, traj_gt[1,:,:], vels_gt[1,:,:], T);

		Diff_test, Vel_inc_test = Diff_pred, Vel_inc_pred
		hu_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 1])
		hv_kde_test(t) = obtain_h_kde(Vel_inc_test[t, :, 2])

		traj_model_test_th, vels_model_test_th, rhos_model_test_th = traj, vels, rhos;
		Diff_model_test_th, Vel_inc_model_test_th = Diff_pred, Vel_inc_pred

		rot_errt[ii] = rotational_metric_phys_inf(traj_gt[T,:,:], vels_gt[T,:,:], c_hat, h, α_hat, β_hat, θ, obtain_sph_av_A)

		Lgt[ii] = compute_Lg(lg_method, Vel_inc_test, Vel_inc_model_test_th, hu_kde_test, hv_kde_test, Gt_utest, Gt_vtest, T)

		println("T = ", T*coarse_mult, "   L = ", Lt[ii], "   Lg = ", Lgt[ii],  "  rot = ", rot_errt[ii]);
		ii += 1;
	end
	return Lt, Lgt, rot_errt
end


# Lt, Lgt, rot_errt = obtain_gen_loss_t(5, 155, 10)
# Lt, Lgt, rot_errt = obtain_gen_loss_t(2, 48, 3) #coarse = 3;
Lt, Lgt, rot_errt = obtain_gen_loss_t(1, 16, 1) #coarse = 3;



#-----------Output data

function save_output_data(data, path)
    npzwrite(path, data)
end


save_output_data(Lθ, "./generalization_data_coarse$(coarse_mult)/loss_theta_$(method)_$(loss_method).npy")
save_output_data(Lgθ, "./generalization_data_coarse$(coarse_mult)/loss_g_theta_$(method)_$(loss_method).npy")
save_output_data(rot_errθ, "./generalization_data_coarse$(coarse_mult)/rot_error_theta_$(method)_$(loss_method).npy")

save_output_data(Lt, "./generalization_data_coarse$(coarse_mult)/loss_t_$(method)_$(loss_method).npy")
save_output_data(Lgt, "./generalization_data_coarse$(coarse_mult)/loss_g_t_$(method)_$(loss_method).npy")
save_output_data(rot_errt, "./generalization_data_coarse$(coarse_mult)/rot_error_t_$(method)_$(loss_method).npy")
