#PACKAGES -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

using Optim, Plots, DelimitedFiles, LinearAlgebra, Random, StatsBase, FiniteDifferences, LaTeXStrings , EasyFit, Printf, FFTW, Pkg, Noise, Clustering, Dierckx, BSplineKit, MultivariateStats, Flux, Combinatorics, Bigsimr, DataFrames, JLD, Base.Threads, KernelDensity, Distributions, Statistics

#FUNCTIONS -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#Functions for optimization and compute expectation of the cost function
cd(@__DIR__)#to go back to the directory of the script
include("../MyFunctions/myFunctions_NeurIPS_2025.jl")

#TASK PARAMETERS -- Control of Neural Population Activity -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

println("\n------- NEW RUN ---------------------------------------------------- \n")

dimension_of_state = 100
dimension_of_control = dimension_of_state
dimension_of_observation = dimension_of_state
dimension_of_latent_space_ext_LSC = dimension_of_state 

T = 50 #duration of the task (number of time steps)

σ_ξ = 0.5 
σ_ω_add_sensory_noise = 0.0 
σ_ρ_mult_sensory_noise = 0.0 
σ_ϵ_mult_control_noise = 0.0 
σ_η_internal_noise = 0.2
σ_η_internal_noise_vec = σ_η_internal_noise .*ones(dimension_of_state)

Ω_ξ = Diagonal(σ_ξ .* ones(dimension_of_state))
Ω_ω = Diagonal(σ_ω_add_sensory_noise .* ones(dimension_of_state))
Ω_η = Diagonal(σ_η_internal_noise_vec) 

#Initial conditions (mean state and state estimate, and their covariances) -- Note that initial state and state estimate are considered to be uncorrelated at t=1 (initial time step)
scale_target = 10.0
x_1_mean = scale_target .* randn(dimension_of_state) #meters. 
z_1_mean = zeros(dimension_of_latent_space_ext_LSC)
z_1_start_scaling = σ_η_internal_noise #we set it equal to the internal noise level!
z_1_mean = z_1_start_scaling*randn(dimension_of_latent_space_ext_LSC)

#State covariance 
Σ_1_x = Diagonal(zeros(dimension_of_state)) #intial covariance of the state
#State estimate covariance --- NOTE THAT THESE VARIABLES HAVE TO BE ZERO (OTHERWISE CHANGE THE NSC APPROACH ACCORDINGLY!)
Σ_1_z = Diagonal(zeros(dimension_of_latent_space_ext_LSC)) #intial covariance of the state estimate

#Matrices of the problem 
#Define the dynamical system
g_A = 1.1
A = g_A / sqrt(dimension_of_state) * randn(dimension_of_state, dimension_of_state)
# ρ = maximum(abs.(eigvals(A)))
# println("Spectral radius of A: ", ρ)

H = Matrix(I, dimension_of_observation, dimension_of_state)
B = Matrix(I, dimension_of_state, dimension_of_control)

C = σ_ϵ_mult_control_noise .* B
D = σ_ρ_mult_sensory_noise .* H

R_matrix = zeros(dimension_of_control, dimension_of_control, T)
r_val = 0.001
for i in 1:T-1
    R_matrix[:,:,i] .=  Diagonal(r_val .* ones(dimension_of_control))
end

Q_matrix = zeros(dimension_of_state, dimension_of_state, T)
q_val_final = 0.1
q_val_trial = 0.001

for i in 2:T-1
    Q_matrix[:,:,i] .= q_val_trial .* Diagonal(ones(dimension_of_state))
end

Q_matrix[:,:,end] .= q_val_final .* Diagonal(ones(dimension_of_state))

println("\n------- Neural Space Control (NSC), p = $dimension_of_latent_space_ext_LSC ---------------------------\n")

K_matrix = zeros(dimension_of_latent_space_ext_LSC, dimension_of_observation, T-1)
M_matrix = zeros(dimension_of_latent_space_ext_LSC, dimension_of_latent_space_ext_LSC, T-1)
L_initial_NSC = zeros(dimension_of_control, dimension_of_latent_space_ext_LSC, T-1)

#Initial conditions for the parameters to be optimized -- same as optimal condition only if dimension_of_latent_space_ext_LSC = dimension_of_control
g_K = 0.3
g_M = 0.9
K = g_K / sqrt(dimension_of_state) * randn(dimension_of_state, dimension_of_state)
M = g_M / sqrt(dimension_of_state) * randn(dimension_of_state, dimension_of_state)

