% Code for the toy example: comparison of different choices for p and q

clear;
close all;

% dimension
n = 2;

% loss function
A = [1 1];
Q = A*A';
f = @(x) 1/2*(A*x-1)'*(A*x-1);

% exact gradient of loss function
gradf = @(x) A'*(A*x-1);

% minimum-norm solution
xast = lsqminnorm(A,1);

% number of iterates
final_iterate = 10^5;

% steps describes sequence of tracking loss and residuals
steps = 10;

% number of runs
MC_runs = 10;

% initialization of loss and residual
loss_regSGD = zeros(MC_runs,final_iterate/steps+1);
res_regSGD = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD2 = zeros(MC_runs,final_iterate/steps+1);
res_regSGD2 = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD3 = zeros(MC_runs,final_iterate/steps+1);
res_regSGD3 = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD4 = zeros(MC_runs,final_iterate/steps+1);
res_regSGD4 = zeros(MC_runs,final_iterate/steps+1);


% learning rate and regularization

xi = 1; % viscosity rate

% selection 1
q = 2/3;
p = 1/(6*xi+3);

% selection 2
q2 = 2/3;
p2 = 0;

% selection 3
q3 = 1/2;
p3 = 2/3;

% selection 4
q4 = 2/7;
p4 = 1/9;

tic;
for m=1:MC_runs
    
    % same initialization of SGD and reg-SGD
    x0 = randn(n,1);
    x_regSGD = x0;
    x_regSGD2 = x0;
    x_regSGD3 = x0;
    x_regSGD4 = x0;

    
    loss_regSGD_path = zeros(1,final_iterate/steps+1);
    res_regSGD_path = zeros(1,final_iterate/steps+1);

    loss_regSGD_path2 = zeros(1,final_iterate/steps+1);
    res_regSGD_path2 = zeros(1,final_iterate/steps+1);

    loss_regSGD_path3 = zeros(1,final_iterate/steps+1);
    res_regSGD_path3 = zeros(1,final_iterate/steps+1);

    loss_regSGD_path4 = zeros(1,final_iterate/steps+1);
    res_regSGD_path4 = zeros(1,final_iterate/steps+1);
            
    
    loss_regSGD_path(1) = f(x0);
    res_regSGD_path(1) = norm(x0-xast);

    loss_regSGD_path2(1) = f(x0);
    res_regSGD_path2(1) = norm(x0-xast);
    
    loss_regSGD_path3(1) = f(x0);
    res_regSGD_path3(1) = norm(x0-xast);
    
    loss_regSGD_path4(1) = f(x0);
    res_regSGD_path4(1) = norm(x0-xast);
    
    
    
    % counter for tracking loss and residual
    counter = 1;
    for k=1:final_iterate
        % generate synthetic noise for gradient evaluation, same realization
        % for SGD and reg-SGD
        noisygrad = 1*randn(n,1);
        
        % decaying regularization
        lambdak = 1/k^p;
        lambdak2 = 1/k^p2;
        lambdak3 = 1/k^p3;
        lambdak4 = 1/k^p4;
        
        % learning rate for reg-SGD
        step_sizek = q*0.2/k^q;
        step_sizek2 = q2*0.2/k^q2;
        step_sizek3 = q3*0.2/k^q3;
        step_sizek4 = q4*0.2/k^q4;
   
        % iterate reg-SGD
        x_regSGD = x_regSGD-step_sizek*(gradf(x_regSGD)+noisygrad)-step_sizek*lambdak*x_regSGD;
        x_regSGD2 = x_regSGD2-step_sizek2*(gradf(x_regSGD2)+noisygrad)-step_sizek2*lambdak2*x_regSGD2;
        x_regSGD3 = x_regSGD3-step_sizek3*(gradf(x_regSGD3)+noisygrad)-step_sizek3*lambdak3*x_regSGD3;
        x_regSGD4 = x_regSGD4-step_sizek4*(gradf(x_regSGD4)+noisygrad)-step_sizek4*lambdak4*x_regSGD4;

        

        % track loss and residual
        if counter == steps
            
            loss_regSGD_path(k/counter+1) = f(x_regSGD);
            res_regSGD_path(k/counter+1) = norm(x_regSGD-xast)^2;
            loss_regSGD_path2(k/counter+1) = f(x_regSGD2);
            res_regSGD_path2(k/counter+1) = norm(x_regSGD2-xast)^2;
            loss_regSGD_path3(k/counter+1) = f(x_regSGD3);
            res_regSGD_path3(k/counter+1) = norm(x_regSGD3-xast)^2;
            loss_regSGD_path4(k/counter+1) = f(x_regSGD4);
            res_regSGD_path4(k/counter+1) = norm(x_regSGD4-xast)^2;
            counter = 0; % set counter to 0
        end
        counter = counter+1;
    end

    loss_regSGD(m,:) = loss_regSGD_path;
    res_regSGD(m,:) = res_regSGD_path;

    loss_regSGD2(m,:) = loss_regSGD_path2;
    res_regSGD2(m,:) = res_regSGD_path2;

    loss_regSGD3(m,:) = loss_regSGD_path3;
    res_regSGD3(m,:) = res_regSGD_path3;

    loss_regSGD4(m,:) = loss_regSGD_path4;
    res_regSGD4(m,:) = res_regSGD_path4;
