% Code for the toy example: L2 convergence rates

clear;
close all;


% dimension
n = 2;

% loss function
A = [1 1];
Q = A*A';
f = @(x) 1/2*(A*x-1)'*(A*x-1);

% exact gradient of loss function
gradf = @(x) A'*(A*x-1);

% minimum-norm solution
xast = lsqminnorm(A,1);

% number of iterates
final_iterate = 10^4;

% steps describes sequence of tracking loss and residuals
steps = 10;

% number of runs
MC_runs = 100;

% initialization of loss and residual
loss_SGD = zeros(MC_runs,final_iterate/steps+1);
res_SGD = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD = zeros(MC_runs,final_iterate/steps+1);
res_regSGD = zeros(MC_runs,final_iterate/steps+1);

% learning rate and regularization
xi = 1; % viscosity rate
p = 1/(4*xi+3); % regularization decay
q = (1+p)/2; % learning rate decay

beta = 2*q-1; % expected rate

tic;
for m=1:MC_runs
    
    % same initialization of SGD and reg-SGD
    x0 = randn(n,1);
    x_SGD = x0;
    x_regSGD = x0;

    loss_SGD_path = zeros(1,final_iterate/steps+1);
    res_SGD_path = zeros(1,final_iterate/steps+1);

    loss_regSGD_path = zeros(1,final_iterate/steps+1);
    res_regSGD_path = zeros(1,final_iterate/steps+1);


    loss_SGD_path(1) = f(x0);
    res_SGD_path(1) = norm(x0-xast);
    
    loss_regSGD_path(1) = f(x0);
    res_regSGD_path(1) = norm(x0-xast);
    
    
    % counter for tracking loss and residual
    counter = 1;
    for k=1:final_iterate
        
        % decaying regularization
        lambdak = 1/k^p;

        % learning rate for reg-SGD
        step_sizek = 1/k^q;

        % learning rate for SGD without reg
        step_sizek_literature = 1./k^(1/2);

        % generate synthetic noise for gradient evaluation, same realization
        % for SGD and reg-SGD
        noisygrad = 0.1*randn(n,1);
    
        % iterate SGD and reg-SGD
        x_SGD = x_SGD-step_sizek_literature*(gradf(x_SGD)+noisygrad);
        x_regSGD = x_regSGD-step_sizek*(gradf(x_regSGD)+noisygrad)-step_sizek*lambdak*x_regSGD;
        
        % track loss and residual
        if counter == steps
            loss_SGD_path(k/counter+1) = f(x_SGD);
            res_SGD_path(k/counter+1) = norm(x_SGD-xast)^2;
            loss_regSGD_path(k/counter+1) = f(x_regSGD);
            res_regSGD_path(k/counter+1) = norm(x_regSGD-xast)^2;
            counter = 0; % set counter to 0
        end
        counter = counter+1;
    end
    loss_SGD(m,:) = loss_SGD_path;
    res_SGD(m,:) = res_SGD_path;
    loss_regSGD(m,:) = loss_regSGD_path;
    res_regSGD(m,:) = res_regSGD_path;
end

% take running time
runningtime = toc; 

% uncomment if you want to save results
% save('results_L2.mat')


% colors and line style for plots
col = {[0.7 0 0],[0 0.5 0],[0 0 0.7],[0.7 0.2 0.7],[0.7 0.7 0.2]};
col_shadow = {[0.9 0 0 0.5],[0 0.7 0 0.5],[0 0 0.9 0.5],[0.9 0.2 0.9 0.5],[0.9 0.9 0.2 0.5]};
linest = {'-','-.','--',':',':'};



% plotting the loss
fig1 = figure(1);
clf(fig1)
set(fig1, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);

i=1;
plot_mean1 = loglog([1,steps:steps:final_iterate],mean(loss_SGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','SGD: $\mathbf{E}[f(X_k)-f^\ast]$');hold on

i=2;
plot_mean2 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD: $\mathbf{E}[f(X_k)-f^\ast]$');hold on

i=3;
iterate_indexes = [1+100*steps,steps:steps:final_iterate];
rate_plot = loglog(iterate_indexes,0.1*loss_regSGD(1)*iterate_indexes.^(-min(q-p,p)),'linestyle',linest{i},'Color',[0 0 0],'LineWidth',2,'DisplayName','theoretical $k^{-1/7}$');hold on

i=4;
rate_plot2 = loglog(iterate_indexes,0.001*loss_regSGD(1)*iterate_indexes.^(-1/2),'linestyle',linest{i},'Color',[0 0 0],'LineWidth',2,'DisplayName','theoretical $k^{-1/2}$');hold on

legend('show',[plot_mean1,plot_mean2,rate_plot,rate_plot2],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
grid on;

% plotting the residuals
fig2 = figure(2);
clf(fig2)
set(fig2, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);

i=1;
mean_plot1 = loglog([1,steps:steps:final_iterate],mean(res_SGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','SGD: $\mathbf{E}[\|X_k-x_\ast\|^2]$');hold on

i=2;
mean_plot2 = loglog([1,steps:steps:final_iterate],mean(res_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD: $\mathbf{E}[\|X_k-x_\ast\|^2]$');hold on

i=4;
iterate_indexes = [1+100*steps,steps:steps:final_iterate];
rate_plot = loglog(iterate_indexes,0.1*res_regSGD(1)*iterate_indexes.^(-(2*xi)/(4*xi+3)),'linestyle',linest{i},'Color',[0 0 0],'LineWidth',2,'DisplayName','theoretical $k^{-2/7}$');hold on

legend('show',[mean_plot1,mean_plot2,rate_plot],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
grid on;
