% Radon transform

clear;
close all;

% load reference image
load mri
D = squeeze(D);
P = double(D(:,:,10)); 

% collect dimension
n = size(P,1);

% view image on [-1,1]^2
% discretization of [-1,1]
x = -1 + (2/n)*(0:n); 

% 2d representation
[Y,X] = meshgrid(x(1:n)); 
xd = X(:)'; yd = Y(:)';
[Y,X] = meshgrid(x(2:end));
xu = X(:)'; yu = Y(:)';

% number of all possible illumination angles 
M = 32; 
% number of parallel rays per illumination angle
K = 100; 
% equally spaced illumination angles
angles_all = pi/M * (0:M-1); 

% set up forward model based on Radon transform
A = radon_forward(K,n,xd,yd,xu,yu,angles_all);

% generating noisefree data
y = A*P(:);

% loss function
f = @(x) 1/2*(A*x-y)'*(A*x-y);

% exact gradient of loss function
gradf = @(x) A'*(A*x-y);

% minimum-norm solution
xast = lsqminnorm(A,y);

% initialize SGD and reg-SGD at zero
x0 = zeros(n,n);
x0 = x0(:);

% number of iterates
final_iterate = 10^3;

% steps describes sequence of tracking loss and residuals
steps = 100;


% learning rate and regularization
q = 2/3; % learning rate decay 
p = 1/3; % regularization decay

beta = 2*q-1; % expected rate

% number of runs
MC_runs = 10;

% initialization of loss and residual
loss_SGD = zeros(MC_runs,final_iterate/steps+1);
res_SGD = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD = zeros(MC_runs,final_iterate/steps+1);
res_regSGD = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD2 = zeros(MC_runs,final_iterate/steps+1);
res_regSGD2 = zeros(MC_runs,final_iterate/steps+1);

% save reconstructions
recon_SGD =zeros(MC_runs,n^2);
recon_regSGD =zeros(MC_runs,n^2);
recon_regSGD2 =zeros(MC_runs,n^2);

tic;
for m=1:MC_runs

    % initialization of loss and residual
    loss_SGD_paths = zeros(1,final_iterate/steps+1);
    res_SGD_paths = zeros(1,final_iterate/steps+1);
    loss_regSGD_paths = zeros(1,final_iterate/steps+1);
    res_regSGD_paths = zeros(1,final_iterate/steps+1);
    loss_regSGD_paths2 = zeros(1,final_iterate/steps+1);
    res_regSGD_paths2 = zeros(1,final_iterate/steps+1);
    
    % same initialization of SGD and reg-SGD
    x_SGD = x0;
    loss_SGD_paths(1,1) = f(x0);
    res_SGD_paths(1,1) = norm(x0-xast)^2;
    
    x_regSGD = x0;
    loss_regSGD_paths(1,1) = f(x0);
    res_regSGD_paths(1,1) = norm(x0-xast)^2;

    x_regSGD2 = x0;
    loss_regSGD_paths2(1,1) = f(x0);
    res_regSGD_paths2(1,1) = norm(x0-xast)^2;
    
    % counter for tracking loss and residual
    counter = 1;
    for k=1:final_iterate
        
        % generate synthetic noise for gradient evaluation, same realization
        % for SGD and reg-SGD
        noisygrad = 0.5*randn(n^2,1);

        % decaying regularization
        lambdak = 0.01/k^p;
        lambdak2 = 0.01/k^(2/3);
        
        % learning rate for reg-SGD
        step_sizek = 20/k^q;
        step_sizek2 = 20/k^(q);
        
        % learning rate for SGD without reg
        step_sizek_literature = 20/k^(1/2);    
    
        
        % random sample of angle
        rand_index = randi(M);
        A_index = A(1+100*(rand_index-1):100*rand_index,:);
        y_index = y(1+100*(rand_index-1):100*rand_index,1);
        grad_SGD = A_index'*(A_index*x_SGD-y_index);
        grad_regSGD = A_index'*(A_index*x_regSGD-y_index);
        grad_regSGD2 = A_index'*(A_index*x_regSGD2-y_index);

        % iterate SGD and reg-SGD
        x_SGD = x_SGD-step_sizek_literature*(grad_SGD+noisygrad);
        x_regSGD = x_regSGD-step_sizek*(grad_regSGD+noisygrad)-step_sizek*lambdak*x_regSGD;
        x_regSGD2 = x_regSGD2-step_sizek2*(grad_regSGD2+noisygrad)-step_sizek2*lambdak2*x_regSGD2;

        % track loss and residual
        if counter == steps
            loss_SGD_paths(1,k/counter+1) = f(x_SGD);
            res_SGD_paths(1,k/counter+1) = norm(x_SGD-xast)^2;
            loss_regSGD_paths(1,k/counter+1) = f(x_regSGD);
            res_regSGD_paths(1,k/counter+1) = norm(x_regSGD-xast)^2;
            loss_regSGD_paths2(1,k/counter+1) = f(x_regSGD2);
            res_regSGD_paths2(1,k/counter+1) = norm(x_regSGD2-xast)^2;
            counter = 0; % set counter to 0
        end
        counter = counter+1;
    end
    loss_SGD(m,:) = loss_SGD_paths;
    res_SGD(m,:) = res_SGD_paths;
    loss_regSGD(m,:) = loss_regSGD_paths;
    res_regSGD(m,:) = res_regSGD_paths;
    loss_regSGD2(m,:) = loss_regSGD_paths2;
    res_regSGD2(m,:) = res_regSGD_paths2;
    recon_SGD(m,:) = x_SGD;
    recon_regSGD(m,:) = x_regSGD;
    recon_regSGD2(m,:) = x_regSGD2;
