#PACKAGES 

using Optim, Plots, DelimitedFiles, LinearAlgebra, Random, StatsBase, FiniteDifferences, LaTeXStrings , EasyFit, Printf, FFTW, Pkg, Noise, Clustering, Dierckx, BSplineKit, MultivariateStats, Flux, Combinatorics, Bigsimr, DataFrames, JLD, Base.Threads

#FUNCTIONS 

#(A): Functions to compute expected cost ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#Compute the expected cost using Todorov's algorithm. Note that this should be correct only for Σ_1_z = 0 [to be checked]
#NOTE that we consider c=d=1 for the sake of simplicity for all the codes [only one scaling matrix C and D!] 
function expected_cost_using_Todorov(L, K, A, B, H, C, D, T, dimension_of_state, initial_x_state_mean, Σ_1_x, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    
    L_tod_convention = zeros(dimension_of_control,dimension_of_state,(T-1))
    L_tod_convention[:,:,:] .= .-L[:,:,:]
    #Expected cost, once control and estimator have been optimized
    Sx_tp1 = Q_matrix[:,:,T]      
    Se_tp1 = zeros(dimension_of_state,dimension_of_state)
    s_small_t = 0

    for idx in (T-1):-1:1

        cSx_tp1 = Q_matrix[:,:,idx] .+ transpose(A)*Sx_tp1*(A .- B*L_tod_convention[:,:,idx]) .+ transpose(D)*transpose(K[:,:,idx])*Se_tp1*K[:,:,idx]*D
        cSe_tp1 = transpose(A)*Sx_tp1*B*L_tod_convention[:,:,idx] .+ transpose(A .- K[:,:,idx]*H)*Se_tp1*(A .- K[:,:,idx]*H)

        s_small_t = tr(Sx_tp1*Ω_ξ .+ Se_tp1*(Ω_ξ .+ Ω_η .+ K[:,:,idx]*Ω_ω*transpose(K[:,:,idx]))) + s_small_t

        Sx_tp1 = cSx_tp1
        Se_tp1 = cSe_tp1

    end

    expected_cost = s_small_t + transpose(initial_x_state_mean)*Sx_tp1*initial_x_state_mean  + tr((Sx_tp1 .+ Se_tp1)*Σ_1_x) #Note that we use Σ_1_z = 0

    return expected_cost

end

#Use this function to compute the expected cost with moments propagation
function expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    #Initial conditions for non-central moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xx[:,:,1] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:,1] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:,1] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated
    S_t_zx[:,:,1] .= transpose(S_t_xz[:,:,1])

    for t in 2:T
        S_t_xx[:,:,t] .= A*S_t_xx[:,:,t-1]*transpose(A) .+ A*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ B*L_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(A) .+ B*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ C*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(C*L_matrix[:,:,t-1]) .+ Ω_ξ
        S_t_zx[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(A) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(A) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1])
        S_t_xz[:,:,t] .= transpose(S_t_zx[:,:,t])
        S_t_zz[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*D*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*D) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*Ω_ω*transpose(K_matrix[:,:,t-1]) .+ Ω_η    
    end

    cost = 0
    for t in 1:T-1
      cost = cost + tr(Q_matrix[:,:,t]*S_t_xx[:,:,t]) + tr(transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t]*S_t_zz[:,:,t])
    end
    #final time step 
    cost = cost + tr(Q_matrix[:,:,T]*S_t_xx[:,:,T])
    
    return cost 

end

#Compute expected cost using new trace formula (formula proved by induction); 
function expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_state, dimension_of_state)
    S_t_xz = zeros(dimension_of_state, dimension_of_state)
    S_t_zx = zeros(dimension_of_state, dimension_of_state)
    #Initial conditions for non-central moments 
    S_t_xx[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx[:,:] .= transpose(S_t_xz[:,:])

    # S_t_xx[:,:] .= Diagonal([2.1^2, 2.1^2]) #initial condition for the state covariance matrix
    # S_t_zz[:,:] .= Diagonal([2.1^2, 2.1^2])
    # S_t_xz[:,:] .= Diagonal([2.1^2, 2.1^2]) 
    # S_t_zx[:,:] .= transpose(S_t_xz[:,:])

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    nu_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    d_multiplier = zeros(T)

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_state)
    nu_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_state)
    d_multiplier[T] = 0
    
    #computing the multipliers backwards in time [at fixed K and L]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*omega_multiplier[:,:,t+1]*(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]
        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ .+ omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    #Computing the expected cost at the first time step
    cost = tr(lambda_multiplier[:,:,1]*S_t_xx .+ omega_multiplier[:,:,1]*S_t_zz .+ nu_multiplier[:,:,1]*S_t_xz) + d_multiplier[1]

    return cost

end

