#PACKAGES -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

using Optim, Plots, DelimitedFiles, LinearAlgebra, Random, StatsBase, FiniteDifferences, LaTeXStrings , EasyFit, Printf, FFTW, Pkg, Noise, Clustering, Dierckx, BSplineKit, MultivariateStats, Flux, Combinatorics, Bigsimr, DataFrames, JLD, Base.Threads

#FUNCTIONS -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#Functions for optimization and compute expectation of the cost function
cd(@__DIR__)#to go back to the directory of the script
include("../../MyFunctions/myFunctions_NeurIPS_2025.jl")

#Loading data ----------------------------------------------------------------------------------------------------------------------------------

#Folder
cd(@__DIR__)#to go back to the directory of the script
folder_location = "../../Data_for_plots/NSC_3D_Reaching_Task"

#Parameters
file_name_Data_dict = "NSC_3D_Reaching_Task.jld2"
file_path_Data_dict = folder_location * "/" * file_name_Data_dict
Data_loaded = load(file_path_Data_dict)

#Here we unfold the dictionary to extract all the variables and plot them
my_dict_Data = Data_loaded["Data_dict"]

# Iterate through the dictionary and assign values to variables
for (key, value) in my_dict_Data
    # Create a symbol from the key
    symbol_key = Symbol(key)
    # Assign the value to the symbol
    eval(:($symbol_key = $value))
end

#OPTIONS TO SAVE PLOTS ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

name_plots = ""

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = ""

dpi_value = 500 # A higher dpi value will result in a sharper image but may also increase the file size.
default(dpi=dpi_value)

#------------------- (1) Suboptimality of EC: comparison between Todorov, EC and LSC while varying the internal noise ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

plt_costs = plot(σ_η_internal_noise_levels[:], expected_cost_mom_prop_TOD_internal_noise[:], label = L"$\mathrm{TOD}$", title = "", titlefont = font(12,"Computer Modern"), xlabel = L"$\sigma_{\eta}$", ylabel = L"$\mathbb{E}[C]$", lw = 3, grid = false, color = cgrad(:Blues_5, length(σ_η_internal_noise_levels))[6], alpha = 0.8)
plot!(plt_costs, σ_η_internal_noise_levels[:], expected_cost_mom_prop_Lag_Mul_whole_internal_noise[:], label = L"$\mathrm{EC}$", lw = 3, color = cgrad(:Reds_5, length(σ_η_internal_noise_levels))[6], alpha = 0.8, yscale=:log)
plot!(plt_costs, σ_η_internal_noise_levels[:], expected_cost_mom_prop_LSC_internal_noise[:], label = L"$\mathrm{NSC}$", lw = 3, color = cgrad(:BuPu_5, length(σ_η_internal_noise_levels))[6], alpha = 0.8, yscale=:log)

plt_costs = plot(plt_costs, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), legend_background_color=:transparent, legend_foreground_color=:transparent, ylims = (10^(1.6),10^2.2))

name_and_path_plot = full_folder_save_plots * "Cost_3D_Reaching_Task" * name_plots * ".pdf" 
savefig(plt_costs, name_and_path_plot)

# plot(σ_η_internal_noise_levels[1:3], expected_cost_mom_prop_LSC_internal_noise[1:3], label = L"LSC", lw = 3, color = cgrad(:Reds_5, length(σ_η_internal_noise_levels))[end])
# plot!(σ_η_internal_noise_levels[1:3], expected_cost_mom_prop_Lag_Mul_whole_internal_noise[1:3], label = L"EC", lw = 3, color = cgrad(:BuGn_5, length(σ_η_internal_noise_levels))[end], yscale=:log)

#------------------- (2) Convergence of NSC algorithm ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

maximum_index_iteration_plot = 3*N_iterations_LSC*N_iterations_LSC_each_dir+1
plt_convergence_LSC = plot(convergence_cost_LSC_mom_prop[1,1:maximum_index_iteration_plot], lw = 3, color = cgrad(:BuPu_5, length(σ_η_internal_noise_levels))[1], xscale=:log, xlabel = L"$\mathrm{Iteration}$", ylabel = L"$\mathbb{E}[C]$")

for i in 2:length(σ_η_internal_noise_levels)
    plot!(plt_convergence_LSC, convergence_cost_LSC_mom_prop[i,1:maximum_index_iteration_plot], lw = 3, color = cgrad(:BuPu_5, length(σ_η_internal_noise_levels))[i])
end

plt_convergence_LSC = plot(plt_convergence_LSC, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, leg = :false)