end
% take running time
runningtime = toc; 

% uncomment if you want to save results
%save('results_comparison.mat')

% colors and line style for plots
col = {[0.7 0 0],[0 0.5 0],[0 0 0.7],[0.7 0.5 0.05]};
col_shadow = {[0.9 0 0 0.2],[0 0.7 0 0.2],[0 0 0.9 0.2],[0.9290 0.6940 0.1250 0.2]};
linest = {'-','-.','--',':',':'};

% plotting the loss
fig1 = figure(1);
clf(fig1)
set(fig1, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);

i=4;
loglog([1,steps:steps:final_iterate],loss_regSGD4,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p4,q4));hold on
plot_mean4 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD4,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%.2f,%.2f)', p4,q4));hold on

i=2;
loglog([1,steps:steps:final_iterate],loss_regSGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p,q));hold on
plot_mean1 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%.2f,%.2f)', p,q));hold on

i=1;
loglog([1,steps:steps:final_iterate],loss_regSGD2,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p2,q2));hold on
plot_mean2 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD2,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%.2f,%.2f)', p2,q2));hold on

i=3;
loglog([1,steps:steps:final_iterate],loss_regSGD3,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p3,q3));hold on
plot_mean3 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD3,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%.2f,%.2f)', p3,q3));hold on

title('$f(X_k)-f^\ast$','Interpreter','latex','FontSize',16)
legend('show',[plot_mean1,plot_mean2,plot_mean3,plot_mean4],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.5)
grid on;


% plotting the residuals
fig2 = figure(2);
clf(fig2)
set(fig2, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);

i=4;
loglog([1,steps:steps:final_iterate],res_regSGD4,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot4 = loglog([1,steps:steps:final_iterate],mean(res_regSGD4,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p4,4));hold on

i=2;
loglog([1,steps:steps:final_iterate],res_regSGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot1 = loglog([1,steps:steps:final_iterate],mean(res_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p,q));hold on

i=1;
loglog([1,steps:steps:final_iterate],res_regSGD2,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot2 = loglog([1,steps:steps:final_iterate],mean(res_regSGD2,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p2,q2));hold on

i=3;
loglog([1,steps:steps:final_iterate],res_regSGD3,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot3 = loglog([1,steps:steps:final_iterate],mean(res_regSGD3,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName',sprintf('(p,q) = (%d,%d)', p3,q3));hold on

title('$\|X_k-x_\ast\|^2$','Interpreter','latex','FontSize',16)
legend('show',[plot_mean1,plot_mean2,plot_mean3,plot_mean4],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.5)
grid on;
