clc;
clear all;

% Sparse plus low-rank transition matrices estimation by AltPGD
% Curves of squared estimation error vs. number of measurements under different ambient dimensions
% Curves of squared estimation error vs. number of measurements rescaled by Gaussian width under different ambient dimensions

addpath(genpath('..\Software_Linear_Convergence_VAR'));

nlength = 10;

%%%%%Different numbers of measurements
n = ceil([linspace(1000, 6000, nlength); linspace(1000, 8000, nlength); linspace(1000, 10000, nlength)]);

%%%%%Different ambient dimensions
d = [50, 75, 100];
dLength = length(d);

%%%%%Non-zero entries of sparse matrix
s = 300;
%%%%%Rank of low-rank matrix
r = 3;

%%%%%Maximum iteration
kmax = 100;
%%%%%Repeated trials
T  = 100;

%%%%%Initialization
squaredErrorTensor = zeros(T, dLength, nlength);

%%%%%Parallel processing
parfor t = 1 : T
%%%%%Ordinary processing
% for t = 1 : T
    for in = 1 : nlength
        
        
        for j = 1 : dLength
            
            %%%Construct Sparse matrix
            S = zeros(d(j), d(j));
            
            for i = 1 : d(j)
                temp = zeros(d(j), 1);
                temp(randperm(d(j), s / d(j))) = randn(s / d(j), 1);
                S(i, :) = temp;
            end
            
            %%%Construct low-rank matrix
            LL = randn(d(j), r);
            LR = randn(d(j), r);
            
            L = LL * LR';
            
            %%%Rescaling
            S = 0.9 * 5 * S / (norm(5 * S+ 0.5 * L));
            L = 0.9 * 0.5 * L / (norm(5 * S+ 0.5 * L));
            
            %%%Construct AR process
            E = 1*eye(d(j)) * randn(d(j), n(j, in));
            
            x0 = zeros(d(j), 1);
            [Y, xStart] = ARProduce(x0, S + L, E, n(j, in));
            
            X = [xStart'; Y(:, 1: (n(j, in)-1))'];
            
            
            
            %%%%%Initial point
            A0 = zeros(d(j), d(j));
            
            [squaredErrorComponent] = AltPGD_SparsePlusLowrank_ErrorScale(X, Y, A0, A0, kmax, S, L);
            
            squaredErrorTensor(t, j, in) = squaredErrorComponent;
            
        end
    end
    

end



squaredError = squeeze(sum(squaredErrorTensor, 1)) / T;

%%%%%Original squared estimation error vs. number of measurements
figure(1)
plot(n(1, :), squaredError(1, :), '-*','MarkerSize',6,'LineWidth',1.2);
hold on
plot(n(2, :), squaredError(2, :), '-o','MarkerSize',6,'LineWidth',1.2);
hold on
plot(n(3, :), squaredError(3, :), '-d','MarkerSize',6,'LineWidth',1.2);

xlabel('The number of samples $N$','fontsize',14,'interpreter','latex')
ylabel('Squared estimation error for \boldmath${\hat{S}}$ + \boldmath${\hat{L}}$','fontsize',14,'interpreter','latex')
legend('d = 50', 'd = 75', 'd = 100','fontsize',14);
axis tight

%%%%%Rescaled number of measurements
n1 = n(1, :) / (sqrt(2 * s * log(exp(1) * d(1) * d(1) / s)) + sqrt(6 * r * d(1)))^2;
n2 = n(2, :) / (sqrt(2 * s * log(exp(1) * d(2) * d(2) / s)) + sqrt(6 * r * d(2)))^2;
n3 = n(3, :) / (sqrt(2 * s * log(exp(1) * d(3) * d(3) / s)) + sqrt(6 * r * d(3)))^2;

%%%%%Squared estimation error vs. rescaled number of measurements
figure(2)

plot(n1, squaredError(1, :), '-*','MarkerSize',6,'LineWidth',1.2);
hold on
plot(n2, squaredError(2, :), '-o','MarkerSize',6,'LineWidth',1.2);
hold on
plot(n3, squaredError(3, :), '-d','MarkerSize',6,'LineWidth',1.2);


xlabel('The rescaled sample size $N / (\omega(C_{S} \cap S_{F}) + \omega(C_{L} \cap S_{F}))^2$','fontsize',14,'interpreter','latex')
ylabel('Squared estimation error for \boldmath${\hat{S}}$ + \boldmath${\hat{L}}$','fontsize',14,'interpreter','latex')

legend('d = 50', 'd = 75', 'd = 100','fontsize',14);
axis tight

