#plotting learned W and P(ρ)
using Flux, Statistics, LaTeXStrings
using ForwardDiff, Plots
using Flux.Losses
using Zygote
using NPZ

using BSON: @load

method = "Wnn"

if method =="eos_nn"
	height = 8;
	# nn_data_dir = "./semi_inf_data2/output_data_l2_eos_nn_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt8_coarse5"
	nn_data_dir = "./semi_inf_data3/output_data_l2_eos_nn_s_hit_itr2000_lr0.05_T10_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt8_coarse1"
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	# include("sensitivities_eos_nn.jl")
end

if method =="Wnn"
	height = 10;
	# h = 0.2;
	# nn_data_dir = "./semi_inf_data2/output_data_l2_eos_nn_s_hit_itr2000_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint90_ts3205_hgt8_coarse5"
	nn_data_dir = "./output_data_forward_kl_Wnn_s_hit_itr3004_lr0.05_T3_D2_N1024_c12.0_α1.0_β2.0_h0.2_nint140_ts3205_hgt10_coarse1"
	@load "$(nn_data_dir)/NN_model.bson" NN
	println(NN)
	p_, re = Flux.destructure(NN)   #flatten nn params
	n_params = size(p_)[1]
	params_path = "$(nn_data_dir)/params_intermediate.npy"
	p_fin = npzread(params_path)
	# include("sensitivities_Wnn.jl")
end


function compare_eos(p_hat)
    gr(size=(500,500))
    Pnn_comp(ρ) = re(p_hat)([ρ])[1]
    rho_data = 0.90:0.02:1.10
    P_gt = Pres.(rho_data, c_gt, g)
    P_nn = Pnn_comp.(rho_data)

    plt = plot(rho_data, P_gt, label="P_gt", color="blue", linewidth = 2.5)
    Plots.scatter!(rho_data, P_nn, label="P_nn", color="black", linewidth = 2)

	title!(L"\textrm{Learning EoS WCSPH}", titlefont=18)
    xlabel!(L"\rho", xtickfontsize=12, xguidefontsize=20)
    ylabel!(L"P(\rho)", ytickfontsize=12, yguidefontsize=20)

    display(plt)
	# savefig(plt, "./figures/EOS_$(sens_method)_$(loss_method)_$(method)_$(file_title)_height$(height).png")
end


function plot_Wnn(p_hat)
    gr(size=(800,800))
    println("*************** plot Wnn ******************")

	h = 0.2;
	sigma = (10. / (7. * pi * h * h));
	# sigma = 1/(pi*h^3)  #3D normalizing factor

	function W(r, h)
	  q = r / h;   if (q > 2.)   return 0.;   end
	  if (q > 1.)   return (sigma * (2. - q)^3 / 4.);   end
	  return (sigma * (1. - 1.5 * q * q * (1. - q / 2.)));
	end

	W_nn(r) = re(p_hat)([r])[1]

	plt = plot(r -> W(r, h), 0, 2*h)
	plot!(r -> W_nn(r), 0, 2*h)

    title!(L"\textrm{Gt W vs Learned W}", titlefont=16)
    xlabel!(L"\textrm{r}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"W(r)", ytickfontsize=10, yguidefontsize=14)
    display(plt)
    # path_rot = "./gen_error_figures/bar_plot_rot_err_coarse$(coarse).png"
    # savefig(plt_rot, path_rot)
end

if method == "Wnn"
	plot_Wnn(p_fin)
end

if method == "eos_nn"
	compare_eos(p_fin)
end
