

"""
2D fsa with forwardiff AD

learning less informed --> physics informed SPH
"""


using NPZ, Plots, Flux, QuadGK
using ForwardDiff, Statistics
using Flux.Optimise: update!
using Flux.Optimise: Optimiser
using Flux.Losses
#import Pkg; Pkg.add("BSON")
using BSON: @save


const T = 10				#number of time steps in integration
const coarse_mult = 1;  #coarse graining in time (number of dts to skip)
const n_itrs = 4002			#number of iteration
const vis_rate = 1;		#sampling frequency for output
const lr = 5e-2 			#initial lr (later adapted with ADAM)
const mag = 1.0			#Amplitude of external forcing
# τ = T;			#Fixed tau method
const r = 5.0;			#number of smoothing (r*hkde) lengths for determining bounds of integration in KL
const h_kde = 0.9;
const nb = "all";				# number of samples in batching
const n_int = 140;
const t_start = 3205;
const window = 0.15;
const height = 9;
const t_decay = round(Int, 0.9*n_itrs);			#time decay begins

println("height = ", height)

# sensitivity method:
sens_method = "forward"
# sens_method = "adjoint"

# loss_method = "l2"
# loss_method = "kl"
# loss_method = "lf"
# loss_method = "kl_t"
# loss_method = "kl_l2_t"
# loss_method = "kl_one_dist"
# loss_method = "kl_t_one_dist"
# loss_method = "lf"
loss_method = "kl_lf_t"


# method = "nn_sum"
# method = "nn_sum2"
method = "rot_inv"
# method = "node"
# method = "eos_nn"
# method = "grad_p"
# method = "Wnn"


println(sens_method, " ", loss_method, " ", method)
println("coarse_mult = ", coarse_mult)


# IC = "Vrandn"
IC = "s_hit"

using Dates
println(Dates.now())

#---KDE smoothing kernel for density estimation (both for gt and pred data)
K(x) = 1/(sqrt(2*pi))*exp(-x^2/2) #guassian kernel (best results so far)
# K(x) = maximum([1 - abs(x), 0])     #triangle (produces frequency polygon)
# K(x) = 3/4*maximum([1 - x^2, 0])   #Epanechnikov
K_prime(x) = ForwardDiff.derivative(x -> K(x), x) #(used in computing ∇L)


#----- Load data
const c_gt = 12.0;
const h = 0.2;
const α = 1.0;
const β = 2.0;
const g = 7.0;
const θ = 8e-2;

const cdt = 0.4;
const dt = coarse_mult * cdt * h / c_gt;

const traj_gt = npzread("../data/traj_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")[t_start:coarse_mult:end, :, :];
const vels_gt = npzread("../data/vels_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")[t_start:coarse_mult:end, :, :];
const rhos_gt = npzread("../data/rhos_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.08_AV_neg_rel_ke.npy")[t_start:coarse_mult:end, :];

const N = size(traj_gt)[2];
const D = size(traj_gt)[3];
const m = (2. * pi)^D / N; #so that ρ₀ = 1;


#--- Load files

if (method == "nn_sum2")
	include("./sensitivities_nnsum2.jl")
end

if (method == "nn_sum")
	include("./sensitivities_nnsum.jl")
end

if (method == "rot_inv")
	include("./sensitivities_rot_nn.jl")
end

if (method == "node")
	include("./sensitivities_node_sort.jl")
end

if (method == "eos_nn")
	include("./sensitivities_eos_nn.jl")
end

if (method == "grad_p")
	include("./sensitivities_grad_p.jl")
end

if (method == "Wnn")
	include("./sensitivities_Wnn.jl")
end

include("./gen_outputs_2d.jl")
include("./loss_functions_2d.jl")
include("./kde_G.jl")
include("./integrators_utils.jl")

println("n_params = ", n_params)


#------- Load test set data: Larger Reynolds number flow

const traj_test_set = npzread("../data/traj_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.16_AV_neg_rel_ke.npy")[t_start:coarse_mult:end, :, :];
const vels_test_set = npzread("../data/vels_N1024_T6001_ts5801_h0.2_s_hit_cdt0.4_c12.0_α1.0_β2.0_θ0.16_AV_neg_rel_ke.npy")[t_start:coarse_mult:end, :, :];

const θ_test = 1.6e-1;
const α_test = 1.0;
const β_test = 2.0;
const c_test = 12.0;

include("./sph_simulator.jl")

Diff_test, Vel_inc_test =
obtain_pred_dists(traj_test_set[1:(T+1),:,:], vels_test_set[1:(T+1),:,:], traj_test_set[1,:,:], vels_test_set[1,:,:]);

Gu_test(x) = kde(x, Vel_inc_test[T, :, 1]); hu_kde_test = obtain_h_kde(Vel_inc_test[T, :, 1])
Gv_test(x) = kde(x, Vel_inc_test[T, :, 2]); hv_kde_test = obtain_h_kde(Vel_inc_test[T, :, 2])


Vf_gt,d1,d2,d3,d4 = obtain_interpolated_velocity_over_τ(traj_gt, vels_gt, rhos_gt, T)


#------------------- Training algorithm

