% Code for the ODE example
clear;
close all;

% mesh size
ell = 8;
% number of observation points
K = 2^6-1;

% dimension of the discretized system
I = 2^ell;            
% dimension of parameters
n = I-1;

% domain with boundary
xx_b = linspace(0,1,I+1); 
% domain without boundary
xx = xx_b(2:end-1);            

% generating random function using KL Basis
N_x = 100;
V_basis = zeros(I-1,N_x);
for j = 1:N_x
    V_basis(:,j) = sqrt(2*pi)*sin(j*pi*xx)';
end
V_basis_b = zeros(I+1,N_x);
for j = 1:N_x
    V_basis_b(:,j) = sqrt(2*pi)*sin(j*pi*xx_b)';
end

D = zeros(N_x,N_x);
for j = 1:N_x
    D(j,j) = (j^2)^(-4/2);
end

%%% generate reference rhs: uncomment for new references
% m_ref = V_basis*sqrt(D)*randn(N_x,1);
% save('m_ref.mat','m_ref');

% load reference rhs
load('m_ref.mat','m_ref');

% forward model
A = observation_matrix(K,I)*(A_ellipticPDE(I-1)\eye(I-1));
y = A*m_ref;

% objective function and gradient
f = @(x) 1/2*(A*x-y)'*(A*x-y);
gradf = @(x) A'*(A*x-y);

% minimum-norm solution
xast = lsqminnorm(A,y);
% global minimum
f_min = f(xast);

% initialization at zero
x0 = zeros(I-1,1);

% number of iterates
final_iterate = 10^4;

% steps describes sequence of tracking loss and residuals
steps = 100;

% all iterates
iterates = 1:final_iterate;

% learning rate and regularization
q = 2/3; % regularization decay
p = 1/3; % learning rate decay

beta = 2*q-1; % expected rate

% number of runs
MC_runs = 10;

% initialization of loss and residual
loss_SGD = zeros(MC_runs,final_iterate/steps+1);
res_SGD = zeros(MC_runs,final_iterate/steps+1);
loss_regSGD = zeros(MC_runs,final_iterate/steps+1);
res_regSGD = zeros(MC_runs,final_iterate/steps+1);

recon_SGD =zeros(n,MC_runs);
recon_regSGD =zeros(n,MC_runs);

tic;

for m=1:MC_runs
    
    % initialization of loss and residual
    loss_SGD_paths = zeros(1,final_iterate/steps+1);
    res_SGD_paths = zeros(1,final_iterate/steps+1);
    loss_regSGD_paths = zeros(1,final_iterate/steps+1);
    res_regSGD_paths = zeros(1,final_iterate/steps+1);

    % same initialization of SGD and reg-SGD
    x_SGD = x0;
    loss_SGD_paths(1,1) = f(x0);
    res_SGD_paths(1,1) = norm(x0-xast)^2;
    
    x_regSGD = x0;
    loss_regSGD_paths(1,1) = f(x0);
    res_regSGD_paths(1,1) = norm(x0-xast)^2;
    
    % counter for tracking loss and residual
    counter = 1;
    for k=1:final_iterate
        
        % generate synthetic noise for gradient evaluation, same realization
        % for SGD and reg-SGD
        noisygrad = 0.001*randn(n,1);
    
        % decaying regularization
        lambdak = 0.002/k^p;

        % learning rate for reg-SGD
        step_sizek = 100/k^q;
        
        % learning rate for SGD without reg
        step_sizek_literature = 100/k^(1/2);    
    
      
        % random batch of observation locations
        rand_index = randi(K,2^4,1);
        A_index = A(rand_index,:); 
        y_index = y(rand_index,1);
        grad_SGD = A_index'*(A_index*x_SGD-y_index);
        grad_regSGD = A_index'*(A_index*x_regSGD-y_index);
    
        % iterate SGD and reg-SGD
        x_SGD = x_SGD-step_sizek_literature*(grad_SGD+noisygrad);
        x_regSGD = x_regSGD-step_sizek*(grad_regSGD+noisygrad)-step_sizek*lambdak*x_regSGD;
        
        % track loss and residual
        if counter == steps
            loss_SGD_paths(1,k/counter+1) = f(x_SGD);
            res_SGD_paths(1,k/counter+1) = norm(x_SGD-xast)^2;
            loss_regSGD_paths(1,k/counter+1) = f(x_regSGD);
            res_regSGD_paths(1,k/counter+1) = norm(x_regSGD-xast)^2;
            counter = 0; % set counter to 0
        end
        counter = counter+1;
    end
    loss_SGD(m,:) = loss_SGD_paths;
    res_SGD(m,:) = res_SGD_paths;
    loss_regSGD(m,:) = loss_regSGD_paths;
    res_regSGD(m,:) = res_regSGD_paths;
    recon_SGD(:,m) = x_SGD;
    recon_regSGD(:,m) = x_regSGD;