#(B): Functions for optimization ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#(B.1): classic K,L approach (Estimation-Control approach) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#(B.1.2): Todorov's optimization algorithm, for the multidimensional case (working for multiplicative and additive, including internal, noise) ----------------------------------------------------------------------------------------------------------------------------------------------------------
function Todorov_optimization_multidimensional_case(A, B, H, C, D, T, dimension_of_state, dimension_of_control, dimension_of_observation, initial_x_state_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η, N_iter = 500)

    Σe_t = zeros(dimension_of_state,dimension_of_state)
    Σxe_t = zeros(dimension_of_state,dimension_of_state)
    Σx_t = zeros(dimension_of_state,dimension_of_state)

    Σe_t_start = zeros(dimension_of_state,dimension_of_state)
    Σxe_t_start = zeros(dimension_of_state,dimension_of_state)
    Σx_t_start = zeros(dimension_of_state,dimension_of_state)

    Σe_t_start[:,:] .= Σ_1_x[:,:] .+ Σ_1_z[:,:]
    Σxe_t_start[:,:] .= Σ_1_z
    Σx_t_start[:,:] .= Σ_1_z .+ initial_x_state_mean*transpose(initial_x_state_mean)

    #OPTIMIZATION
    
    L = zeros(dimension_of_control,dimension_of_state,(T-1))
    K = zeros(dimension_of_state,dimension_of_observation,(T-1))
 
    for i in 1:N_iter

        #OPTIMAL CONTROLLER --> backwards through time

        #Initial (i.e. final) conditions
        Sx_tp1 = Q_matrix[:,:,T]  
        Se_tp1 = zeros(dimension_of_state,dimension_of_state)

        for idx in (T-1):-1:1

            L[:,:,idx] = inv(R_matrix[:,:,idx] .+ transpose(B)*Sx_tp1*B .+ transpose(C)*(Sx_tp1 .+ Se_tp1)*C)*transpose(B)*Sx_tp1*A       
            cSx_tp1 = Q_matrix[:,:,idx] .+ transpose(A)*Sx_tp1*(A .- B*L[:,:,idx]) .+ transpose(D)*transpose(K[:,:,idx])*Se_tp1*K[:,:,idx]*D
            Se_tp1 = transpose(A)*Sx_tp1*B*L[:,:,idx] .+ transpose(A .- K[:,:,idx]*H)*Se_tp1*(A .- K[:,:,idx]*H)
            Sx_tp1 = cSx_tp1

        end

        #OPTIMAL FILTER --> forward in time
        
        #Initial conditions --> given by function arguments 
        Σe_t[:,:] .= Σe_t_start[:,:]
        Σxe_t[:,:] .= Σxe_t_start[:,:]
        Σx_t[:,:] .= Σx_t_start[:,:]

        #first step for Σe_t, Σxe_t, Σx_t: the first K is zero, if we assume zero initial uncertainty on the state vector (Σ_1 = 0)
        idx = 1
        cΣe_t = Ω_ξ .+ Ω_η .+ (A .- K[:,:,idx]*H)*Σe_t*transpose(A) .+ C*L[:,:,idx]*Σx_t*transpose(L[:,:,idx])*transpose(C)
        Σx_t = Ω_η .+ K[:,:,idx]*H*Σe_t*transpose(A) .+ (A .- B*L[:,:,idx])*Σx_t*transpose((A .- B*L[:,:,idx])) .+ (A .- B*L[:,:,idx])*Σxe_t*transpose(H)*transpose(K[:,:,idx]) .+ K[:,:,idx]*H*transpose(Σxe_t)*transpose(A .- B*L[:,:,idx])
        Σe_t = cΣe_t
        Σxe_t = (A .- B*L[:,:,idx])*Σxe_t*transpose(A .- K[:,:,idx]*H) .- Ω_η

        for idx in 2:(T-1)

            K[:,:,idx] = A*Σe_t*transpose(H)*inv(H*Σe_t*transpose(H) .+ Ω_ω .+ D*(Σe_t .+ Σx_t .+ Σxe_t .+ transpose(Σxe_t))*transpose(D))
            cΣe_t = Ω_ξ .+ Ω_η .+ (A .- K[:,:,idx]*H)*Σe_t*transpose(A) .+ C*L[:,:,idx]*Σx_t*transpose(L[:,:,idx])*transpose(C)
            Σx_t = Ω_η .+ K[:,:,idx]*H*Σe_t*transpose(A) .+ (A .- B*L[:,:,idx])*Σx_t*transpose((A .- B*L[:,:,idx])) .+ (A .- B*L[:,:,idx])*Σxe_t*transpose(H)*transpose(K[:,:,idx]) .+ K[:,:,idx]*H*transpose(Σxe_t)*transpose(A .- B*L[:,:,idx])
            Σe_t = cΣe_t
            Σxe_t = (A .- B*L[:,:,idx])*Σxe_t*transpose(A .- K[:,:,idx]*H) .- Ω_η

        end

    end

    return .-L,K

end

#(B.1.3): Numerical GD procedure (working for multiplicative and additive, including internal noise) ----------------------------------------------------------------------------------------------------------------------------------------------------------
# Numerical GD algorithm - Estimation-Control optimization

#function for GD optimization - OPTIMAL ESTIMATOR
function expected_cost_using_mom_prop_for_GD_ESTIMATION_optimization_raw_moments(x, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_control, dimension_of_observation, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    
    K_matrix = reshape(x[1:dimension_of_state*dimension_of_observation*(T - 1)], dimension_of_state, dimension_of_observation, T - 1)
    L_matrix_tmp = zeros(dimension_of_control, dimension_of_state, T)
    L_matrix_tmp[:,:,1:T-1] .= L_matrix[:,:,:]

    #Initial conditions for non-central moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xx[:,:,1] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:,1] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:,1] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated
    S_t_zx[:,:,1] .= transpose(S_t_xz[:,:,1])
    cost = 0
    cost = cost + tr(Q_matrix[:,:,1]*S_t_xx[:,:,1]) + tr(transpose(L_matrix[:,:,1])*R_matrix[:,:,1]*L_matrix[:,:,1]*S_t_zz[:,:,1])

    for t in 2:T
        S_t_xx[:,:,t] .= A*S_t_xx[:,:,t-1]*transpose(A) .+ A*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ B*L_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(A) .+ B*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ C*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(C*L_matrix[:,:,t-1]) .+ Ω_ξ
        S_t_zx[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(A) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(A) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1])
        S_t_xz[:,:,t] .= transpose(S_t_zx[:,:,t])
        S_t_zz[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*D*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*D) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*Ω_ω*transpose(K_matrix[:,:,t-1]) .+ Ω_η    
        cost = cost + tr(Q_matrix[:,:,t]*S_t_xx[:,:,t]) + tr(transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*S_t_zz[:,:,t])
    end
    
    return cost 

end

#function for GD optimization - OPTIMAL LINEAR CONTROLLER
function expected_cost_using_mom_prop_for_GD_CONTROL_optimization_raw_moments(x, K_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_control, dimension_of_observation, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    
    L_matrix = reshape(x[1:dimension_of_control*dimension_of_state*(T - 1)], dimension_of_control, dimension_of_state, T - 1)
    L_matrix_tmp = zeros(dimension_of_control, dimension_of_state, T)
    L_matrix_tmp[:,:,1:T-1] .= L_matrix[:,:,:]

    #Initial conditions for non-central moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xx[:,:,1] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:,1] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:,1] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated
    S_t_zx[:,:,1] .= transpose(S_t_xz[:,:,1])
    cost = 0
    cost = cost + tr(Q_matrix[:,:,1]*S_t_xx[:,:,1]) + tr(transpose(L_matrix[:,:,1])*R_matrix[:,:,1]*L_matrix[:,:,1]*S_t_zz[:,:,1])

    for t in 2:T
        S_t_xx[:,:,t] .= A*S_t_xx[:,:,t-1]*transpose(A) .+ A*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ B*L_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(A) .+ B*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ C*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(C*L_matrix[:,:,t-1]) .+ Ω_ξ
        S_t_zx[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(A) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(A) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1])
        S_t_xz[:,:,t] .= transpose(S_t_zx[:,:,t])
        S_t_zz[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*D*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*D) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*Ω_ω*transpose(K_matrix[:,:,t-1]) .+ Ω_η    
        cost = cost + tr(Q_matrix[:,:,t]*S_t_xx[:,:,t]) + tr(transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*S_t_zz[:,:,t])
    end
    
    return cost 

end