name_and_path_plot = full_folder_save_plots * "Convergence_LSC_3D_Reaching_Task" * name_plots * ".pdf"
savefig(plt_convergence_LSC, name_and_path_plot)

#------------------- (3) Studying the solutions  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

color_L = cgrad(:matter, 5, categorical = true)[2]
color_M = cgrad(:matter, 5, categorical = true)[3]
color_K = cgrad(:matter, 5, categorical = true)[4]

#study determinant of M and L matrices 
det_M = zeros(length(σ_η_internal_noise_levels),T-1)
det_L = zeros(length(σ_η_internal_noise_levels),T-1)
det_K = zeros(length(σ_η_internal_noise_levels),T-1)
mean_det_L = zeros(length(σ_η_internal_noise_levels))
mean_det_M = zeros(length(σ_η_internal_noise_levels))
mean_det_K = zeros(length(σ_η_internal_noise_levels))

for i in 1:length(σ_η_internal_noise_levels)
    for t in 1:T-1
        det_M[i,t] = det(M_matrix_LSC_internal_noise[i,:,:,t])
        det_L[i,t] = det(L_matrix_LSC_internal_noise[i,:,:,t])
        det_K[i,t] = det(K_matrix_LSC_internal_noise[i,:,:,t])
    end
    mean_det_L[i] = mean(det_L[i,:])
    mean_det_M[i] = mean(det_M[i,:])
    mean_det_K[i] = mean(det_K[i,:])
end

plt_det = plot(σ_η_internal_noise_levels, mean_det_K, label = L"$P$", title = "", titlefont = font(12,"Computer Modern"), xlabel = L"$\sigma_{\eta}$", ylabel = L"$\langle \mathrm{det}(\ \cdot \ )\rangle$", lw = 3, grid = false, color = color_K)
plot!(plt_det, σ_η_internal_noise_levels, mean_det_M, label = L"$W$", lw = 3, color = color_M, yscale=:log)
plot!(plt_det, σ_η_internal_noise_levels, mean_det_L, label = L"$L$", lw = 3, color = color_L, yscale=:log)

plt_det = plot(plt_det, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), legend_background_color=:transparent, legend_foreground_color=:transparent)

name_and_path_plot = full_folder_save_plots * "Determinant_KML" * name_plots * ".pdf"
savefig(plt_det, name_and_path_plot)

#singular value spectrum 
#compute number of singular values larger than a therhsold (effective dimensioality)
function compute_effective_rank_relative(L_matrix::Array{<:Real, 4}, threshold_fraction::Float64)
    # L_matrix: [n_levels, dim_control, dim_latent, T]
    n_levels, _, _, T = size(L_matrix)
    effective_ranks = zeros(n_levels)

    for i in 1:n_levels
        ranks = zeros(T - 1)
        for t in 1:T-1
            L_t = Matrix(@view L_matrix[i, :, :, t])
            s = svdvals(L_t)
            max_s = maximum(s)
            ranks[t] = count(x -> x > threshold_fraction * max_s, s)
        end
        effective_ranks[i] = mean(ranks)
    end
    return effective_ranks
end

threshold_fraction = 0.01  # e.g., 1% of max singular value
effective_ranks_L_LSC = compute_effective_rank_relative(L_matrix_LSC_internal_noise, threshold_fraction)
effective_ranks_M_LSC = compute_effective_rank_relative(M_matrix_LSC_internal_noise, threshold_fraction)
effective_ranks_K_LSC = compute_effective_rank_relative(K_matrix_LSC_internal_noise, threshold_fraction)

plt_eff_rank = scatter(σ_η_internal_noise_levels, effective_ranks_K_LSC, label = L"$P$", xlabel = L"\sigma_{\eta}", ylabel = L"$SV_{thr} \ \ \mathrm{Count}$", marker=:circle, ms=5, lw = 2, color = color_L, markerstrokewidth = 0, grid = false, size = (600, 500))
scatter!(plt_eff_rank, σ_η_internal_noise_levels, effective_ranks_M_LSC, label = L"$W$", marker=:circle, ms=5,lw = 2, color = color_K, markerstrokewidth = 0) 
scatter!(plt_eff_rank, σ_η_internal_noise_levels, effective_ranks_L_LSC, label = L"$L$",marker=:circle, ms=5,lw = 2, color = color_M, markerstrokewidth = 0)
     
plt_eff_rank = plot(plt_eff_rank, ylims = (5.5,6), xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, legend_background_color=:transparent, legend_foreground_color=:transparent)

name_and_path_plot = full_folder_save_plots * "Effective_rank" * name_plots * ".pdf"
savefig(plt_eff_rank, name_and_path_plot)
   
