clc;
clear all;

% Performance comparison between PGD and FNSL for sparse transition matrices estimation
% Performance metrics
% TPR - Average true positive rate
% FAR - Average false alarm rate
% EE - Average relative estimation error
% Time - Total calculation time


addpath(genpath('..\Software_Linear_Convergence_VAR'));

%%%%%Ambient dimension
d = 100;                   
%%%%%Number of measurements
n = 1000;                  
%%%%%Non-zero entries
s = 3500;                  

%%%%%Repeated trials
T = 100;                   

%%%%%Initialization
TPR_PGD_sum = 0;
FAR_PGD_sum = 0;
EE_PGD_sum = 0;
Time_PGD_sum = 0;

TPR_FNSL_sum = 0;
FAR_FNSL_sum = 0;
EE_FNSL_sum = 0;
Time_FNSL_sum = 0;

%%%%%Parallel processing
parfor t = 1 : T
%%%%%Ordinary processing
% for t = 1 : T
    
%%%Construct Sparse matrix
S = zeros(d, d);

for i = 1 : d
    rowS = zeros(d, 1);
    rowS(randperm(d, s / d)) = randn(s / d, 1);
    S(i, :) = rowS;
end

%%%%%Recaling
S = 0.9 * S / norm(S);

%%%%%Construct AR process
E = 1*eye(d) * randn(d, n); 

x0 = zeros(d, 1);
[Y, xStart] = ARProduce(x0, S, E, n);

X = [xStart'; Y(:, 1: (n-1))'];

%%%%%Maximum iteration
kmax = 200;

%%%%%Initial point
A0 = zeros(d, d);


[TPR_FNSL, FAR_FNSL, EE_FNSL, Time_FNSL] = FNSL_sparse_Summary(X, Y, A0, kmax, S);
[TPR_PGD, FAR_PGD, EE_PGD, Time_PGD] = PGD_Sparse_Summary(X, Y, A0, kmax, S);


TPR_FNSL_sum = TPR_FNSL_sum + TPR_FNSL;
FAR_FNSL_sum = FAR_FNSL_sum + FAR_FNSL;
EE_FNSL_sum = EE_FNSL_sum + EE_FNSL;
Time_FNSL_sum = Time_FNSL_sum + Time_FNSL;

TPR_PGD_sum = TPR_PGD_sum + TPR_PGD;
FAR_PGD_sum = FAR_PGD_sum + FAR_PGD;
EE_PGD_sum = EE_PGD_sum + EE_PGD;
Time_PGD_sum = Time_PGD_sum + Time_PGD;

end

TPR_FNSL_avg = TPR_FNSL_sum / T;
FAR_FNSL_avg = FAR_FNSL_sum / T;
EE_FNSL_avg = EE_FNSL_sum / T;

TPR_PGD_avg = TPR_PGD_sum / T;
FAR_PGD_avg = FAR_PGD_sum / T;
EE_PGD_avg = EE_PGD_sum / T;

fprintf('FNSL: \n TPR: %f \t FAR: %f \t EE: %f \t Time: %f \n', TPR_FNSL_avg, FAR_FNSL_avg, EE_FNSL_avg, Time_FNSL_sum)
fprintf('AltPGD: \n TPR: %f \t FAR: %f \t EE: %f \t Time: %f \n', TPR_PGD_avg, FAR_PGD_avg, EE_PGD_avg, Time_PGD_sum)