end

% take running time
runningtime = toc; 

% uncomment if you want to save results
% save('results_ODE.mat')


% colors and line style for plots
col = {[0.7 0 0],[0 0.5 0],[0 0 0.7],[0.7 0.2 0.7],[0.7 0.7 0.2]};
col_shadow = {[0.9 0 0 0.5],[0 0.7 0 0.5],[0 0 0.9 0.5],[0.9 0.2 0.9 0.5],[0.9 0.9 0.2 0.5]};
linest = {'-','-.','--',':',':'};


% plotting the loss
fig1 = figure(1);
clf(fig1)
set(fig1, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);
i=1;

loglog([1,steps:steps:final_iterate],loss_SGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',1,'DisplayName','SGD: $f(X_k)-f^\ast$');hold on
plot_mean1 = loglog([1,steps:steps:final_iterate],mean(loss_SGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','SGD: $f(X_k)-f^\ast$');hold on

i=2;
loglog([1,steps:steps:final_iterate],loss_regSGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',1,'DisplayName','reg-SGD: $f(X_k)-f^\ast$');hold on
plot_mean2 = loglog([1,steps:steps:final_iterate],mean(loss_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD: $f(X_k)-f^\ast$');hold on

i=3;
iterate_indexes = [1+100*steps,steps:steps:final_iterate];
rate_plot = loglog(iterate_indexes,loss_regSGD(1)*iterate_indexes.^(-min(beta,p)),'linestyle',linest{i},'Color',[0 0 0],'LineWidth',2,'DisplayName','$k^{-1/3}$');hold on

i=4;
rate_plot2 = loglog(iterate_indexes,0.01*loss_SGD(1)*iterate_indexes.^(-1/2),'linestyle',linest{i},'Color',[0 0 0],'LineWidth',2,'DisplayName','$k^{-1/2}$');hold on

legend('show',[plot_mean1,plot_mean2,rate_plot,rate_plot2],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
grid on;

% plotting the residuals
fig2 = figure(2);
clf(fig2)
set(fig2, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);

i=1;
loglog([1,steps:steps:final_iterate],res_SGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',1);hold on
mean_plot1 = loglog([1,steps:steps:final_iterate],mean(res_SGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','SGD: $\|X_k-x_\ast\|^2$');hold on

i=2;
loglog([1,steps:steps:final_iterate],res_regSGD,'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',1);hold on
mean_plot2 = loglog([1,steps:steps:final_iterate],mean(res_regSGD,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD: $\|X_k-x_\ast\|^2$');hold on

legend('show',[mean_plot1,mean_plot2],'Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
grid on;

% plotting the rhs reconstruction
fig3 = figure(3);
clf(fig3)
set(fig3, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);

i=1;
plot(xx,recon_SGD(:,1),'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2,'DisplayName','SGD');hold on
i=2;
plot(xx,recon_regSGD(:,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD');hold on
i=4;
plot(xx,m_ref,'linestyle',':','Color',[0 0 0],'LineWidth',2,'DisplayName','$x^\dagger$');hold on
i=3;
plot(xx,xast,'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','$x_\ast$');hold on

xlabel('s')
legend('show','Interpreter','latex','FontSize',14,'location','NorthWest','BackgroundAlpha',0.8)

% plotting the ode estimation
fig4 = figure(4);
clf(fig4)
set(fig4, 'Units', 'normalized', 'Position', [0.1, 0.5, 0.25, 0.25]);
i=1;
plot(xx,(A_ellipticPDE(I-1)\eye(I-1))*recon_SGD(:,1),'linestyle',linest{i},'Color',col_shadow{i},'LineWidth',2,'DisplayName','SGD');hold on
i=2;
plot(xx,(A_ellipticPDE(I-1)\eye(I-1))*recon_regSGD(:,1),'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','reg-SGD');hold on
i=3;
plot(xx,(A_ellipticPDE(I-1)\eye(I-1))*m_ref,'linestyle','-','Color',[0 0 0],'LineWidth',2,'DisplayName','$p_{ref}$');hold on
plot(xx,(A_ellipticPDE(I-1)\eye(I-1))*xast,'linestyle',linest{i},'Color',col{i},'LineWidth',2,'DisplayName','$p_\ast$');hold on

xlabel('s')
legend('show','Interpreter','latex','FontSize',14,'location','SouthWest','BackgroundAlpha',0.8)