l_0 = -0.1

for i in 1:T-1
    
    K_matrix[:,:,i] .= K[:,:]
    M_matrix[:,:,i] .= M[:,:]
    L_initial_NSC[:,:,:] .= l_0 .* Matrix(I, dimension_of_state, dimension_of_state)

end

N_iterations_LSC = 20

#Lagrange multipliers optimization
K_opt, M_opt, L_opt, cost_mom_prop_NSC, cost_trace_formula_NSC = Optimal_EXTENDED_Latent_Space_ONLY_OUTPUT_WEIGHTS(N_iterations_LSC, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_initial_NSC, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

#here we optimize both K and L!
# N_iterations_LSC_each_dir = 1
# K_opt, M_opt, L_opt, cost_mom_prop_LSC_high_lat_space, cost_trace_formula_LSC_high_lat_space = Optimal_EXTENDED_Latent_Space_ONLY_INPUT_and_OUTPUT_WEIGHTS(N_iterations_LSC, N_iterations_LSC_each_dir, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_initial_NSC, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

expected_cost_mom_prop_NSC = cost_mom_prop_NSC[end]
expected_cost_using_trace_formula_NSC = cost_trace_formula_NSC[end]

println("- Expected cost using moment propagation (NSC): $expected_cost_mom_prop_NSC\n")
println("- Expected cost using trace formula (NSC):$expected_cost_using_trace_formula_NSC\n")

#cost without control!
# L_matrix = zeros(dimension_of_control, dimension_of_latent_space_ext_LSC, T-1)
# #L_matrix[:,:,:] .= L_initial_NSC[:,:,:]
# cost_mom_prop_no_control = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, K_opt, M_opt, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
# cost_trace_formula_no_control = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_opt, K_opt, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space_ext_LSC, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

#PREDICTIONS ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

N_trials_MC = 1000
random_seed = 1234

L_matrix = zeros(dimension_of_control, dimension_of_latent_space_ext_LSC, T-1)

L_matrix[:,:,1:T-1] .= L_opt[:,:,:] #in case we want to test the behaviour without control (unstability) comment this line

x_vec, z_vec, cost_over_time, mean_total_cost, sem_total_cost = get_model_predictions_NSC(random_seed, N_trials_MC, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, T, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, A, B, C, H, D, K_opt, M_opt, L_matrix, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix)

# println("Mean total cost: ", mean_total_cost)
# println("Standard error: ", sem_total_cost)

x_norm_value = zeros(N_trials_MC, T)
z_norm_value = zeros(N_trials_MC, T)
control_norm_over_time = zeros(N_trials_MC, T)
control_over_time = zeros(dimension_of_control, N_trials_MC, T)

#compute control over time
for i in 1:N_trials_MC
    for t in 1:T-1
        control_norm_over_time[i,t] = norm(L_opt[:,:,t]*z_vec[:,i,t])
        control_over_time[:,i,t] = L_opt[:,:,t]*z_vec[:,i,t]
    end
end
#compute norm of x and z over time
for i in 1:N_trials_MC
    for t in 1:T
        x_norm_value[i,t] = norm(x_vec[:,i,t])
        z_norm_value[i,t] = norm(z_vec[:,i,t])
    end
end
x_mean_norm = mean(x_norm_value, dims=1)
z_mean_norm = mean(z_norm_value, dims=1)
control_norm_mean = mean(control_norm_over_time, dims=1)
control_mean = mean(control_over_time, dims=2)

# Compute standard deviations across trials
x_std = std(x_norm_value, dims=1)
z_std = std(z_norm_value, dims=1)
control_std = std(control_norm_over_time, dims=1)

#PLOTTING ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
dpi_value = 500 
default(dpi=dpi_value)

time = 0:T-1

# Flatten mean and std arrays to 1D
x_mean = vec(x_mean_norm)
z_mean = vec(z_mean_norm)
control_mean = vec(control_norm_mean)

x_err = vec(x_std)
z_err = vec(z_std)
control_err = vec(control_std)

#X,Z AND CONTROL OVER TIME (norms)
alpha_lev = 0.8
plt_norms = plot(time, x_mean, ribbon = x_err, label=L"$x_t$", linewidth=3, xlabel="Time", ylabel="Mean Norm", color = cgrad(:BuPu_5, 7)[6], alpha = alpha_lev)
plot!(plt_norms, time, z_mean, ribbon = z_err, label=L"$z_t$", linewidth=3, color = :gray, alpha = alpha_lev, linestyle = :dash)
plot!(plt_norms, time, control_mean, ribbon = control_err, label=L"$u_t$", linewidth=3, alpha = alpha_lev, color = cgrad(:Accent_3)[3])

plt_norms = plot!(plt_norms, xlabel=L"$t$", ylabel=L"$\mathbb{E}[‖\ \cdot \ ‖]$", title="", legend=:topright, grid=false,legend_background_color=:transparent, legend_foreground_color=:transparent, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600))

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = "../Plots/NSC_Neural_Populations"

