% This code depends non-open-source software MATLAB.

clear all
 

% p - The dimension of the data (number of features).
p = 300;
% n - The number of iterations
n = 10^6;
% nrep - The number of repetitions. 
nrep =10;
nmeth=1;

% a - the alpha of the capacity condition. 
a = 3/4; 
% b - the beta of the capacity condition.
b = 1;

%% For experiment 1 in the paper, check 
% a = 1/2
% b = 0 
%%
 
% Sample points from the iterations to store and plot later.
sampling_test_moments = unique(round(10.^[0:.02:log10(n)]));
indices_test_points = zeros(n,1);
indices_test_points(sampling_test_moments) = 1:length(sampling_test_moments);
indices_test_points(1)=1;

% to store the last iterate across various repetitions
testw = zeros(length(sampling_test_moments) ,nmeth,nrep);
% to store the last iterate across various repetitions
testwave = zeros(length(sampling_test_moments) ,nmeth,nrep);

% Choose the covariance such that its eigen value decay as  1/i^(1/(1-a))
cova =  randn(p,p);
[u,s,v] = svd(cova);
cova = u * diag( [1./(1:p).^(.5/(1-a))]) * v';

% minimum eigen value
mu= 1/p^(1/(1-a));

% w0 satisfying the source condition
w0 =(ones(1,p)* diag([1./(1:p).^(.5+.5*b/(1-a))]) *v')';
w0 = sqrt(2)*w0 / sqrt(w0'* cova' * cova * w0);

% Initialize gamma to 1/2Tr(H), where H is covariance.
R = sqrt(trace(cova'*cova));
w_ini=zeros(p,1);
gamma=1/R^2/2;

% run nrep times for averaging. 
for irep=1:nrep
 
   X = randn(n,p) * cova;
   % no noise the data fits perfectly.
   y= X*w0;
    
    for imethod = 1:nmeth
        [irep, imethod]
        
        switch imethod
  
            case 1, %average sgd                    
               
                w = w_ini;
                wave = w_ini;
 
                for i=1:n
                    
                    if indices_test_points(i)
                        testw(indices_test_points(i),imethod,irep) =  .5 * norm(cova*(w-w0))^2;
                        testwave(indices_test_points(i),imethod,irep) =  .5 * norm(cova*(wave-w0))^2;                         
                    end  
                % SGD step.               
                 grad = X(i,:)'*(X(i,:)*w-y(i));          
                 w = w -  gamma *  grad ;
                % compute the average and store as wave. 
                 wave = wave +1/(1+i)*(w-wave);
                end
           
        end
    end
end
figure
% Plotting Last iterate.
loglog((sampling_test_moments), (mean(testw,3)),'linewidth',6); hold on;
% Plotting average.
loglog((sampling_test_moments), (mean(testwave,3)),'linewidth',6); hold on;
% Plotting the reference curve to show convergence.
loglog((sampling_test_moments), 30./(sampling_test_moments).^(1+min(a,b)),'--','linewidth',5); hold on;

set(gca,'FontSize',20)

legend({ 'Last iterate', 'Averaged iterate'},'Location','Southwest', 'FontSize',20)
%title('$n=10^6$ $d = 300$ $\alpha=0.5$ $\beta=0$', 'Interpreter', 'latex','FontSize',30)
title('$n=10^6$ $d = 300$ $\alpha=0.75$ $\beta=1$', 'Interpreter', 'latex','FontSize',30)

axis([0 10^6 10^(-13) 10])

% Show the regime of linear convergence.
yL = get(gca,'YLim');
line([1/(gamma*mu) 1/(gamma*mu)],yL,'Color','black','LineStyle','--','linewidth',1,'HandleVisibility','off');
xlabel('Number of iterations', 'Interpreter','latex','FontSize',30)
ylabel('$\mathcal{R}(\theta)$', 'Interpreter','latex','FontSize',30)