end

% take running time
runningtime = toc; 

% uncomment if you want to save results
% save('results_radon.mat')


% colors and line style for plots
col = {[0.7 0 0],[0 0.5 0],[0 0 0.7],[0.7 0.2 0.7],[0.7 0.7 0.2]};
col_shadow = {[0.9 0 0 0.5],[0 0.7 0 0.5],[0 0 0.9 0.5],[0.9 0.2 0.9 0.5],[0.9 0.9 0.2 0.5]};
linest = {'-','-.','--',':',':'};


% plotting the ground truth
fig1 = figure(1);
clf(fig1)
set(fig1, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);

mycolormap = magma();
imagesc(P), colormap(mycolormap);
title('ground truth','FontSize',18,'FontWeight','normal')

fig2 = figure(2);
clf(fig2)
set(fig2, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);

% plotting data
sino = reshape(y,K,M); % form the sinogram
imagesc(sino), colormap(mycolormap);
title('observation (noisefree)','FontSize',18,'FontWeight','normal')

% plotting minimum-norm solution
fig3 = figure(3);
clf(fig3)
set(fig3, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);
imagesc(reshape(xast,n,n)), colormap(mycolormap);%colormap bone
title('min-norm solution','FontSize',18,'FontWeight','normal')

% plotting reconstruction via SGD
fig4 = figure(4);
clf(fig4)
set(fig4, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);
imagesc(reshape(recon_SGD(1,:),n,n)), colormap(mycolormap);%colormap bone
title('reconstruction SGD','FontSize',18,'FontWeight','normal')

% plotting reconstruction via reg-SGD (ours)
fig5 = figure(5);
clf(fig5)
set(fig5, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);
imagesc(reshape(recon_regSGD(1,:),n,n)), colormap(mycolormap);%colormap bone
title('reconstruction reg-SGD (ours)','FontSize',18,'FontWeight','normal')

% plotting reconstruction via reg-SGD (greedy)
fig6 = figure(6);
clf(fig6)
set(fig6, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.3, 0.3]);
imagesc(reshape(recon_regSGD2(1,:),n,n)), colormap(mycolormap);%colormap bone
title('reconstruction reg-SGD (aggressive)','FontSize',18,'FontWeight','normal')


% plotting loss
fig7 = figure(7);
clf(fig7)
set(fig7, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);

i=1;
loglog([1,steps:steps:final_iterate],loss_SGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
plot_mean1 = loglog([1,steps:steps:final_iterate],mean(loss_SGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','SGD');hold on


i=2;
loglog([1,steps:steps:final_iterate],loss_regSGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
plot_mean2 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD (ours)');hold on

i=3;
loglog([1,steps:steps:final_iterate],loss_regSGD2,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot3 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD2,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD (aggressive)');hold on

iterate_indexes = [1+100*steps,steps:steps:final_iterate];
rate_plot = loglog(iterate_indexes,0.1*loss_regSGD(1)*iterate_indexes.^(-min(beta,p)),'linestyle',linest{2},'Color',[0 0 0],'LineWidth',2,'DisplayName','theoretical $k^{-1/3}$');hold on

i=4;
rate_plot2 = loglog(iterate_indexes,0.01*loss_SGD(1)*iterate_indexes.^(-1/2),'linestyle',linest{i},'Color',[0 0 0],'LineWidth',2,'DisplayName','theoretical $k^{-1/2}$');hold on

title('$f(X_k)-f(x_\ast)$','Interpreter','latex','FontSize',16)
legend('show',[plot_mean1,plot_mean2,mean_plot3,rate_plot,rate_plot2],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
grid on;

ylim([1,10^8])
xlim([1,10^7])



% plotting residuals
fig8 = figure(8);
clf(fig8)
set(fig8, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);

i=1;
loglog([1,steps:steps:final_iterate],res_SGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot1 = loglog([1,steps:steps:final_iterate],mean(res_SGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','SGD');hold on

i=2;
loglog([1,steps:steps:final_iterate],res_regSGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot2 = loglog([1,steps:steps:final_iterate],mean(res_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD (ours)');hold on

i=3;
loglog([1,steps:steps:final_iterate],res_regSGD2,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2);hold on
mean_plot3 = loglog([1,steps:steps:final_iterate],mean(res_regSGD2,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD (aggressive)');hold on

title('$\|X_k-x_\ast\|^2$','Interpreter','latex','FontSize',16)
legend('show',[mean_plot1,mean_plot2,mean_plot3],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
grid on;
xlim([1,10^7])