name_and_path_plot = full_folder_save_plots * "/x_z_u_norms.pdf" 
savefig(plt_norms, name_and_path_plot)

#SHOW THAT CONTROL IS DISTRIBUTED BUT EFFECTVELY ACTS ONLY ON UNITS FAR FROM THE "target"
# Fixed initial condition (same across trials)
abs_x_1 = abs.(x_1_mean)

# Preallocate average control per unit
avg_abs_control_per_unit = zeros(dimension_of_state)
controls_for_hist = Float64[]

for t in 1:T-1
    for k in 1:N_trials_MC
        u_tmp = B * L_opt[:,:,t] * z_vec[:,k,t]
        avg_abs_control_per_unit .+= abs.(u_tmp)  # accumulate absolute value
        append!(controls_for_hist, u_tmp)
    end
end

avg_abs_control_per_unit ./= (N_trials_MC)

x_centers, y_density, x_kde, y_kde, x_gauss, y_gauss, μ, σ = analyze_distribution(controls_for_hist; nbins=1000, smooth_points=300)

#distribution of control with gaussian fit
plt_cont_distr = plot(
    x_centers, y_density,
    label = L"$u_t$",
    xlabel = L"$u_t$",
    ylabel = L"p(u_t)",
    title = "",
    fill = (0, cgrad(:Accent_3)[3], 0.7),  # fill down to y=0
    linewidth = 2,
    color = cgrad(:Accent_3)[3],
    #framestyle = :box,
    grid = false,
    legend = :topright,
    legend_background_color = :transparent,
    foreground_color_legend = nothing
)
x_min_plt = -8
x_max_plt = 8
plot!(plt_cont_distr, x_gauss, y_gauss, lw = 2, linestyle = :dash, color = :black, label = L"\mathcal{N}(\mu_u, \sigma_u^2)", xlims=(x_min_plt,x_max_plt), xticks = [-6,0,6])
# Optional KDE overlay
#plot!(x_kde, y_kde, lw=2, linestyle=:dash, color=:black, label="KDE")
# Optional Gaussian overlay
# plot!(x_gauss, y_gauss, lw=2, linestyle=:dot, color=:gray, label="Gaussian fit")
plt_cont_distr = plot!(plt_cont_distr, legend=:topright, grid=false,legend_background_color=:transparent, legend_foreground_color=:transparent, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600))

#Sort units by abs(x_1[i])
sorted_indices = sortperm(abs_x_1, rev=true)
sorted_abs_x_1 = abs_x_1[sorted_indices]
sorted_avg_u = avg_abs_control_per_unit[sorted_indices]

#Sotrted control and x
plt_sorted_control = plot(
    sorted_abs_x_1,
    sorted_avg_u,
    marker = (:circle, 4, 1.0),
    markersize = 5,
    markerstrokewidth = 0,
    linealpha = 0.0,
    color = cgrad(:Accent_3)[3],
    xlabel = L"$|x_1|$",
    ylabel = L"$\mathbb{E}[ \langle |u| \rangle ]$",
    title = "",
    legend = true,
    grid = false,
    #framestyle = :box,  # minimal frame with just x and y axes
    #tick_direction = :out,
    legend_background_color = :transparent,
    legend_foreground_color = :transparent,
    #dpi = 300, 
    label = ""
)
#compute correlation
# Fit line
a, b = simple_linear_fit(sorted_abs_x_1, sorted_avg_u)
# Generate points for plotting
x_fit = range(minimum(sorted_abs_x_1), stop=maximum(sorted_abs_x_1), length=200)
y_fit = a .* x_fit .+ b