#function for GD optimization - OPTIMAL ESTIMATOR AND LINEAR CONTROLLER [joint optimization]
function expected_cost_using_mom_prop_for_GD_WHOLE_optimization_raw_moments(x, A, B, H, C, D, T, dimension_of_state, dimension_of_control, dimension_of_observation, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    
    L_matrix = reshape(x[1:dimension_of_control*dimension_of_state*(T - 1)], dimension_of_control, dimension_of_state, T - 1)
    K_matrix = reshape(x[dimension_of_control*dimension_of_state*(T - 1)+1:dimension_of_control*dimension_of_state*(T - 1)+dimension_of_state*dimension_of_observation*(T - 1)], dimension_of_state, dimension_of_observation, T - 1)

    L_matrix_tmp = zeros(dimension_of_control, dimension_of_state, T)
    L_matrix_tmp[:,:,1:T-1] .= L_matrix[:,:,:]

    #Initial conditions for non-central moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xz = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_xx[:,:,1] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:,1] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:,1] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated
    S_t_zx[:,:,1] .= transpose(S_t_xz[:,:,1])
    cost = 0
    cost = cost + tr(Q_matrix[:,:,1]*S_t_xx[:,:,1]) + tr(transpose(L_matrix[:,:,1])*R_matrix[:,:,1]*L_matrix[:,:,1]*S_t_zz[:,:,1])

    for t in 2:T
        S_t_xx[:,:,t] .= A*S_t_xx[:,:,t-1]*transpose(A) .+ A*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ B*L_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(A) .+ B*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ C*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(C*L_matrix[:,:,t-1]) .+ Ω_ξ
        S_t_zx[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(A) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1]) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(A) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(B*L_matrix[:,:,t-1])
        S_t_xz[:,:,t] .= transpose(S_t_zx[:,:,t])
        S_t_zz[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*D*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*D) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ (A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H)*S_t_zz[:,:,t-1]*transpose(A .+ B*L_matrix[:,:,t-1] .- K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*Ω_ω*transpose(K_matrix[:,:,t-1]) .+ Ω_η    
        cost = cost + tr(Q_matrix[:,:,t]*S_t_xx[:,:,t]) + tr(transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*S_t_zz[:,:,t])
    end
    
    return cost 

end

# Example of usage (for Control optimization)
# #Objective function -  NOTE WE ONLY OPTIMIZE CONTROL AT FIXED FILTERS
# #Numerical GD algorithm - optimization params 
# algorithm_GD = GradientDescent()
# iterations_GD = 5000
# # Specify options for the optimization algorithm
# options_GD = Optim.Options(
#     # Step size for gradient descent
#     iterations = iterations_GD,  # Number of iterations
#     store_trace = false   # Show optimization trace
# )
# cost_function_optimization_GD(x) = expected_cost_using_mom_prop_for_GD_CONTROL_optimization_raw_moments(x, K, A, B, H, C, D, T, dimension_of_state, dimension_of_control, dimension_of_observation, x_state_mean_0, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)   
# x_0 = zeros(dimension_of_control*dimension_of_state*(T - 1))
# x_0[1:dimension_of_control*dimension_of_state*(T - 1)] = L_TOD[:] #initial condition for the control gains

# result_GD = optimize(cost_function_optimization_GD, x_0, algorithm_GD, options_GD)

# x_opt_GD = result_GD.minimizer
# expected_cost_GD_optim = result_GD.minimum

# L_GD = reshape(x_opt_GD[1:dimension_of_control*dimension_of_state*(T - 1)], dimension_of_control, dimension_of_state, T - 1)

#(B.1.4): Lagrange multiplier solution --  New analytical algorithm for ESTIMATION-CONTROL APPROACH -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#One iteration for estimation optimization 
function Estimation_Optimization_one_iteration_with_Lagrange_Multipliers(dimension_of_state, dimension_of_control, dimension_of_observation, K_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    nu_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_state, dimension_of_state)
    S_t_xz = zeros(dimension_of_state, dimension_of_state)
    S_t_zx = zeros(dimension_of_state, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zx_old = zeros(dimension_of_state, dimension_of_state)

    K_matrix = zeros(dimension_of_state, dimension_of_observation, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_state, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_state)
    nu_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and L]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*omega_multiplier[:,:,t+1]*(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]
        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ .+ omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    #K optimization
    for t in 1:(T-2)

        #auxiliary variables 
        F = omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1])
        V = (omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1]))*(A .+ B*L_matrix[:,:,t])*(S_t_zx_old[:,:] .- S_t_zz_old[:,:])*transpose(H) .+ nu_multiplier[:,:,t+1]*(A*(S_t_xx_old[:,:] .- S_t_xz_old[:,:]) .+ B*L_matrix[:,:,t]*(S_t_zx_old[:,:] .- S_t_zz_old[:,:]))*transpose(H)
        G = H*(S_t_xx_old[:,:] .- S_t_xz_old[:,:])*transpose(H) .+ H*(S_t_zz_old[:,:] .- S_t_zx_old[:,:])*transpose(H) .+ D*S_t_xx_old[:,:]*transpose(D) .+ Ω_ω
        #optimal K_t
        #K_matrix[:,:,t] .= .-(inv(F)*V*inv(G))
        K_matrix[:,:,t] .= .-(pinv(F)*V*pinv(G))

        #update moments 
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(C*L_matrix[:,:,t]) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(A) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t])
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return K_matrix

end

#One iteration for control optimization
function Control_Optimization_one_iteration_with_Lagrange_Multipliers(dimension_of_state, dimension_of_control, dimension_of_observation, K_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    nu_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_state, dimension_of_state)
    S_t_xz = zeros(dimension_of_state, dimension_of_state)
    S_t_zx = zeros(dimension_of_state, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zx_old = zeros(dimension_of_state, dimension_of_state)

    K_matrix = zeros(dimension_of_state, dimension_of_observation, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_state, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_state)
    nu_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and L]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*omega_multiplier[:,:,t+1]*(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]
        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ .+ omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    #L optimization
    for t in 1:(T-1)

        #auxiliary variables
        J = 2 .*R_matrix[:,:,t] .+ transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*B .+ transpose(C)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*C .+ transpose(B)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*B .+ transpose(B)*(nu_multiplier[:,:,t+1] .+ transpose(nu_multiplier[:,:,t+1]))*B
        N = transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(B)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(B)*nu_multiplier[:,:,t+1]*A .+ transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        P = transpose(B)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*(A .- K_matrix[:,:,t]*H) .+ transpose(B)*transpose(nu_multiplier[:,:,t+1])*(A .- K_matrix[:,:,t]*H)
        #optimal L_t
        L_matrix[:,:,t] .= .-(pinv(J)*(N*S_t_xz_old[:,:] .+ P*S_t_zz_old[:,:])*pinv(S_t_zz_old[:,:]))
        #L_matrix[:,:,t] = .-pinv(J)*(N * S_t_xz_old[:,:]*pinv(S_t_zz_old[:,:]) .+ P)

        #update moments 
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(C*L_matrix[:,:,t]) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(A) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t])
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return L_matrix

end

