function [sample_vec1,sample_vec2,sample_vec_proj,sample_vec_dbgd,...
   rec_vec1,rec_vec2,rec_vec_proj,rec_vec_dbgd,...
   loss_vec_up1,loss_vec_up2,loss_vec_up_proj,loss_vec_up_dbgd,...
   loss_vec_lo1,loss_vec_lo2,loss_vec_lo_proj,loss_vec_lo_dbgd]=dict_learning_sto(seed,figs)
% generate the data
p = 50; % number of atoms
m = 25; % dimension of data
n_old = 250; % number of lower-level samples
n_new = 250; % number of upper-level samples
thres = 0.9; % threshold for recovery rate
sigma = 0.01; % noise level
max_time = 39;  % max running time

% seed = 2020;
rng(seed)
% figs = true;

Dstar = randn(m,p); % true dictionary
Dstar = Dstar./vecnorm(Dstar); % normalize the atoms 
k_spar = 5; % number of nonzeros in each coefficient vector
mask = zeros(p,n_old);
for i=1:n_old
    mask(randsample(round(4*p/5),k_spar),i)=(0.8*rand(k_spar,1)+0.2).*(2*randi([0 1],k_spar,1)-1);
end
X_old = mask; % true coefficient matrix for the old dataset
A_old = Dstar*X_old+sigma*randn(m,n_old);

mask = zeros(p,n_new);
for i=1:n_new
    mask(randsample(round(3*p/5):p,k_spar),i)=(0.8*rand(k_spar,1)+0.2).*(2*randi([0 1],k_spar,1)-1);
end
X_new = mask; % true coefficient matrix for the new dataset
A_new = Dstar*X_new+sigma*randn(m,n_new);


%% solving the lower-level problem initially by CG
epsilon_g = 1e-6;
maxiter = 1e4;
delta = 3;

param.delta = delta;
param.thres = thres;
param.p = round(4*p/5);
param.eps = epsilon_g;
param.eps2 = 1e-10;
param.maxiter = maxiter;

disp('Initialization starts!');
[D_last,X_last,~,~] = CG_init(A_old,Dstar,param);
disp('Initialization done!');

p_res = p-round(4*p/5);
D_last = [D_last,zeros(m,p_res)];
X_last = [X_last;zeros(p_res,n_old)];
loss_0 = norm(A_old-D_last*X_last,'fro')^2/2;

% %% solving the lower-level problem initially by SPIDER
% epsilon_g = 1e-6;
% maxiter = 1e3;
% delta = 3;
% 
% param.delta = delta;
% param.thres = thres;
% param.p = round(4*p/5);
% param.eps = epsilon_g;
% param.eps2 = 1e-10;
% param.maxiter = maxiter;
% 
% disp('Initialization starts!');
% [D_last2,X_last2,g_hist2,rec_init2,sample2] = SPIDER_init(A_old,Dstar,param);
% disp('Initialization done!');
% 
% g_hist2 = g_hist2/n_old;
% p_res = p-round(4*p/5);
% D_last2 = [D_last2,zeros(m,p_res)];
% X_last2 = [X_last2;zeros(p_res,n_old)];
% loss_0 = norm(A_old-D_last2*X_last2,'fro')^2/2;

%% initial for upper level
X_up_init = randn(p,n_new);
X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;


