clc;
clear all;

% Compare the convergence rates of AltPGD and FNSL for sparse plus low-rank
% transition matrices estimation

addpath(genpath('..\Software_Linear_Convergence_VAR'));

%%%%%Ambient dimension
d = 100;
%%%%%Number of measurements
n = 8000;

%%%%%Non-zero entries of sparse matrix
s = 3500;
%%%%%Rank of low-rank matrix
r = 3;

%%%%%Repeated trials
T = 100;

%%%%%Initialization
errorFNSL = [];
errorAltPGD = [];


%%%%%Parallel processing
parfor tT = 1 : T
    %%%%%Ordinary processing
    % for tT = 1 : T
    
    %%%Construct Sparse matrix
    S = zeros(d, d);
    
    for i = 1 : d
        temp = zeros(d, 1);
        temp(randperm(d, s / d)) = randn(s / d, 1);
        S(i, :) = temp;
    end
    
    %%%Construct low-rank matrix
    LL = randn(d, r);
    LR = randn(d, r);
    
    L = LL * LR';
    
    %%%%%Rescaling
    S = 0.9 * 5 * S / (norm(5 * S+0.5 * L));
    L = 0.9 * 0.5 * L / (norm(5 * S+0.5 * L));
    
    
    
    %%%Construct AR process
    E = 1*eye(d) * randn(d, n);
    
    x0 = zeros(d, 1);
    [Y, xStart] = ARProduce(x0, S + L, E, n);
    
    X = [xStart'; Y(:, 1: (n-1))'];
    
    
    %%%%%Maximum iteration
    kmax = 150;
    
    
    %%%%%Initial points
    S0 = zeros(d, d);
    
    D = norm(svd(L), 1);
    L0 = D * eye(d,d);     %%%%%Stasify the requirement of AltPGD since the second step
    
    relativeError_FNSL = FNSL_sparsepluslowrank(X, Y, S0, L0, kmax, S, L);
    relativeError_AltPGD = AltPGD_SparsePlusLowrank(X, Y, S0, L0, kmax, S, L);
    
    
    errorFNSL = [errorFNSL; [norm(S0 - S + L0 - L, 'fro') / norm(S + L, 'fro'), relativeError_FNSL]];
    errorAltPGD = [errorAltPGD; [norm(S0 - S + L0 - L, 'fro') / norm(S + L, 'fro'), relativeError_AltPGD]];
    
    
end


figure(1)
ite=(1:size(errorFNSL,2));
curveFNSL = shadedErrorBar(ite, errorFNSL, {@mean,@std}, 'lineprops', '-b','patchSaturation',0.075);
curveFNSL.mainLine.LineWidth = 1.5;
hold on;
curveAltPGD = shadedErrorBar(ite, errorAltPGD, {@mean,@std}, 'lineprops', '-r','patchSaturation',0.075);
curveAltPGD.mainLine.LineWidth = 1.5;


set(gca, 'YScale', 'log')
axis([0, 150, 0.15, 10])
xlabel('Iteration', 'fontsize',15)
ylabel('Relative error of \boldmath$S+L$','interpreter','latex', 'fontsize',15)
legend('FNSL', 'AltPGD', 'fontsize',14);
grid on