#Difference between M and A+B*L-KH

model_diff = zeros(length(σ_η_internal_noise_levels))

for i in 1:length(σ_η_internal_noise_levels)
    mismatch_sum = 0.0

    for t in 1:T-1
        M_t = Matrix(M_matrix_LSC_internal_noise[i, :, :, t])
        L_t = Matrix(L_matrix_LSC_internal_noise[i, :, :, t])
        K_t = Matrix(K_matrix_LSC_internal_noise[i, :, :, t])

        A_eff = A + B * L_t - K_t * H

        Δ = M_t - A_eff
        mismatch_sum += norm(Δ)
    end
    model_diff[i] = mismatch_sum / T
end

plt_model_diff = plot(σ_η_internal_noise_levels, model_diff, lw = 3, color = cgrad(:BuPu_5, length(σ_η_internal_noise_levels))[6], alpha = 0.8, ylabel=L"$\langle ‖W_t − \tilde{W}_t‖ \rangle $", xlabel = L"$\sigma_{\eta}$")
plt_model_diff = plot(plt_model_diff, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, leg = :false)

name_and_path_plot = full_folder_save_plots * "Model_diff" * name_plots * ".pdf"
savefig(plt_model_diff, name_and_path_plot)

#------------------- (4) Studying the predictions  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
#PREDICTIONS

N_trials_MC = 10000

#in case we want to test the behaviour without control (unstability)
L_matrix = zeros(dimension_of_control, dimension_of_latent_space_ext_LSC, T-1)
M_matrix = zeros(dimension_of_latent_space_ext_LSC, dimension_of_latent_space_ext_LSC, T-1)
K_matrix = zeros(dimension_of_latent_space_ext_LSC, dimension_of_latent_space_ext_LSC, T-1)

L_matrix_EC = zeros(dimension_of_control, dimension_of_latent_space_ext_LSC, T-1)
M_matrix_EC = zeros(dimension_of_latent_space_ext_LSC, dimension_of_latent_space_ext_LSC, T-1)
K_matrix_EC = zeros(dimension_of_latent_space_ext_LSC, dimension_of_latent_space_ext_LSC, T-1)

x_vec_LSC_internal_noise = zeros(length(σ_η_internal_noise_levels), dimension_of_state, N_trials_MC, T)
z_vec_LSC_internal_noise = zeros(length(σ_η_internal_noise_levels), dimension_of_latent_space_ext_LSC, N_trials_MC, T)
mean_total_cost_LSC_internal_noise = zeros(length(σ_η_internal_noise_levels))
sem_total_cost_LSC_internal_noise = zeros(length(σ_η_internal_noise_levels))

x_vec_internal_noise_EC = zeros(length(σ_η_internal_noise_levels), dimension_of_state, N_trials_MC, T)
z_vec_internal_noise_EC = zeros(length(σ_η_internal_noise_levels), dimension_of_latent_space_ext_LSC, N_trials_MC, T)
mean_total_cost_internal_noise_EC = zeros(length(σ_η_internal_noise_levels))
sem_total_cost_internal_noise_EC = zeros(length(σ_η_internal_noise_levels))

for i in 1:length(σ_η_internal_noise_levels)

    random_seed = i

    σ_η_internal_noise = σ_η_internal_noise_levels[i]
    σ_η_internal_noise_vec = σ_η_internal_noise .*ones(dimension_of_state) #internal noise acting on the state estimate (position, velocity, force, filtered control)
    Ω_η = Diagonal(σ_η_internal_noise_vec) #Internal noise is acting only on the osberved variables (position, velocity: visual feedbacks; force: proprioception)

    L_matrix[:,:,1:T-1] .= L_matrix_LSC_internal_noise[i,:,:,:]
    M_matrix[:,:,:] .= M_matrix_LSC_internal_noise[i,:,:,:]
    K_matrix[:,:,:] .= K_matrix_LSC_internal_noise[i,:,:,:]

    L_matrix_EC[:,:,1:T-1] .= L_matrix_Lag_Mul_whole_internal_noise[i,:,:,:]
    for t in 1:T-1
        M_matrix_EC[:,:,t] .= A .+ B * L_matrix_Lag_Mul_whole_internal_noise[i,:,:,t] .- K_matrix_Lag_Mul_whole_internal_noise[i,:,:,t] * H
    end
    K_matrix_EC[:,:,:] .= K_matrix_Lag_Mul_whole_internal_noise[i,:,:,:]
    
    x_vec, z_vec, cost_over_time, mean_total_cost, sem_total_cost = get_model_predictions_NSC(random_seed, N_trials_MC, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, T, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, A, B, C, H, D, K_matrix, M_matrix, L_matrix, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix)
    
    x_vec_EC, z_vec_EC, cost_over_time_EC, mean_total_cost_EC, sem_total_cost_EC = get_model_predictions_NSC(random_seed, N_trials_MC, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, T, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, A, B, C, H, D, K_matrix_EC, M_matrix_EC, L_matrix_EC, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix)

    x_vec_LSC_internal_noise[i,:,:,:] = x_vec
    z_vec_LSC_internal_noise[i,:,:,:] = z_vec
    mean_total_cost_LSC_internal_noise[i] = mean_total_cost
    sem_total_cost_LSC_internal_noise[i] = sem_total_cost

    x_vec_internal_noise_EC[i,:,:,:] = x_vec_EC
    z_vec_internal_noise_EC[i,:,:,:] = z_vec_EC
    mean_total_cost_internal_noise_EC[i] = mean_total_cost_EC
    sem_total_cost_internal_noise_EC[i] = sem_total_cost_EC