% %% CG-BiO
% 
% gamma0 = .3;
% maxiter_up = 5e4;
% 
% param.delta = delta;
% param.maxtime = max_time;
% param.gamma0 = gamma0; % initial stepsize
% param.thres = thres; % threshold for recovery
% param.p = p;
% param.maxiter = maxiter_up;
% 
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
% [loss_vec_up,loss_vec_lo,rec_vec,time_vec] = CG_BiO(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% % normalize the loss
% loss_vec_up = loss_vec_up/n_new;
% loss_vec_lo = (loss_vec_lo-loss_0)/n_old;

%% CG-SBO1
% seed = 2020;
% rng(seed)
param.maxiter = 4e4;
param.gamma0 = 0.1;
param.maxtime = max_time;
param.delta = delta;
param.thres = thres; % threshold for recovery
param.p = p;

% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up1,loss_vec_lo1,rec_vec1,time_vec1,sample_vec1] = CG_SBO1(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up1 = loss_vec_up1/n_new;
loss_vec_lo1 = (loss_vec_lo1-loss_0)/n_old;

%% CG-SBO2
param.maxiter = 2e4;
param.gamma0 = 1e-3;
% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up2,loss_vec_lo2,rec_vec2,time_vec2,sample_vec2] = CG_SBO2(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up2 = loss_vec_up2/n_new;
loss_vec_lo2 = (loss_vec_lo2-loss_0)/n_old;



%% The same stepsize but without cutting plane
% X_up_init = zeros(p,n_new);
% [loss_vec_up2,loss_vec_lo2,rec_vec2,time_vec2] = CG_upper(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% loss_vec_up2 = loss_vec_up2/n_new;
% loss_vec_lo2 = (loss_vec_lo2-loss_0)/n_old;

% %% Yousefian's algorithm
% disp('a-IRG algorithm starts!')
% [loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj] = Alg_projection(A_old,X_last,A_new,Dstar,...
%     D_last,X_up_init,param);
% disp('a-IRG algorithm done!');
% 
% loss_vec_up_proj = loss_vec_up_proj/n_new;
% loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;
% 

%% Yousefian's algorithm
param.maxiter = 2e4;
disp('a-IRG algorithm starts!')
[loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj, sample_vec_proj] = Alg_projection_sto(A_old,X_last,A_new,Dstar,...
    D_last,X_up_init,param);
disp('a-IRG algorithm done!');

loss_vec_up_proj = loss_vec_up_proj/n_new;
loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;


%% DBGD-sto
param.alpha = 100;
param.beta = 100;
param.stepsize = 5e-3;
param.delta = delta;
param.thres = thres;

% param.maxtime = 39;
param.maxiter = 4e4;

disp('DBGD algorithm starts!')
[loss_vec_up_dbgd,loss_vec_lo_dbgd,rec_vec_dbgd,time_vec_dbgd,sample_vec_dbgd] = DBGD_sto(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
disp('DBGD algorithm done!');

loss_vec_up_dbgd = loss_vec_up_dbgd/n_new;
loss_vec_lo_dbgd = (loss_vec_lo_dbgd-loss_0)/n_old;



% %% BiG-SAM
% eta_up_list = [1e-2,1e-1,1];
% eta_lo_list = [1e-2,1e-1,1];
% gamma_list = [1e-1,1,10];
% 
% loss_list = zeros(3,3,3);
% 
% for idx_up = 1:3
%     for idx_lo = 1:3
%         for idx_gamma = 1:3
%             param.eta_up = eta_up_list(idx_up);
%             param.eta_lo = eta_lo_list(idx_lo);
%             param.gamma =  gamma_list(idx_gamma);
%             param.maxtime = 10;
%             disp('BiG-SAM algorithm starts!')
%             [loss_vec_up_sam,loss_vec_lo_sam,~,~] = BigSAM(A_old,A_new,X_last,...
%             D_last,X_up_init,Dstar,param);
%             disp('BiG-SAM algorithm done!');
%             loss_list(idx_up,idx_lo,idx_gamma) = loss_vec_up_sam(end)/n_new+(loss_vec_lo_sam(end)-loss_0)/n_old;
%         end
%     end
% end
% %%
% [~,I] = min(loss_list,[],'all','linear');
% [I1,I2,I3] = ind2sub([3,3,3],I);
% param.eta_up = eta_up_list(I1);
% param.eta_lo = eta_lo_list(I2);
% param.gamma =  gamma_list(I3);
% 
% % param.eta_up = 0.1;
% % param.eta_lo = 0.1;
% % param.gamma = 10;
% 
% param.maxtime = 39;
% disp('BiG-SAM algorithm starts!')
% [loss_vec_up_sam,loss_vec_lo_sam,rec_vec_sam,time_vec_sam] = BigSAM(A_old,A_new,X_last,...
%     D_last,X_up_init,Dstar,param);
% disp('BiG-SAM algorithm done!');
% 
% loss_vec_up_sam = loss_vec_up_sam/n_new;
% loss_vec_lo_sam = (loss_vec_lo_sam-loss_0)/n_old;
% 

%% Figures
maxsample = sample_vec2(end);
if figs == true
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
    N_marker = 10;
    time_idx = linspace(0,maxsample,N_marker);
    marker_idx1 = zeros(N_marker,1);
    marker_idx2 = zeros(N_marker,1);
    marker_idx_proj = zeros(N_marker,1);
    marker_idx_dbgd = zeros(N_marker,1);
%     marker_idx_isam = zeros(N_marker,1);
    for j=1:N_marker
        [~,idx] = min(abs(sample_vec1-time_idx(j)));
        marker_idx1(j) = idx;
        [~,idx] = min(abs(sample_vec2-time_idx(j)));
        marker_idx2(j) = idx;
        [~,idx] = min(abs(sample_vec_proj-time_idx(j)));
        marker_idx_proj(j) = idx;
        [~,idx] = min(abs(sample_vec_dbgd-time_idx(j)));
        marker_idx_dbgd(j) = idx;
    end
%     plot(time_vec,rec_vec,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
%     hold on
%     plot(time_vec_sam,rec_vec_sam,'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx_sam)
%     plot(time_vec_proj, rec_vec_proj,'s-','DisplayName','a-IRG','MarkerIndices', marker_idx_proj)
%     plot(time_vec2,rec_vec2,'d-','DisplayName','FW (w/o cutting plane)','MarkerIndices', marker_idx2)
    plot(sample_vec1,rec_vec1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    plot(sample_vec2,rec_vec2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    plot(sample_vec_proj, rec_vec_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)
    plot(sample_vec_dbgd,rec_vec_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")


    lgd = legend();
    lgd.Position=[0.521319094763862 0.416352230128746 0.37569514380561 0.252044610374479];
    ylabel('Recovery rate')
    xlabel('time (s)')
    set(gca,'FontSize',24);
    set(gca,'YLim',[0,1])
    legend('Interpreter','latex')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/recovery_new.eps')
    %% Upper-level objective
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
%     semilogy(time_vec,loss_vec_up,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
%     hold on
%     semilogy(time_vec_sam,loss_vec_up_sam,'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx_sam)
%     semilogy(time_vec_proj,loss_vec_up_proj,'s-','DisplayName','a-IRG','MarkerIndices', marker_idx_proj)
%     semilogy(time_vec2,loss_vec_up2,'d-','DisplayName','FW (w/o cutting plane)','MarkerIndices', marker_idx2)
    semilogy(sample_vec1,loss_vec_up1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    semilogy(sample_vec2,loss_vec_up2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    semilogy(sample_vec_proj,loss_vec_up_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)   
    semilogy(sample_vec_dbgd,loss_vec_up_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")


    legend
    
    ylabel('$f(\tilde{\mathbf{D}}_k,\tilde{\mathbf{X}}_k)$')
    xlabel('time (s)')
    set(gca,'FontSize',24);
    set(gca,'YLim',[1e-6,1])
    legend('Interpreter','latex','Location','southwest')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/nonconvex_upper_new.eps')
    %% Lower-level objective
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
    
%     semilogy(time_vec, loss_vec_lo,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
%     hold on
%     semilogy(time_vec_sam, loss_vec_lo_sam,'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx_sam)
%     semilogy(time_vec_proj, loss_vec_lo_proj,'s-','DisplayName','a-IRG','MarkerIndices', marker_idx_proj)
%     semilogy(time_vec2, loss_vec_lo2,'d-','DisplayName','FW (w/o cutting plane)','MarkerIndices', marker_idx2)
% 
    semilogy(sample_vec1, loss_vec_lo1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    semilogy(sample_vec2, loss_vec_lo2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    semilogy(sample_vec_proj, loss_vec_lo_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)
    semilogy(sample_vec_dbgd,loss_vec_lo_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")

    
   
    legend
    
    ylabel('$g(\tilde{\mathbf{D}}_k)-g(\tilde{\mathbf{D}}_0)$')
    xlabel('time (s)')
    set(gca,'FontSize',24);
    set(gca,'YLim',[1e-22,1])
    legend('Interpreter','latex','Location','southwest')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/nonconvex_lower_new.eps')
    
end