#Optimization of K and L using Lagrange multipliers
function Optimal_Control_Estimation_with_Lagrange_Multipliers(N_iter, dimension_of_state, dimension_of_control, dimension_of_observation, K_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    K_matrix = zeros(dimension_of_state, dimension_of_observation, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_state, T-1)
    K_matrix_old = zeros(dimension_of_state, dimension_of_observation, T-1)
    L_matrix_old = zeros(dimension_of_control, dimension_of_state, T-1)

    cost_mom_prop = zeros(N_iter+1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    cost_mom_prop[1] = expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    for i in 1:N_iter

        #Estimation optimization
        K_matrix[:,:,:] .= Estimation_Optimization_one_iteration_with_Lagrange_Multipliers(dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
        #Control optimization
        L_matrix[:,:,:] .= Control_Optimization_one_iteration_with_Lagrange_Multipliers(dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

        #compute expected cost to see whether it is decreasing using moment propagation 
        cost_mom_prop[i+1] = expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    end
    
    return K_matrix, L_matrix, cost_mom_prop

end

#Optimal Estimator 
function Optimal_Estimation_with_Lagrange_Multipliers(N_iter, dimension_of_state, dimension_of_control, dimension_of_observation, K_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    nu_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_state, dimension_of_state)
    S_t_xz = zeros(dimension_of_state, dimension_of_state)
    S_t_zx = zeros(dimension_of_state, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zx_old = zeros(dimension_of_state, dimension_of_state)

    K_matrix = zeros(dimension_of_state, dimension_of_observation, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_state, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #cost during optimization procedure 
    cost_formula_trace_induction = zeros(N_iter, T-1)
    cost_mom_prop = zeros(N_iter, T-1)

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.
    for iter in 1:N_iter

        # Initial conditions for non-central moments
        S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
        S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
        S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
        S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

        #multipliers at the last time step [boundary conditions]
        lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
        omega_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_observation)
        nu_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_observation)
        d_multiplier[T] = 0

        #computing the multipliers backwards in time [at fixed K and L]
        for t in (T-1):-1:1
            lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
            omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*omega_multiplier[:,:,t+1]*(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]
            nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
            d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ .+ omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
        end 

        #K optimization
        for t in 1:(T-2)

            #auxiliary variables 
            F = omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1])
            V = (omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1]))*(A .+ B*L_matrix[:,:,t])*(S_t_zx_old[:,:] .- S_t_zz_old[:,:])*transpose(H) .+ nu_multiplier[:,:,t+1]*(A*(S_t_xx_old[:,:] .- S_t_xz_old[:,:]) .+ B*L_matrix[:,:,t]*(S_t_zx_old[:,:] .- S_t_zz_old[:,:]))*transpose(H)
            G = H*(S_t_xx_old[:,:] .- S_t_xz_old[:,:])*transpose(H) .+ H*(S_t_zz_old[:,:] .- S_t_zx_old[:,:])*transpose(H) .+ D*S_t_xx_old[:,:]*transpose(D) .+ Ω_ω
            #optimal K_t
            K_matrix[:,:,t] .= .-(pinv(F)*V*pinv(G))

            #compute expected cost to see whether it is decreasing
            cost_formula_trace_induction[iter, t] = expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_mom_prop[iter,t] = expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            
            #update moments 
            S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(C*L_matrix[:,:,t]) .+ Ω_ξ
            S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(A) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t])
            S_t_xz[:,:] .= transpose(S_t_zx[:,:])
            S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
            #update old moments for next optimization 
            S_t_xx_old[:,:] .= S_t_xx[:,:]
            S_t_xz_old[:,:] .= S_t_xz[:,:]
            S_t_zx_old[:,:] .= S_t_zx[:,:]
            S_t_zz_old[:,:] .= S_t_zz[:,:]

        end
        #last time step
        #K_matrix[:,:,T-1] .= zeros(dimension_of_state, dimension_of_observation) #Final filter gain is zero since at t = T - 1, there is no need to estimate future states—no control is applied at t = T
            
        #compute expected cost to see whether it is decreasing
        cost_formula_trace_induction[iter, T-1] = expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        cost_mom_prop[iter,T-1] = expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        
    end

    final_cost = expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    return K_matrix, final_cost, cost_formula_trace_induction, cost_mom_prop

end

#Optimal Controller 
function Optimal_Control_with_Lagrange_Multipliers(N_iter, dimension_of_state, dimension_of_control, dimension_of_observation, K_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    nu_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_state, dimension_of_state)
    S_t_xz = zeros(dimension_of_state, dimension_of_state)
    S_t_zx = zeros(dimension_of_state, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zx_old = zeros(dimension_of_state, dimension_of_state)

    K_matrix = zeros(dimension_of_state, dimension_of_observation, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_state, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #cost during optimization procedure 
    cost_formula_trace_induction = zeros(N_iter, T-1)
    cost_mom_prop = zeros(N_iter, T-1)

    #compute expected cost to see whether it is decreasing
    initial_cost_trace = expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    initial_cost_mom_prop = expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    
    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.
    for iter in 1:N_iter

        # Initial conditions for non-central moments
        S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
        S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
        S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
        S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

        #multipliers at the last time step [boundary conditions]
        lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
        omega_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_observation)
        nu_multiplier[:,:,T] .= zeros(dimension_of_state, dimension_of_observation)
        d_multiplier[T] = 0

        #computing the multipliers backwards in time [at fixed K and L]
        for t in (T-1):-1:1
            lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
            omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*omega_multiplier[:,:,t+1]*(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]
            nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
            d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ .+ omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
        end 

        #L optimization
        for t in 1:(T-1)

            #auxiliary variables
            J = 2 .*R_matrix[:,:,t] .+ transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*B .+ transpose(C)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*C .+ transpose(B)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*B .+ transpose(B)*(nu_multiplier[:,:,t+1] .+ transpose(nu_multiplier[:,:,t+1]))*B
            N = transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(B)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(B)*nu_multiplier[:,:,t+1]*A .+ transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
            P = transpose(B)*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*(A .- K_matrix[:,:,t]*H) .+ transpose(B)*transpose(nu_multiplier[:,:,t+1])*(A .- K_matrix[:,:,t]*H)
            #optimal L_t
            L_matrix[:,:,t] .= .-(pinv(J)*(N*S_t_xz_old[:,:] .+ P*S_t_zz_old[:,:])*pinv(S_t_zz_old[:,:]))
            #L_matrix[:,:,t] = .-pinv(J)*(N * S_t_xz_old[:,:]*pinv(S_t_zz_old[:,:]) .+ P)

            #compute expected cost to see whether it is decreasing
            cost_formula_trace_induction[iter, t] = expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_mom_prop[iter,t] = expected_cost_raw_moments_propagation(T, dimension_of_state, dimension_of_control, dimension_of_observation, K_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            
            #update moments 
            S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(C*L_matrix[:,:,t]) .+ Ω_ξ
            S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(B*L_matrix[:,:,t]) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(A) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(B*L_matrix[:,:,t])
            S_t_xz[:,:] .= transpose(S_t_zx[:,:])
            S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ (A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H)*S_t_zz_old[:,:]*transpose(A .+ B*L_matrix[:,:,t] .- K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
            #update old moments for next optimization 
            S_t_xx_old[:,:] .= S_t_xx[:,:]
            S_t_xz_old[:,:] .= S_t_xz[:,:]
            S_t_zx_old[:,:] .= S_t_zx[:,:]
            S_t_zz_old[:,:] .= S_t_zz[:,:]

        end

    end

    final_cost_trace = expected_cost_using_trace_formula_induction(L_matrix, K_matrix, A, B, H, C, D, T, dimension_of_state, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    return L_matrix, initial_cost_trace, initial_cost_mom_prop, final_cost_trace, cost_formula_trace_induction, cost_mom_prop

end

#(B.2): NEURAL SPACE CONTROL (NSC) ------------------------------------------------------------------------------------------------

# Note that for the 3D reaching task, we start optimizing from the second time step (t=2) 
# instead of t=1, in order to exactly match the initial conditions of the EC approaches 
# (Todorov and Damiani) and have a fair comparison. For the neural population steering task 
# we start the optimization at t=1.

#One iteration for optimization of K_t [weigths of sensory feedback -- input to the latent dynamics] - optimization starts at t=2
function Input_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    nu_multiplier = zeros(dimension_of_latent_space, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx_old = zeros(dimension_of_latent_space, dimension_of_state)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_latent_space)
    nu_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and M]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*omega_multiplier[:,:,t+1]*M_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]

        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(M_matrix[:,:,t])*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ) + tr(omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    # #update moments
    t=1
    S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
    S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
    S_t_xz[:,:] .= transpose(S_t_zx[:,:])
    S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
    #update old moments for next optimization 
    S_t_xx_old[:,:] .= S_t_xx[:,:]
    S_t_xz_old[:,:] .= S_t_xz[:,:]
    S_t_zx_old[:,:] .= S_t_zx[:,:]
    S_t_zz_old[:,:] .= S_t_zz[:,:]

    #K optimization
    for t in 2:(T-2)

        #auxiliary variables 
        F = omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1])
        V = (omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1]))*M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(H) .+ nu_multiplier[:,:,t+1]*(A*S_t_xx_old[:,:] .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:])*transpose(H)
        G = H*S_t_xx_old[:,:]*transpose(H) .+ D*S_t_xx_old[:,:]*transpose(D) .+ Ω_ω
        
        #optimal K_t
        #K_matrix[:,:,t] .= .-(inv(F)*V*inv(G))
        K_matrix[:,:,t] .= .-(pinv(F)*V*pinv(G))
        
        #K_matrix[:,:,t] .= .-(M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(H) .+ pinv(F)*nu_multiplier[:,:,t+1]*(A*S_t_xx_old[:,:] .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:])*transpose(H))*pinv(G)
        
        #update moments
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return K_matrix