plt_sorted_control = plot!(plt_sorted_control, legend=:topright, grid=false,legend_background_color=:transparent, legend_foreground_color=:transparent, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.0, size = (1300, 600), leg = :false)
#plot!(plt_sorted_control, x_fit, y_fit, color = :gray, linestyle = :dash, lw = 1.5, label = L"$ |u|=a|x_1|+b $")

plt_joined = plot(plt_cont_distr, plt_sorted_control, layout = (1, 2), size = (1300, 600), title = "", titlefont = font(12,"Computer Modern"), legend_background_color=:transparent, legend_foreground_color=:transparent)

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = "../Plots/NSC_Neural_Populations"

name_and_path_plot = full_folder_save_plots * "/distr_and_control_sorted.pdf" 
savefig(plt_joined, name_and_path_plot)

#single units plot: targeted control
trial_num = 1
Tplot = size(x_vec, 3)
x_1 = x_vec[:, trial_num, 1]

# Select units
idx_far_pos = argmax(x_1)
idx_far_neg = argmin(x_1)
idx_near_1 = partialsortperm(abs.(x_1), 1)
idx_near_2 = partialsortperm(abs.(x_1), 2)

# Plot setup
plt_single_units = plot(title = "",
     xlabel = L"$t$",
     ylabel = L"$x_{1,t}$",
     #legend = :topright,
     #framestyle = :box,
     grid = false,
     #dpi = 300,
     background_color = :white,
     foreground_color_legend = nothing,
     legend_background_color = :transparent,
)

plot!(plt_single_units, x_vec[idx_far_pos, trial_num, :], lw = 2.2, color = cgrad(:BuPu_5, 7)[6], alpha = 0.8, label = L"$\mathrm{unit \ 72}$")
plot!(plt_single_units, x_vec[idx_near_1, trial_num, :], lw = 2.2, color = cgrad(:BuPu_5, 7)[6], alpha = 0.6, linestyle = :dash, label = L"$\mathrm{unit \ 32}$")
     
plt_single_units = plot!(plt_single_units, grid=false,legend_background_color=:transparent, legend_foreground_color=:transparent, xtickfontsize=12, ytickfontsize=12, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.2, size = (1300, 600))

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = "../Plots/NSC_Neural_Populations"

name_and_path_plot = full_folder_save_plots * "/single_units_x.pdf" 
savefig(plt_single_units, name_and_path_plot)

#structure of L_t

# Select timepoints
t1 = 1
t2 = Int(round(size(L_opt, 3) / 2))

L1 = L_opt[:, :, t1]
L2 = L_opt[:, :, t2]

# Common color scale
vmax = maximum(abs.([L1; L2]))  # single symmetric scale

# Plot setup
plt1 = heatmap(L1,
    title = L"$L_{t=1}$",
    #clim = (-vmax, vmax),
    #colorbar_ticks = ([-2, 0, 2], ["−2", "0", "2"]),
    color = :balance,
    xlabel = L"$\mathrm{col}$",
    ylabel = L"$\mathrm{row}$",
    framestyle = :box,
    yflip = false,
    #legend = false, 
    #colorbar = false,
    xticks = [], 
    yticks = []
    )

plt1 = plot(plt1, grid=false,legend_background_color=:transparent, legend_foreground_color=:transparent, xtickfontsize=10, ytickfontsize=10, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.0, size = (750, 600))

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = "../Plots/NSC_Neural_Populations"

name_and_path_plot = full_folder_save_plots * "/L_t_structure_start_trial.svg" 
savefig(plt1, name_and_path_plot)

# Plot setup
plt2 = heatmap(L2,
    title = L"$L_{t=T/2}$",
    #clim = (-vmax, vmax),
    #colorbar_ticks = ([-2, 0, 2], ["−2", "0", "2"]),
    color = :balance,
    xlabel = L"$\mathrm{col}$",
    ylabel = L"$\mathrm{row}$",
    framestyle = :box,
    yflip = false,
    #legend = false, 
    #colorbar = false,
    xticks = [], 
    yticks = [],
    )

plt2 = plot(plt2, grid=false,legend_background_color=:transparent, legend_foreground_color=:transparent, xtickfontsize=10, ytickfontsize=10, xguidefontsize=18, yguidefontsize=18, thickness_scaling=2.0, size = (750, 600))

cd(@__DIR__)#to go back to the directory of the script
full_folder_save_plots = "../Plots/NSC_Neural_Populations"

name_and_path_plot = full_folder_save_plots * "/L_t_structure_middle_trial.svg" 
savefig(plt2, name_and_path_plot)