function training_algorithm(n_itrs, vis_rate, T, p_h, c, h, α, β, θ)
	L_itr = zeros(round(Int, n_itrs/vis_rate))
	Vel_inc_pred_k = zeros(n_itrs, T, N);
	rot_QF = zeros(round(Int, n_itrs/vis_rate))
	rot_RF = zeros(round(Int, n_itrs/vis_rate))
	Lg_itr = zeros(round(Int, n_itrs/vis_rate))
	rho_data = 0.9:0.005:1.1
	P_nn = zeros(round(Int, n_itrs/vis_rate), size(rho_data)[1]);
	P_gt = Pres.(rho_data, c_gt, g);
	ii = 1;
	opt = ADAM(lr); #optimizer for gradient descent
	# opt = Optimiser(ExpDecay(lr, 0.1, t_decay, 1e-4), ADAM(lr))

	for k ∈ 1 : n_itrs
		if sens_method=="forward"
        	traj_pred, vels_pred, rhos_pred, HT = simultaneous_integration(p_h, c, h, α, β, θ, T)
			Diff_pred, Vel_inc_pred = obtain_pred_dists(traj_pred, vels_pred, traj_gt[1,:,:], vels_gt[1,:,:])
			∇L = compute_∇L(loss_method, Vel_inc_gt, Vel_inc_pred, traj_pred, traj_gt, vels_pred, vels_gt, rhos_pred, HT);
		end
		if sens_method=="adjoint"
			traj_pred, vels_pred, rhos_pred, λT, ∂F_pT = dual_adjoint_integration(p_h, c, h, α, β, θ, T)
			Diff_pred, Vel_inc_pred = obtain_pred_dists(traj_pred, vels_pred, traj_gt[1,:,:], vels_gt[1,:,:])
			∇L = compute_adjoint_∇L(λT, ∂F_pT)
		end
		update!(opt, p_h, ∇L)
		Vel_inc_pred_k[k, :, :] = Vel_inc_pred[:, :, 1];
        if mod(k, vis_rate) == 0
			if method == "eos_nn"
				Pnn_comp(ρ) = re(p_h)([ρ])[1];
				P_nn[k, :] = Pnn_comp.(rho_data);
				compare_eos(p_h);
			end
				#	theta test set
			traj_model_test_th, vels_model_test_th, rhos_model_test_th =
			vel_verlet_NN(traj_test_set[1,:,:], vels_test_set[1,:,:], p_h, α, β, h, c, g, θ_test, T);
			Diff_model_test_th, Vel_inc_model_test_th =
			obtain_pred_dists(traj_model_test_th, vels_model_test_th, traj_test_set[1,:,:], vels_test_set[1,:,:]);
			rot_QF[ii], rot_RF[ii] = rotational_metric(traj_gt[1,:,:], vels_gt[1,:,:], p_h, c, h, α, β, θ, obtain_sph_AV_A)
			L_itr[ii], Lg_itr[ii] = compute_L(Vel_inc_gt, Vel_inc_pred, G_u, G_v, hu_kde_gt, hv_kde_gt,
					Vel_inc_test, Vel_inc_model_test_th, Gu_test, Gv_test, hu_kde_test, hv_kde_test,
					traj_gt, traj_pred, vels_pred, rhos_pred)
			# if abs(L_itr[ii]) < 5e-8
			# 	opt = Optimiser(ExpDecay(lr, 0.1, 500, 1e-4), ADAM(lr))
			# end
			println("Itr  = ", k, "   Loss = ", L_itr[ii])
			ii += 1;
			save_output_data(p_h, "./$(data_out_path)/params_intermediate.npy")
        end
		comparing_Gu(G_u, Vel_inc_pred, "train", θ, window)
    end
	# animate_learning(n_itrs, rho_data, P_gt, P_nn)
	return L_itr, Lg_itr, rot_QF, rot_RF, Vel_inc_pred_k, p_h
end


L_out, Lg_out, rot_QF, rot_RF, Vel_inc_pred_k, p_fin  =
training_algorithm(n_itrs, vis_rate, T, p_hat, c_gt, h, α, β, θ)

L_out = abs.(L_out)
Lg_out = abs.(Lg_out)

println(Dates.now())

save_output_data(L_out, "./$(data_out_path)/loss.npy")
save_output_data(Lg_out, "./$(data_out_path)/gen_loss.npy")
save_output_data(rot_QF, "./$(data_out_path)/rot_error_qf.npy")
save_output_data(rot_RF, "./$(data_out_path)/rot_error_rf.npy")
# save_output_data(Vel_inc_pred_k, "./$(data_out_path)/vel_inc_pred_k.npy")
save_output_data(p_fin, "./$(data_out_path)/params_final.npy")


#saving the neural network archetecture:
@save "./$(data_out_path)/NN_model.bson" NN


plot_loss_itr()
plot_Lg_itr()
plot_rot_itr()
animate_Gu_fixt(n_itrs, Vel_inc_pred_k, window)




#---------- Output gen errors
println("************** Generalization error ******************")
println("  ")
println("c = ", c_gt, "  α = ", α, "  β = ", β, "  θ_gt = ", θ, "  θ_t = ",
θ_test, "  T = ", T, "  n_iters = ", n_itrs)

const c = c_gt

traj_model_test, vels_model_test, rhos_model_test =
vel_verlet_NN(traj_test_set[1,:,:], vels_test_set[1,:,:], p_fin, α, β, h, c, g, θ_test, T);

Diff_model_test, Vel_inc_model_test =
obtain_pred_dists(traj_model_test, vels_model_test, traj_test_set[1,:,:], vels_test_set[1,:,:]);

comparing_Gu(Gu_test, Vel_inc_model_test, "test", θ_test, window);

L_test = Ikl_fixed_τ(Vel_inc_test, Vel_inc_model_test, Gu_test, Gv_test, hu_kde_test, hv_kde_test)
println("L_test = ", L_test)

#save txt file containing generalization error
open("./$(data_out_path)/gen_error.txt","a") do io
    println(io,"c = ", c_gt, "  α = ", α, "  β = ", β, "  θ_gt = ", θ, "  θ_t = ",
        θ_test, "  T = ", T, "  n_iters = ", n_itrs)
        println(io, " ")
        println(io, "L_test = ", L_test)
end
