%% Data-driven discovery of linear dynamical systems with process and measurement noise

%% Example 1: Discrete-time linear dynamical model.
clear;
clc;
close all;
addpath (genpath('tools')) 
rng(100, 'twister');
% Initial state 
R = [0.49 0.25; 0.25 0.49];
Q = [0.49 0.25; 0.25 0.49];
R = chol(R)';
Q = chol(Q)';
xi = [1, 1]';
rr = 1;
iv = xi + rr * randn(2,1);  
A = [0 0.9; 0.9 0];
B = 2 * eye(2);  
C = 2 * eye(2);  
D = 1.5 * eye(2);  
n = size(A,1);
mm = size(C,1);
% Initial state 
T_total = 3000;
T_train = T_total*2/3;
T_test = T_total - T_train;
% train dataset
X_train = zeros(2, T_train);
Y_train = zeros(2, T_train);
for i = 1:T_train
    U_train(:,i) = 2 * rand(2, 1); 
    if i == 1
        X_train(:,i) = A * iv + B * U_train(:,i) + R * randn(n,1);
    else
        X_train(:,i) = A * X_train(:,i-1) + B * U_train(:,i) + R * randn(n,1);
    end
    Y_train(:,i) = C * X_train(:,i) + D * U_train(:,i) + Q * randn(mm,1);
end
% test dataset
X_test =zeros(2, T_test);
Y_test = zeros(2, T_test);
U_test = zeros(2, T_test);
for i = 1:T_test
    U_test(:,i) = 2 * rand(2, 1);  % Control input
    if i == 1
        X_test(:,i) = A * iv + B * U_test(:,i) + R * randn(n,1);
    else
        X_test(:,i) = A * X_test(:,i-1) + B * U_test(:,i) + R * randn(n,1);
    end
    Y_test(:,i) = C * X_test(:,i) + D * U_test(:,i) + Q * randn(mm,1);
end
%% Our method
n = 2; mm = 2; p = 2; Iter_Max = 400;
[A_our, B_our, C_our, D_our, R_our, Q_our] = Our_Method(Y_train, U_train, n, mm, p, T_train, Iter_Max, xi, rr);
%mean relative error
[X_our, Y_our] = state_space_init_predict(A_our, B_our, C_our, D_our, U_test, Y_test, T_test);
norm_relative_errors_our = calculate_relative_errors(Y_test, Y_our);
average_norm_relative_error_our = mean(norm_relative_errors_our);
%% Maximum Likelihood Estimation (MLE) method
[A_mle, B_mle, C_mle, D_mle, R_mle, Q_mle] = MLE(Y_train, U_train, n, mm, p, T_train, Iter_Max, xi, rr);
% mean relative error
[X_mle, Y_mle] = state_space_init_predict(A_mle, B_mle, C_mle, D_mle, U_test, Y_test, T_test);
norm_relative_errors_mle = calculate_relative_errors(Y_test, Y_mle);
average_norm_relative_error_mle = mean(norm_relative_errors_mle);
%% N4SID 
Ts = 1;
data = iddata(Y_train', U_train', Ts);
n4sid_sys = n4sid(data, 2);
% mean relative error
[X_n4sid_sys, Y_n4sid_sys] = state_space_init_predict(n4sid_sys.A, n4sid_sys.B, n4sid_sys.C, n4sid_sys.D, U_test, Y_test, T_test);
norm_relative_errors_n4sid = calculate_relative_errors(Y_test, Y_n4sid_sys);
average_norm_relative_error_n4sid = mean(norm_relative_errors_n4sid);
%% PEM
data = iddata(Y_train', U_train', 1);
A_in = eye(2); B_in = eye(2); C_in = eye(2); D_in = eye(2); K_in = eye(2); x0 = [1,1]';
sys_in = idss(A_in, B_in, C_in, D_in, K_in, x0);
pem_sys = pem(data, sys_in, 'Focus', 'prediction');
% mean relative error
[X_pem, Y_pem] = state_space_init_predict(pem_sys.A, pem_sys.B, pem_sys.C, pem_sys.D, U_test, Y_test, T_test);
norm_relative_errors_pem = calculate_relative_errors(Y_test, Y_pem);
average_norm_relative_error_pem = mean(norm_relative_errors_pem);
%% Display results
disp('Average Norm Relative Error (Our Method):');
disp(average_norm_relative_error_our);
disp('Average Norm Relative Error (MLE):');
disp(average_norm_relative_error_mle);
disp('Average Norm Relative Error (n4sid):');
disp(average_norm_relative_error_n4sid);
disp('Average Norm Relative Error (PEM):');
disp(average_norm_relative_error_pem);