end

#One iteration for optimization of M_t [internal connectivity weigths of the internal dynamics] --  referred to as "W" in the paper - optimization starts at t=2
function Latent_Space_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    nu_multiplier = zeros(dimension_of_latent_space, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx_old = zeros(dimension_of_latent_space, dimension_of_state)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_latent_space)
    nu_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and M]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*omega_multiplier[:,:,t+1]*M_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]

        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(M_matrix[:,:,t])*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ) + tr(omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    #M optimization
    t=1
    #update moments
    S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
    S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
    S_t_xz[:,:] .= transpose(S_t_zx[:,:])
    S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
    #update old moments for next optimization 
    S_t_xx_old[:,:] .= S_t_xx[:,:]
    S_t_xz_old[:,:] .= S_t_xz[:,:]
    S_t_zx_old[:,:] .= S_t_zx[:,:]
    S_t_zz_old[:,:] .= S_t_zz[:,:]

    for t in 2:(T-2)

        #auxiliary variables
        J = omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1])
        N = (omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H*S_t_xz_old[:,:] .+ nu_multiplier[:,:,t+1]*(A*S_t_xz_old[:,:] .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:])
        #optimal M_t
        M_matrix[:,:,t] .= .-(pinv(J)*N*pinv(S_t_zz_old[:,:]))

        #update moments
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return M_matrix

end

#One iteration for optimization of L_t [output weigths of latent state -- together with B produce the physical control from the latent dynamics] - optimization starts at t=2
function Output_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    nu_multiplier = zeros(dimension_of_latent_space, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx_old = zeros(dimension_of_latent_space, dimension_of_state)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_latent_space)
    nu_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and M]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*omega_multiplier[:,:,t+1]*M_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]

        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(M_matrix[:,:,t])*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ) + tr(omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    #L optimization
    t=1
    #update moments
    S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
    S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
    S_t_xz[:,:] .= transpose(S_t_zx[:,:])
    S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
    #update old moments for next optimization 
    S_t_xx_old[:,:] .= S_t_xx[:,:]
    S_t_xz_old[:,:] .= S_t_xz[:,:]
    S_t_zx_old[:,:] .= S_t_zx[:,:]
    S_t_zz_old[:,:] .= S_t_zz[:,:]
    
    for t in 2:(T-1)

        #auxiliary variables 
        J = 2 .* R_matrix[:,:,t] .+ transpose(B)*(lambda_multiplier[:,:,t+1] .+  transpose(lambda_multiplier[:,:,t+1]))*B .+ transpose(C)*(lambda_multiplier[:,:,t+1] .+  transpose(lambda_multiplier[:,:,t+1]))*C

        N = transpose(B)*transpose(nu_multiplier[:,:,t+1])*M_matrix[:,:,t]*S_t_zz_old[:,:] .+ transpose(B)*(lambda_multiplier[:,:,t+1] .+  transpose(lambda_multiplier[:,:,t+1]))*A*S_t_xz_old[:,:] .+ transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H*S_t_xz_old[:,:]

        #optimal L_t
        L_matrix[:,:,t] .= .-(pinv(J)*N*pinv(S_t_zz_old[:,:]))

        #update moments
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return L_matrix

end

#Whole optimization (K_t, L_t and M_t iteratively in a coordinate descent fashion)
function Optimal_EXTENDED_Latent_Space_and_Input_with_Lagrange_Multipliers_COORDINATE_DESCENT(N_iter_tot, N_iter_each_dir, dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)
    
    #L_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)

    cost_mom_prop = zeros(3*N_iter_tot*N_iter_each_dir+1)
    cost_trace_formula = zeros(3*N_iter_tot*N_iter_each_dir+1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]
    
    cost_mom_prop[1] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    cost_trace_formula[1] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η) 
    
    k = 2
    for j in 1:N_iter_tot

        for i in 1:N_iter_each_dir
            #K optimization
            K_matrix[:,:,:] .= Input_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
            #compute expected cost to see whether it is decreasing using moment propagation 
            cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η) 
            k = k+1
        end

        for i in 1:N_iter_each_dir
            #M optimization
            M_matrix[:,:,:] .= Latent_Space_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
            #compute expected cost to see whether it is decreasing using moment propagation 
            cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            k = k+1
        end

        for i in 1:N_iter_each_dir
            #L optimization
            L_matrix[:,:,:] .= Output_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
            #compute expected cost to see whether it is decreasing using moment propagation 
            cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            k = k+1
        end

    end
    
    return K_matrix, M_matrix, L_matrix, cost_mom_prop, cost_trace_formula

end

