

using Plots, NPZ, LaTeXStrings
ENV["GKSwstype"]="100" #set env variable for UAHPC


file_title = "$(IC)_itr$(n_itrs)_lr$(lr)_T$(T)_D$(D)_N$(N)_c$(c_gt)_α$(α)_β$(β)_h$(h)_nb$(nb)_nint$(n_int)_ts$(t_start)_hgt$(height)"


function save_output_data(data, path)
    npzwrite(path, data)
end


"""

=========================== Prob Loss Plots =========================================

"""


function comparing_Gu(G_u, Vel_inc_pred, path, θ_in, width=0.42)
    gr(size=(500,500))
    x_s = -width
    x_e = width
    G_pred(x) = kde(x, Vel_inc_pred[T, :, 1])
    plt = plot(x->G_u(x), x_s, x_e, label=L"G(\tau, d)",
                color="indigo", linewidth = 2.5)
    plot!(x->G_pred(x), x_s, x_e, marker=:x, markersize=4, color="forestgreen",
          markercolor = :black, label=L"G_{\theta}(\tau, d)",
          linestyle=:dash, linewidth = 2.5)

    title!(L"\textrm{Comparing - } G(\tau,z) \textrm{ - and - } \hat{G}_{\theta}(\tau,z)", titlefont=18)
    xlabel!(L"\textrm{Velocity - increment}", xtickfontsize=12, xguidefontsize=20)
    ylabel!(L"G_u \textrm{ - distribution}", ytickfontsize=12, yguidefontsize=20)

    display(plt)
    if path == "train"
        out_path = "./figures/$(method)_Gtrain_kde_$(file_title)_θ$(θ_in).png"
    elseif path == "test"
        out_path = "./figures/$(method)_Gtest_kde_$(file_title)_θ$(θ_in).png"
    end
    savefig(plt, out_path)
end



function animate_Gu_fixt(n_itrs, Vel_inc_pred_k, width=0.42)
    file_out = "./anim/$(method)_Gu_$(file_title).mp4"
    x_s = -width
    x_e = width
    Gu_pred(k, x) = kde(x, Vel_inc_pred_k[k, T, :])
    gr(size=(600,600))
    println("**************** Animating ***************")
    anim = @animate for i ∈ 1 : n_itrs
        plt = plot(x->G_u(x), x_s, x_e, label=L"G(\tau, d)",
                    color="indigo", linewidth = 2.5)
        plot!(x->Gu_pred(i, x), x_s, x_e, marker=:x, markersize=4,
              color="forestgreen", markercolor = :black,
              label=L"G_{\theta}(\tau, d)", linestyle=:dash, linewidth = 2.5)
    end
    title!(L"\textrm{Comparing - } G(\tau,z) \textrm{ - and - } \hat{G}_{\theta}(\tau,z)", titlefont=18)
    xlabel!(L"\textrm{Velocity - increment}", xtickfontsize=12, xguidefontsize=20)
    ylabel!(L"G_u \textrm{ - distribution}", ytickfontsize=12, yguidefontsize=20)

    gif(anim, file_out, fps = ceil(Int, n_itrs/10))
    println("**************** Animation Complete ***************")
end



"""

=========================== Loss Plots =========================================

"""


function plot_loss_itr()
    println(" ===== Plotting Loss =====")
    gr(size=(500,500))
    xs = 1 : vis_rate : (vis_rate * size(L_out)[1])
    plt = plot(xs, L_out, label="KL", color="blue", yaxis=:log, linewidth = 2.5)

    title!(L"\textrm{WC SPH AV:  N = } %$N \textrm{, time steps = } %$T", titlefont=16)
    xlabel!(L"\textrm{Iterations}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"\textrm{Loss}", ytickfontsize=10, yguidefontsize=16)

    display(plt)
    savefig(plt, "./figures/loss_$(method)_$(file_title).png")
end


function plot_Lg_itr()
    println(" ===== Plotting Generalization Loss =====")
    gr(size=(500,500))
    xs = 1 : vis_rate : (vis_rate * size(Lg_out)[1])
    plt = plot(xs, Lg_out, label="KL", color="blue", yaxis=:log, linewidth = 2.5)

    title!(L"\textrm{WC SPH AV: N = } %$N \textrm{, time steps = } %$T", titlefont=16)
    xlabel!(L"\textrm{Iterations}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"\textrm{Generalization - Error}", ytickfontsize=10, yguidefontsize=16)

    display(plt)
    savefig(plt, "./figures/gen_loss_$(method)_$(file_title).png")
end

function plot_Lg2_itr()
    println(" ===== Plotting Generalization Loss =====")
    gr(size=(500,500))
    xs = 1 : vis_rate : (vis_rate * size(Lg2_out)[1])
    plt = plot(xs, Lg2_out, label="KL", color="blue", yaxis=:log, linewidth = 2.5)

    title!(L"\textrm{WC SPH AV: N = } %$N \textrm{, time steps = } %$T", titlefont=16)
    xlabel!(L"\textrm{Iterations}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"\textrm{Generalization Error}", ytickfontsize=10, yguidefontsize=16)

    display(plt)
    savefig(plt, "./figures/gen2_loss_$(method)_$(file_title).png")
end



function plot_rot_itr()
    println(" ===== Plotting rotation error =====")
    gr(size=(500,500))
    xs = 1 : vis_rate : (vis_rate * size(rot_QF)[1])

    plt = plot(xs, rot_QF, label="||QF - F(QY)||", color="blue", yaxis=:log, linewidth = 2.25)
    title!(L"\textrm{WCSPH: Rotational - error}", titlefont=16)
    xlabel!(L"\textrm{Iterations}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"||QF - F(QY)||_2", ytickfontsize=10, yguidefontsize=16)

    plt2 = plot(xs, rot_RF, label="||RF - F(RY)||", color="blue", yaxis=:log, linewidth = 2.25)
    title!(L"\textrm{WCSPH: Rotational - error}", titlefont=16)
    xlabel!(L"\textrm{Iterations}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"||RF(Y) - F(RY)||_2", ytickfontsize=10, yguidefontsize=16)

    display(plt)
    display(plt2)
    savefig(plt, "./figures/Qf_error_$(method)_$(file_title).png")
    savefig(plt2, "./figures/Rf_error_$(method)_$(file_title).png")
end

function plot_T()
    println(" ===== Plotting Translation error =====")
    gr(size=(500,500))
    xs = 1 : vis_rate : (vis_rate * size(Tx)[1])
    plt = plot(xs, Tx, label="||QF - F(QY)||", color="blue", yaxis=:log, linewidth = 2.25)
    plt2 = plot(xs, Ty, label="||RF - F(RY)||", color="blue", yaxis=:log, linewidth = 2.25)

    title!(L"\textrm{WCSPH: Translational - error}", titlefont=16)
    xlabel!(L"\textrm{Iterations}", xtickfontsize=10, xguidefontsize=16)
    ylabel!(L"||F(X-s) - F(X)||_2", ytickfontsize=10, yguidefontsize=16)

    display(plt)
    display(plt2)
    savefig(plt, "./figures/Tx_error_$(method)_$(file_title).png")
    savefig(plt2, "./figures/Ty_error_$(method)_$(file_title).png")
end