end

mean_x_LSC = mean(x_vec_LSC_internal_noise, dims=3)
std_x_LSC = std(x_vec_LSC_internal_noise, dims=3)
std_x_LSC /= sqrt(N_trials_MC)

mean_z_LSC = mean(z_vec_LSC_internal_noise, dims=3)
std_z_LSC = std(z_vec_LSC_internal_noise, dims=3)
std_z_LSC /= sqrt(N_trials_MC)

mean_x_EC = mean(x_vec_internal_noise_EC, dims=3)
std_x_EC = std(x_vec_internal_noise_EC, dims=3)
std_x_EC /= sqrt(N_trials_MC)

mean_z_EC = mean(z_vec_internal_noise_EC, dims=3)
std_z_EC = std(z_vec_internal_noise_EC, dims=3)
std_z_EC /= sqrt(N_trials_MC)

noise_lev_idx = 2
vec_component = 2
alpha_lev = 0.8
plt_xz_LSC = plot(mean_x_LSC[noise_lev_idx,vec_component,:,:][:], ribbon = std_x_LSC[noise_lev_idx,vec_component,:,:][:], label = L"$x \ \mathrm{NSC}$", xlabel = L"$t$", ylabel = L"$\mathbb{E}[x_{2,t}], \mathbb{E}[z_{2,t}]$", lw = 3, grid = false, color = cgrad(:BuPu_5, length(σ_η_internal_noise_levels))[6], alpha = alpha_lev)
plot!(plt_xz_LSC, mean_z_LSC[noise_lev_idx,vec_component,:,:][:], ribbon = std_x_LSC[noise_lev_idx,vec_component,:,:][:], label = L"$z \ \mathrm{NSC}$", lw = 1.5, grid = false, color = :gray, alpha = alpha_lev, ylims = (0,3.5), xticks = [0,50,100], yticks = [0,1.5,3], linestyle = :dash)

plt_xz_EC = plot(mean_x_EC[noise_lev_idx,vec_component,:,:][:], ribbon = std_x_EC[noise_lev_idx,vec_component,:,:][:], label = L"$x \ \mathrm{EC}$", xlabel = L"$t$", ylabel = "", lw = 3, grid = false, color = cgrad(:Reds_5, length(σ_η_internal_noise_levels))[6], alpha = alpha_lev)
plot!(plt_xz_EC, mean_z_EC[noise_lev_idx,vec_component,:,:][:], ribbon = std_x_EC[noise_lev_idx,vec_component,:,:][:], label = L"$z \ \mathrm{EC}$", lw = 1.5, grid = false, color = :gray, alpha = alpha_lev, ylims = (0,1.1), xticks = [0,50,100], yticks = [0,0.5,1], linestyle = :dash)

plt_joined_xz = plot(plt_xz_LSC,plt_xz_EC, layout = (1,2), size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), grid = false, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, legend_background_color=:transparent, legend_foreground_color=:transparent)

name_and_path_plot = full_folder_save_plots * "x_z_comparison" * name_plots * ".pdf"
savefig(plt_joined_xz, name_and_path_plot)


#for colorbar 
colorbar_only = heatmap(reshape(1:100, 100, 1), color = cgrad(:BuPu_5, 7), colorbar = true,
xticks = false,
yticks = false,
colorbar_ticks = false,
axis = false,
framestyle = :none)

name_and_path_plot = full_folder_save_plots * "useful_for_colorbar" * name_plots * ".pdf"
savefig(colorbar_only, name_and_path_plot)