#Expected cost using non-central moments propagation adapted to the neural space control approach
function expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    #Initial conditions for non-central moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state, T)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space, T)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state, T)
    S_t_xx[:,:,1] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:,1] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:,1] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated
    S_t_zx[:,:,1] .= transpose(S_t_xz[:,:,1])

    for t in 2:T 
        S_t_xx[:,:,t] .= A*S_t_xx[:,:,t-1]*transpose(A) .+ A*S_t_xz[:,:,t-1]*transpose(L_matrix[:,:,t-1])*transpose(B) .+ B*L_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(A) .+ B*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(L_matrix[:,:,t-1])*transpose(B) .+ C*L_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(L_matrix[:,:,t-1])*transpose(C) .+ Ω_ξ
        S_t_zx[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(A) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(L_matrix[:,:,t-1])*transpose(B) .+ M_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(A) .+ M_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(L_matrix[:,:,t-1])*transpose(B)
        S_t_xz[:,:,t] .= transpose(S_t_zx[:,:,t])
        S_t_zz[:,:,t] .= K_matrix[:,:,t-1]*H*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ K_matrix[:,:,t-1]*D*S_t_xx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*D) .+ K_matrix[:,:,t-1]*H*S_t_xz[:,:,t-1]*transpose(M_matrix[:,:,t-1]) .+ M_matrix[:,:,t-1]*S_t_zx[:,:,t-1]*transpose(K_matrix[:,:,t-1]*H) .+ M_matrix[:,:,t-1]*S_t_zz[:,:,t-1]*transpose(M_matrix[:,:,t-1]) .+ K_matrix[:,:,t-1]*Ω_ω*transpose(K_matrix[:,:,t-1]) .+ Ω_η
    end

    cost = 0
    for t in 1:T-1 
      cost = cost + tr(Q_matrix[:,:,t]*S_t_xx[:,:,t]) + tr(transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t]*S_t_zz[:,:,t])
    end
    #final time step 
    cost = cost + tr(Q_matrix[:,:,T]*S_t_xx[:,:,T])
    
    return cost 

end

#Expected cost using trace formula 
function expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        
    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state)
    #Initial conditions for non-central moments 
    S_t_xx[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx[:,:] .= transpose(S_t_xz[:,:])

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    nu_multiplier = zeros(dimension_of_latent_space, dimension_of_state, T)
    d_multiplier = zeros(T)

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_latent_space)
    nu_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and M]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*omega_multiplier[:,:,t+1]*M_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]
        
        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(M_matrix[:,:,t])*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ) + tr(omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end

    #Computing the expected cost at the first time step
    cost = tr(lambda_multiplier[:,:,1]*S_t_xx) + tr(omega_multiplier[:,:,1]*S_t_zz) + tr(nu_multiplier[:,:,1]*S_t_xz) + d_multiplier[1]

    return cost

end

#Functions for task involving control of neural activity ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#One iteration for optimization of K_t [weigths of sensory feedback -- input to the latent dynamics] - optimization starts at t=1
function Input_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control_optimization_starts_from_1(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    nu_multiplier = zeros(dimension_of_latent_space, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx_old = zeros(dimension_of_latent_space, dimension_of_state)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_latent_space)
    nu_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and M]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*omega_multiplier[:,:,t+1]*M_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]

        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(M_matrix[:,:,t])*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ) + tr(omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    # #update moments
    # t=1
    # S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
    # S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
    # S_t_xz[:,:] .= transpose(S_t_zx[:,:])
    # S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
    # #update old moments for next optimization 
    # S_t_xx_old[:,:] .= S_t_xx[:,:]
    # S_t_xz_old[:,:] .= S_t_xz[:,:]
    # S_t_zx_old[:,:] .= S_t_zx[:,:]
    # S_t_zz_old[:,:] .= S_t_zz[:,:]

    #K optimization
    for t in 1:(T-2)

        #auxiliary variables 
        F = omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1])
        V = (omega_multiplier[:,:,t+1] .+  transpose(omega_multiplier[:,:,t+1]))*M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(H) .+ nu_multiplier[:,:,t+1]*(A*S_t_xx_old[:,:] .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:])*transpose(H)
        G = H*S_t_xx_old[:,:]*transpose(H) .+ D*S_t_xx_old[:,:]*transpose(D) .+ Ω_ω
        
        #optimal K_t
        #K_matrix[:,:,t] .= .-(inv(F)*V*inv(G))
        K_matrix[:,:,t] .= .-(pinv(F)*V*pinv(G))
        
        #K_matrix[:,:,t] .= .-(M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(H) .+ pinv(F)*nu_multiplier[:,:,t+1]*(A*S_t_xx_old[:,:] .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:])*transpose(H))*pinv(G)
        
        #update moments
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return K_matrix

end

