clc;
clear all;

% Performance comparison between AltPGD and FNSL for sparse plus low-rank transition matrices estimation
% Performance metrics
% TPR - Average true positive rate
% FAR - Average false alarm rate
% EE - Average relative estimation error
% Time - Total calculation time

addpath(genpath('..\Software_Linear_Convergence_VAR'));

%%%%%Ambient dimension
d = 100;
%%%%%Number of measurements
n = 2500;

%%%%%Non-zero entries of sparse transition matrix
s = 3500;
%%%%%Rank of low-rank transition matrix
r = 2;

%%%%%Repeated trails
T = 100;

%%%%%Initialization
TPR_Alt_sum = 0;
FAR_Alt_sum = 0;
EE_Alt_sum = 0;
PE_Alt_sum = 0;
Time_Alt_sum = 0;

TPR_FNSL_sum = 0;
FAR_FNSL_sum = 0;
EE_FNSL_sum = 0;
PE_FNSL_sum = 0;
Time_FNSL_sum = 0;


%%%%%Parallel processing
parfor t = 1 : T
%%%%%Ordinary processing
% for t = 1 : T
    
%%%Construct Sparse matrix
S = zeros(d, d);

for i = 1 : d
    temp = zeros(d, 1);
    temp(randperm(d, s / d)) = randn(s / d, 1);
    S(i, :) = temp;
end

%%%Construct low-rank matrix
LL = randn(d, r);
LR = randn(d, r);

L = LL * LR';

%%%Rescaling
S = 0.9 * 5 * S / (norm(5 * S+ 0.5 * L));
L = 0.9 * 0.5 * L / (norm(5 * S+ 0.5 * L));

%%%Construct AR process
E = 1*eye(d) * randn(d, n); 

x0 = zeros(d, 1);
[Y, xStart] = ARProduce(x0, S + L, E, n);

X = [xStart'; Y(:, 1: (n-1))'];

%%%%%Maximum iteration
kmax = 200;

%%%%%Initial point
A0 = zeros(d, d);


[TPR_FNSL, FAR_FNSL, EE_FNSL, Time_FNSL] = FNSL_sparsepluslowrank_Summary(X, Y, A0, A0, kmax, S, L);
[TPR_Alt, FAR_Alt, EE_Alt, Time_Alt] = AltPGD_SparsePlusLowrank_Summary(X, Y, A0, A0, kmax, S, L);

TPR_FNSL_sum = TPR_FNSL_sum + TPR_FNSL;
FAR_FNSL_sum = FAR_FNSL_sum + FAR_FNSL;
EE_FNSL_sum = EE_FNSL_sum + EE_FNSL;
Time_FNSL_sum = Time_FNSL_sum + Time_FNSL;

TPR_Alt_sum = TPR_Alt_sum + TPR_Alt;
FAR_Alt_sum = FAR_Alt_sum + FAR_Alt;
EE_Alt_sum = EE_Alt_sum + EE_Alt;
Time_Alt_sum = Time_Alt_sum + Time_Alt;

end

TPR_FNSL_avg = TPR_FNSL_sum / T;
FAR_FNSL_avg = FAR_FNSL_sum / T;
EE_FNSL_avg = EE_FNSL_sum / T;

TPR_Alt_avg = TPR_Alt_sum / T;
FAR_Alt_avg = FAR_Alt_sum / T;
EE_Alt_avg = EE_Alt_sum / T;

fprintf('FNSL: \n TPR: %f \t FAR: %f \t EE: %f \t Time: %f \n', TPR_FNSL_avg, FAR_FNSL_avg, EE_FNSL_avg, Time_FNSL_sum)
fprintf('AltPGD: \n TPR: %f \t FAR: %f \t EE: %f \t Time: %f \n', TPR_Alt_avg, FAR_Alt_avg, EE_Alt_avg, Time_Alt_sum)


