#PACKAGES -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

using Optim, Plots, DelimitedFiles, LinearAlgebra, Random, StatsBase, FiniteDifferences, LaTeXStrings , EasyFit, Printf, FFTW, Pkg, Noise, Clustering, Dierckx, BSplineKit, MultivariateStats, Flux, Combinatorics, Bigsimr, DataFrames, JLD, Base.Threads

#FUNCTIONS -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#Functions for optimization and compute expectation of the cost function
cd(@__DIR__)#to go back to the directory of the script
include("../../MyFunctions/myFunctions_NeurIPS_2025.jl")

#Loading data ----------------------------------------------------------------------------------------------------------------------------------

#Folder
cd(@__DIR__)#to go back to the directory of the script
folder_location = "../../Data_for_plots/EC_1D_Reaching_Task"

#Parameters
file_name_Data_dict = "EC_1D_Reaching_Task.jld2"
file_path_Data_dict = folder_location * "/" * file_name_Data_dict
Data_loaded = load(file_path_Data_dict)

#Here we unfold the dictionary to extract all the variables and plot them
my_dict_Data = Data_loaded["Data_dict"]

# Iterate through the dictionary and assign values to variables
for (key, value) in my_dict_Data
    # Create a symbol from the key
    symbol_key = Symbol(key)
    # Assign the value to the symbol
    eval(:($symbol_key = $value))
end

#OPTIONS TO SAVE PLOTS ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

name_plots = ""

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = ""

dpi_value = 500 # A higher dpi value will result in a sharper image but may also increase the file size.
default(dpi=dpi_value)

#------------------- (1) Convergence of the algorithm ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

maximum_index_iteration_plot = N_iterations_Lagrange + 1
plt_convergence_EC = plot(cost_mom_prop_Lag_Mul_whole[1:maximum_index_iteration_plot], lw = 3, color = cgrad(:BuGn_5)[3], xscale=:log, xlabel = L"$\mathrm{Iteration}$", ylabel = L"$\mathbb{E}[C]$")

plt_convergence_EC = plot(plt_convergence_EC, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, leg = :false)

name_and_path_plot = full_folder_save_plots * "Convergence_EC_1D_Reaching_Task" * name_plots * ".pdf"
savefig(plt_convergence_EC, name_and_path_plot)

#------------------- (2) Comparison between solutions ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

line_styles = [:dashdot, :solid, :dash, :dot]

comp_idx = 1
linestyle = line_styles[mod1(comp_idx, length(line_styles))]
label_string = latexstring("\$ L_{$comp_idx,t}\$")
plt_GD_solutions = plot(L_GD[1, comp_idx, :], xlabel = L"$t$", ylabel = "", lw = 2.5, color = cgrad(:BuGn_5, 7)[6], linestyle = linestyle, alpha = 0.8, label = "", yticks = [])
plot!(plt_GD_solutions, [NaN], [NaN], lw = 0.5, linestyle = linestyle,color = cgrad(:BuGn_5, 7)[6], label = label_string)

plt_Lag_solutions = plot(L_matrix_Lag_Mul_whole[1, comp_idx, :], xlabel = L"$t$", ylabel = L"$L_{i,t}$", color = cgrad(:Reds_5, 7)[6], linestyle = linestyle, lw = 2.5, alpha = 0.8, label = "")
plot!(plt_Lag_solutions, [NaN], [NaN], lw = 0.5, linestyle = linestyle,color = cgrad(:Reds_5, 7)[6], label = label_string)

for comp_idx in 2:dimension_of_state
    linestyle = line_styles[mod1(comp_idx, length(line_styles))]
    label_string = latexstring("\$ L_{$comp_idx, t}\$")
    plot!(plt_GD_solutions, L_GD[1, comp_idx, :], color = cgrad(:BuGn_5, 7)[6], linestyle = linestyle, lw = 3, alpha = 0.8, label = "")
    plot!(plt_GD_solutions, [NaN], [NaN], lw = 0.5, linestyle = linestyle,color = cgrad(:BuGn_5, 7)[6], label = label_string)

    plot!(plt_Lag_solutions, L_matrix_Lag_Mul_whole[1, comp_idx, :], color = cgrad(:Reds_5, 7)[6], linestyle = linestyle, lw = 3, alpha = 0.8, label = "")
    plot!(plt_Lag_solutions, [NaN], [NaN], lw = 0.5, linestyle = linestyle,color = cgrad(:Reds_5, 7)[6], label = label_string)

end

plt_GD_solutions = plot(plt_GD_solutions, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, legend_background_color=:transparent, legend_foreground_color=:transparent)
plt_Lag_solutions = plot(plt_Lag_solutions, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, legend_background_color=:transparent, legend_foreground_color=:transparent)

plt_sol_comparison = plot(plt_Lag_solutions, plt_GD_solutions, layout = (1, 2), size = (1300, 600), title = "", titlefont = font(12,"Computer Modern")) 

name_and_path_plot = full_folder_save_plots * "Solution_L_EC_GD" * name_plots * ".pdf"
savefig(plt_sol_comparison, name_and_path_plot)