#one iteration for optimization of L_t [output weigths of latent state -- together with B produce the physical control from the latent dynamics] - optimization starts at t=1
function Output_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control_optimization_starts_from_1(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    # Variables used to compute the expected accumulated cost from time t to the final time T <c_t> = Tr[lambda_multiplier[:,:,t]*S_t^{xx} + omega_multiplier[:,:,t]*S_t^{zz} + nu_multiplier[:,:,t]*S_t^{xz}] + d_multiplier[t]
    # where S_t^{xx} = E[x_t*x_t^T], S_t^{zz} = E[z_t*z_t^T], S_t^{xz} = E[x_t*z_t^T]
    lambda_multiplier = zeros(dimension_of_state, dimension_of_state, T)
    omega_multiplier = zeros(dimension_of_latent_space, dimension_of_latent_space, T)
    nu_multiplier = zeros(dimension_of_latent_space, dimension_of_state, T)
    d_multiplier = zeros(T)

    # Non-central second order moments 
    S_t_xx = zeros(dimension_of_state, dimension_of_state)
    S_t_zz = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx = zeros(dimension_of_latent_space, dimension_of_state)
    #auxiliary variables for the propagation of the non-central moments
    S_t_xx_old = zeros(dimension_of_state, dimension_of_state)
    S_t_zz_old = zeros(dimension_of_latent_space, dimension_of_latent_space)
    S_t_xz_old = zeros(dimension_of_state, dimension_of_latent_space)
    S_t_zx_old = zeros(dimension_of_latent_space, dimension_of_state)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    #optimization, N_iter is the number of iterations of the whole "coordinate descent" optimization
    #NOTE: We use the pseudoinverse to avoid issues with singular matrices (e.g., from initial conditions). 
    #It matches the true inverse when it exists, preserving the analytical formula.

    # Initial conditions for non-central moments
    S_t_xx_old[:,:] .= Σ_1_x[:,:] .+ x_1_mean*transpose(x_1_mean)
    S_t_zz_old[:,:] .= Σ_1_z[:,:] .+ z_1_mean*transpose(z_1_mean)
    S_t_xz_old[:,:] .= x_1_mean*transpose(z_1_mean) #initial values of state and state estimate are assumed to be uncorrelated 
    S_t_zx_old[:,:] .= transpose(S_t_xz_old[:,:])

    #multipliers at the last time step [boundary conditions]
    lambda_multiplier[:,:,T] .= Q_matrix[:,:,T]
    omega_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_latent_space)
    nu_multiplier[:,:,T] .= zeros(dimension_of_latent_space, dimension_of_state)
    d_multiplier[T] = 0

    #computing the multipliers backwards in time [at fixed K and M]
    for t in (T-1):-1:1
        lambda_multiplier[:,:,t] .= Q_matrix[:,:,t] .+ transpose(A)*lambda_multiplier[:,:,t+1]*A .+ transpose(H)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*H .+ transpose(H)*transpose(K_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(D)*transpose(K_matrix[:,:,t])*omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*D
        
        omega_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*R_matrix[:,:,t]*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(B)*lambda_multiplier[:,:,t+1]*B*L_matrix[:,:,t] .+ transpose(L_matrix[:,:,t])*transpose(C)*lambda_multiplier[:,:,t+1]*C*L_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*omega_multiplier[:,:,t+1]*M_matrix[:,:,t] .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*B*L_matrix[:,:,t]

        nu_multiplier[:,:,t] .= transpose(L_matrix[:,:,t])*transpose(B)*(lambda_multiplier[:,:,t+1] .+ transpose(lambda_multiplier[:,:,t+1]))*A .+ transpose(M_matrix[:,:,t])*(omega_multiplier[:,:,t+1] .+ transpose(omega_multiplier[:,:,t+1]))*K_matrix[:,:,t]*H .+ transpose(M_matrix[:,:,t])*nu_multiplier[:,:,t+1]*A .+ transpose(L_matrix[:,:,t])*transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H
        d_multiplier[t] = d_multiplier[t+1] + tr(lambda_multiplier[:,:,t+1]*Ω_ξ) + tr(omega_multiplier[:,:,t+1]*Ω_η .+ omega_multiplier[:,:,t+1]*K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]))
    end 

    #L optimization
    # t=1
    # #update moments
    # S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
    # S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
    # S_t_xz[:,:] .= transpose(S_t_zx[:,:])
    # S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
    # #update old moments for next optimization 
    # S_t_xx_old[:,:] .= S_t_xx[:,:]
    # S_t_xz_old[:,:] .= S_t_xz[:,:]
    # S_t_zx_old[:,:] .= S_t_zx[:,:]
    # S_t_zz_old[:,:] .= S_t_zz[:,:]
    
    for t in 1:(T-1)

        #auxiliary variables 
        J = 2 .* R_matrix[:,:,t] .+ transpose(B)*(lambda_multiplier[:,:,t+1] .+  transpose(lambda_multiplier[:,:,t+1]))*B .+ transpose(C)*(lambda_multiplier[:,:,t+1] .+  transpose(lambda_multiplier[:,:,t+1]))*C

        N = transpose(B)*transpose(nu_multiplier[:,:,t+1])*M_matrix[:,:,t]*S_t_zz_old[:,:] .+ transpose(B)*(lambda_multiplier[:,:,t+1] .+  transpose(lambda_multiplier[:,:,t+1]))*A*S_t_xz_old[:,:] .+ transpose(B)*transpose(nu_multiplier[:,:,t+1])*K_matrix[:,:,t]*H*S_t_xz_old[:,:]

        #optimal L_t
        L_matrix[:,:,t] .= .-(pinv(J)*N*pinv(S_t_zz_old[:,:]))

        #update moments
        S_t_xx[:,:] .= A*S_t_xx_old[:,:]*transpose(A) .+ A*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ B*L_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ B*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ C*L_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(C) .+ Ω_ξ
        S_t_zx[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(A) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(A) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(L_matrix[:,:,t])*transpose(B)
        S_t_xz[:,:] .= transpose(S_t_zx[:,:])
        S_t_zz[:,:] .= K_matrix[:,:,t]*H*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ K_matrix[:,:,t]*D*S_t_xx_old[:,:]*transpose(K_matrix[:,:,t]*D) .+ K_matrix[:,:,t]*H*S_t_xz_old[:,:]*transpose(M_matrix[:,:,t]) .+ M_matrix[:,:,t]*S_t_zx_old[:,:]*transpose(K_matrix[:,:,t]*H) .+ M_matrix[:,:,t]*S_t_zz_old[:,:]*transpose(M_matrix[:,:,t]) .+ K_matrix[:,:,t]*Ω_ω*transpose(K_matrix[:,:,t]) .+ Ω_η
        #update old moments for next optimization 
        S_t_xx_old[:,:] .= S_t_xx[:,:]
        S_t_xz_old[:,:] .= S_t_xz[:,:]
        S_t_zx_old[:,:] .= S_t_zx[:,:]
        S_t_zz_old[:,:] .= S_t_zz[:,:]

    end

    return L_matrix

end

#L_t optimization  
function Optimal_EXTENDED_Latent_Space_ONLY_OUTPUT_WEIGHTS(N_iter_tot, dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    N_iter_each_dir = 1

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)
    
    #L_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)

    cost_mom_prop = zeros(1*N_iter_tot*N_iter_each_dir+1)
    cost_trace_formula = zeros(1*N_iter_tot*N_iter_each_dir+1)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    cost_mom_prop[1] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
    cost_trace_formula[1] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)

    k = 2

    for j in 1:N_iter_tot

        # for i in 1:N_iter_each_dir
        #     #K optimization
        #     K_matrix[:,:,:] .= Input_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
        #     #compute expected cost to see whether it is decreasing using moment propagation 
        #     cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        #     cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η) 
        #     k = k+1
        # end

        # for i in 1:N_iter_each_dir
        #     #M optimization
        #     M_matrix[:,:,:] .= Latent_Space_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
        #     #compute expected cost to see whether it is decreasing using moment propagation 
        #     cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        #     cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        #     k = k+1
        # end

        for i in 1:N_iter_each_dir
            #L optimization
            L_matrix[:,:,:] .= Output_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control_optimization_starts_from_1(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
            #compute expected cost to see whether it is decreasing using moment propagation 
            cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            k = k+1
        end

    end
    
    return K_matrix, M_matrix, L_matrix, cost_mom_prop, cost_trace_formula

end

#K_t [referred to as "P_t" in the paper] and L_t optimization
function Optimal_EXTENDED_Latent_Space_ONLY_INPUT_and_OUTPUT_WEIGHTS(N_iter_tot, N_iter_each_dir, dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_initial, M_initial, L_initial, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)

    K_matrix = zeros(dimension_of_latent_space, dimension_of_observation, T-1)
    M_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)
    L_matrix = zeros(dimension_of_control, dimension_of_latent_space, T-1)
    
    #L_matrix = zeros(dimension_of_latent_space, dimension_of_latent_space, T-1)

    cost_mom_prop = zeros(2*N_iter_tot*N_iter_each_dir)
    cost_trace_formula = zeros(2*N_iter_tot*N_iter_each_dir)

    #initial condition for K and L [for the optimization procedure]
    K_matrix[:,:,:] .= K_initial[:,:,:]
    M_matrix[:,:,:] .= M_initial[:,:,:]
    L_matrix[:,:,:] .= L_initial[:,:,:]

    k = 1
    for j in 1:N_iter_tot

        for i in 1:N_iter_each_dir
            #K optimization
            K_matrix[:,:,:] .= Input_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control_optimization_starts_from_1(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
            #compute expected cost to see whether it is decreasing using moment propagation 
            cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η) 
            k = k+1
        end

        # for i in 1:N_iter_each_dir
        #     #M optimization
        #     M_matrix[:,:,:] .= Latent_Space_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
        #     #compute expected cost to see whether it is decreasing using moment propagation 
        #     cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        #     cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
        #     k = k+1
        # end

        for i in 1:N_iter_each_dir
            #L optimization
            L_matrix[:,:,:] .= Output_Weights_Optimization_one_iteration_with_Lagrange_Multipliers_EXTENDED_Latent_Space_Control_optimization_starts_from_1(dimension_of_state, dimension_of_latent_space, dimension_of_observation, dimension_of_control, K_matrix, M_matrix, L_matrix, A, B, H, C, D, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix, Σ_1_x, Σ_1_z, x_1_mean, z_1_mean)
            #compute expected cost to see whether it is decreasing using moment propagation 
            cost_mom_prop[k] = expected_cost_raw_moments_propagation_EXTENDED_Latent_Space(T, dimension_of_state, dimension_of_latent_space, dimension_of_observation, K_matrix, M_matrix, L_matrix, A, B, H, C, D, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            cost_trace_formula[k] = expected_cost_using_trace_formula_induction_EXTENDED_LSC(M_matrix, K_matrix, L_matrix, A, B, H, C, D, T, dimension_of_state, dimension_of_latent_space, x_1_mean, z_1_mean, Σ_1_x, Σ_1_z, Q_matrix, R_matrix, Ω_ξ, Ω_ω, Ω_η)
            k = k+1
        end

    end
    
    return K_matrix, M_matrix, L_matrix, cost_mom_prop, cost_trace_formula

end

#(C): USEFUL FUNCTIONS -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

#We use the optimal parameters to compute the predictions of the system for NSC (can be used also for classic K,L approach by properly defining M=A+BL-KH)
function get_model_predictions_NSC(random_seed, N_trials_MC, dimension_of_state, dimension_of_latent_space_ext_LSC, dimension_of_observation, T, x_1_mean, z_1_mean, Σ_x_1, Σ_z_1, A, B, C, H, D, K, M, L, Ω_ξ, Ω_ω, Ω_η, Q_matrix, R_matrix)
    
    Random.seed!(random_seed) # Set a random seed for reproducibility

    Ω_ξ_sqrt = sqrt.(Ω_ξ)
    Ω_ω_sqrt = sqrt.(Ω_ω)
    Ω_η_sqrt = sqrt.(Ω_η)

    x_vec = zeros(dimension_of_state, N_trials_MC, T)
    z_vec = zeros(dimension_of_latent_space_ext_LSC, N_trials_MC, T)

    cost_over_time = zeros(N_trials_MC, T)

    L_matrix_tmp = zeros(dimension_of_control, dimension_of_latent_space_ext_LSC, T)
    L_matrix_tmp[:,:,1:T-1] .= L[:,:,:] 

    @threads for i in 1:N_trials_MC
        
        t = 1
        x_vec[:,i,t] .= x_1_mean + sqrt.(Σ_x_1) * randn(dimension_of_state)
        z_vec[:,i,t] .= z_1_mean + sqrt.(Σ_z_1) * randn(dimension_of_state)
        
        #cost_over_time[i,t] = tr(Q_matrix[:,:,t]*x_vec[:,i,t]*transpose(x_vec[:,i,t])) + tr(transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*z_vec[:,i,t]*transpose(z_vec[:,i,t]))
        cost_over_time[i,t] = transpose(x_vec[:,i,t])*Q_matrix[:,:,t]*x_vec[:,i,t] + transpose(z_vec[:,i,t])*transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*z_vec[:,i,t]

        for t in 2:T

            x_vec[:,i,t] .= A * x_vec[:,i,t-1] + B * L[:,:,t-1] * z_vec[:,i,t-1] + Ω_ξ_sqrt * randn(dimension_of_state) + randn() * C * L[:,:,t-1] * z_vec[:,i,t-1]
            z_vec[:,i,t] .= M[:,:,t-1] * z_vec[:,i,t-1]  .+ K[:,:,t-1] * H * x_vec[:,i,t-1] .+ K[:,:,t-1] * Ω_ω_sqrt * randn(dimension_of_observation) .+ randn() * K[:,:,t-1] * D * x_vec[:,i,t-1] .+ Ω_η_sqrt * randn(dimension_of_latent_space_ext_LSC)
            
            #cost_over_time[i,t] = tr(Q_matrix[:,:,t]*x_vec[:,i,t]*transpose(x_vec[:,i,t])) + tr(transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*z_vec[:,i,t]*transpose(z_vec[:,i,t]))
            cost_over_time[i,t] = transpose(x_vec[:,i,t])*Q_matrix[:,:,t]*x_vec[:,i,t] + transpose(z_vec[:,i,t])*transpose(L_matrix_tmp[:,:,t])*R_matrix[:,:,t]*L_matrix_tmp[:,:,t]*z_vec[:,i,t]

        end

    end

    #Sum over time for each trial
    trial_totals = sum(cost_over_time, dims=2)

    #Compute mean total cost
    mean_total_cost = mean(trial_totals)

    #Compute standard error of the mean (SEM)
    sem_total_cost = std(trial_totals)/sqrt(N_trials_MC)

    return x_vec, z_vec, cost_over_time, mean_total_cost, sem_total_cost

end

"""
    analyze_distribution(data; nbins=50, smooth_points=300)

Given a 1D data array, returns the histogram (normalized), KDE, and Gaussian fit components.

# Arguments
- `data`: Vector{Float64} — the data to analyze
- `nbins`: Int — number of bins in the histogram (default 50)
- `smooth_points`: Int — resolution of the Gaussian curve (default 300)

# Returns
- `x_centers`, `y_density`: histogram centers and normalized counts
- `x_kde`, `y_kde`: smoothed KDE curve
- `x_gauss`, `y_gauss`: fitted Gaussian curve
- `μ`, `σ`: mean and std of the fitted Gaussian
"""
function analyze_distribution(data::Vector{<:Real}; nbins=50, smooth_points=300)
    # Histogram (normalized to density)
    hist = fit(Histogram, data; nbins=nbins, closed=:left)
    x_edges = hist.edges[1]
    y_counts = hist.weights
    bin_width = x_edges[2] - x_edges[1]
    x_centers = (x_edges[1:end-1] .+ x_edges[2:end]) ./ 2
    y_density = y_counts ./ (length(data) * bin_width)

    # KDE
    kde_obj = kde(data)
    x_kde = kde_obj.x
    y_kde = kde_obj.density

    # Gaussian fit
    μ, σ = mean(data), std(data)
    gauss_fit = Normal(μ, σ)
    x_gauss = range(minimum(data), stop=maximum(data), length=smooth_points)
    y_gauss = pdf.(gauss_fit, x_gauss)

    return x_centers, y_density, x_kde, y_kde, x_gauss, y_gauss, μ, σ
end

# Fit a simple linear regression line to the data
function simple_linear_fit(x::Vector, y::Vector)
    slope = cov(x, y) / var(x)
    intercept = mean(y) - slope * mean(x)
    return slope, intercept
end