clear;
%% generate the data
p = 50; % number of atoms
m = 25; % dimension of data
n_old = 250; % number of lower-level samples
n_new = 250; % number of upper-level samples
thres = 0.9; % threshold for recovery rate
sigma = 0.01; % noise level
max_time = 100;  % max running time

seed = 2020;
rng(seed)
figs = true;

Dstar = randn(m,p); % true dictionary
Dstar = Dstar./vecnorm(Dstar); % normalize the atoms 
k_spar = 5; % number of nonzeros in each coefficient vector
mask = zeros(p,n_old);
for i=1:n_old
    mask(randsample(round(4*p/5),k_spar),i)=(0.8*rand(k_spar,1)+0.2).*(2*randi([0 1],k_spar,1)-1);
end
X_old = mask; % true coefficient matrix for the old dataset
A_old = Dstar*X_old+sigma*randn(m,n_old);

mask = zeros(p,n_new);
for i=1:n_new
    mask(randsample(round(3*p/5):p,k_spar),i)=(0.8*rand(k_spar,1)+0.2).*(2*randi([0 1],k_spar,1)-1);
end
X_new = mask; % true coefficient matrix for the new dataset
A_new = Dstar*X_new+sigma*randn(m,n_new);


%% solving the lower-level problem initially by CG
epsilon_g = 1e-6;
maxiter = 1e4;
delta = 3;

param.delta = delta;
param.thres = thres;
param.p = round(4*p/5);
param.eps = epsilon_g;
param.eps2 = 1e-10;
param.maxiter = maxiter;

disp('Initialization starts!');
[D_last,X_last,g_hist,rec_init,sample_init] = CG_init(A_old,Dstar,param);
disp('Initialization done!');

g_hist = g_hist/n_old;
p_res = p-round(4*p/5);
D_last = [D_last,zeros(m,p_res)];
X_last = [X_last;zeros(p_res,n_old)];
loss_0 = norm(A_old-D_last*X_last,'fro')^2/2;

%% solving the lower-level problem initially by STORM
epsilon_g = 1e-6;
maxiter = 1000;
delta = 3;

param.delta = delta;
param.thres = thres;
param.p = round(4*p/5);
param.eps = epsilon_g;
param.eps2 = 1e-10;
param.maxiter = maxiter;

disp('Initialization starts!');
[D_last1,X_last1,g_hist1,rec_init1,sample1] = STORM_init(A_old,Dstar,param);
disp('Initialization done!');

g_hist1 = g_hist1/n_old;
p_res1 = p-round(4*p/5);
D_last1 = [D_last1,zeros(m,p_res)];
X_last1 = [X_last1;zeros(p_res,n_old)];
loss_01 = norm(A_old-D_last1*X_last1,'fro')^2/2;


%% solving the lower-level problem initially by SPIDER
epsilon_g = 1e-6;
maxiter = 1e3;
delta = 3;

param.delta = delta;
param.thres = thres;
param.p = round(4*p/5);
param.eps = epsilon_g;
param.eps2 = 1e-10;
param.maxiter = maxiter;

disp('Initialization starts!');
[D_last2,X_last2,g_hist2,rec_init2,sample2] = SPIDER_init(A_old,Dstar,param);
disp('Initialization done!');

g_hist2 = g_hist2/n_old;
p_res2 = p-round(4*p/5);
D_last2 = [D_last2,zeros(m,p_res)];
X_last2 = [X_last2;zeros(p_res,n_old)];
loss_02 = norm(A_old-D_last2*X_last2,'fro')^2/2;


%% initial plots
maxsample = sample_init(end);
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,215,720,538])
% set(gcf,'WindowState','maximized');
N_marker = 10;
time_idx = linspace(0,maxsample,N_marker);
marker_idx = zeros(N_marker,1);
marker_idx1 = zeros(N_marker,1);
marker_idx_proj = zeros(N_marker,1);
marker_idx2 = zeros(N_marker,1);
marker_idx_t = zeros(N_marker,1);
marker_idx_isam = zeros(N_marker,1);
marker_idx_dbgd = zeros(N_marker,1);
for j=1:N_marker
    [~,idx] = min(abs(sample_init-time_idx(j)));
    marker_idx(j) = idx;
    [~,idx] = min(abs(sample1-time_idx(j)));
    marker_idx1(j) = idx;
%        [~,idx] = min(abs(time_vec_proj-time_idx(j)));
%        marker_idx_proj(j) = idx;
    [~,idx] = min(abs(sample2-time_idx(j)));
    marker_idx2(j) = idx;
%         [~,idx] = min(abs(time_vec_sam-time_idx(j)));
%         marker_idx_sam(j) = idx;
%         [~,idx] = min(abs(time_vec_dbgd-time_idx(j)));
%         marker_idx_dbgd(j) = idx;
end
semilogy(sample_init, g_hist,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
hold on
semilogy(sample1, g_hist1,'o-','DisplayName','CG-SBO1','MarkerIndices', marker_idx1)
%hold on
%     semilogy(time_vec_sam, loss_vec_lo_sam,'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx_sam)
%     semilogy(time_vec_proj, loss_vec_lo_proj,'s-','DisplayName','a-IRG','MarkerIndices', marker_idx_proj)
semilogy(sample2, g_hist2,'d-','DisplayName','SPIDER-SBO','MarkerIndices', marker_idx2)
%     semilogy(time_vec_dbgd,loss_vec_lo_dbgd,'>-','DisplayName','DBGD','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")


legend

ylabel('$g(\tilde{\mathbf{D}}_k)-g(\tilde{\mathbf{D}}_0)$')
xlabel('number of samples')
set(gca,'FontSize',24);
% set(gca,'YLim',[1e-22,1])
legend('Interpreter','latex','Location','southwest')
grid on;
pbaspect([1 0.7 1])
% print('-depsc2','-r600','./figs/nonconvex_lower_new.eps')


%% initial for upper level
X_up_init = randn(p,n_new);
X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;

%% CG-BiO
max_time = 100;
gamma0 = .3;
maxiter_up = 1e3;

param.delta = delta;
param.maxtime = max_time;
param.gamma0 = gamma0; % initial stepsize
param.thres = thres; % threshold for recovery
param.p = p;
param.maxiter = maxiter_up;

% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up,loss_vec_lo,rec_vec,time_vec,sample_vec] = CG_BiO(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up = loss_vec_up/n_new;
loss_vec_lo = (loss_vec_lo-loss_0)/n_old;

%% CG-SBO1-noK
% seed = 2020;
% rng(seed)
param.maxiter = 4e4;
param.gamma0 = 0.5;
param.maxtime = max_time;
param.delta = delta;
param.thres = thres; % threshold for recovery
param.K = 0;
param.p = p;

% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up,loss_vec_lo,rec_vec,time_vec,sample_vec] = CG_SBO1(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up = loss_vec_up/n_new;
loss_vec_lo = (loss_vec_lo-loss_0)/n_old;

%% CG-SBO2-noK
% seed = 2020;
% rng(seed)
param.maxiter = 2e4;
param.gamma0 = 2e-3;
param.K = 0;
param.p = p;

% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up_t,loss_vec_lo_t,rec_vec_t,time_vec_t,sample_vec_t] = CG_SBO2(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up_t = loss_vec_up_t/n_new;
loss_vec_lo_t = (loss_vec_lo_t-loss_0)/n_old;

%% CG-SBO1
% seed = 2020;
% rng(seed)
param.maxiter = 4e4;
param.gamma0 = 0.1;
param.maxtime = max_time;
param.delta = delta;
param.thres = thres; % threshold for recovery
param.K = 0.01;
param.p = p;

% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up1,loss_vec_lo1,rec_vec1,time_vec1,sample_vec1] = CG_SBO1(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up1 = loss_vec_up1/n_new;
loss_vec_lo1 = (loss_vec_lo1-loss_0)/n_old;

%% CG-SBO1 (no-cp)
% seed = 2020;
% rng(seed)
param.maxiter = 4e4;
param.gamma0 = 0.1;
param.maxtime = max_time;
param.delta = delta;
param.thres = thres; % threshold for recovery
param.K = 0.01;
param.p = p;

% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj,sample_vec_proj] = CG_SBO1_upper(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up_proj = loss_vec_up_proj/n_new;
loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;

%% CG-SBO2
param.maxiter = 2e4;
param.gamma0 = 1e-3;
param.K = 0.01;
% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up2,loss_vec_lo2,rec_vec2,time_vec2,sample_vec2] = CG_SBO2(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up2 = loss_vec_up2/n_new;
loss_vec_lo2 = (loss_vec_lo2-loss_0)/n_old;
%% CG-SBO2(no-cp)
param.maxiter = 2e4;
param.gamma0 = 1e-3;
param.K = 0.01;
% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up_dbgd,loss_vec_lo_dbgd,rec_vec_dbgd,time_vec_dbgd,sample_vec_dbgd] = CG_SBO2_upper(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up_dbgd = loss_vec_up_dbgd/n_new;
loss_vec_lo_dbgd = (loss_vec_lo_dbgd-loss_0)/n_old;
%% CG with the same stepsize but without cutting plane
% param.maxiter = 100;
% X_up_init = zeros(p,n_new);
% [loss_vec_up2,loss_vec_lo2,rec_vec2,time_vec2,sample_vec2] = CG_upper(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% loss_vec_up2 = loss_vec_up2/n_new;
% loss_vec_lo2 = (loss_vec_lo2-loss_0)/n_old;

% %% Yousefian's algorithm
% disp('a-IRG algorithm starts!')
% [loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj] = Alg_projection(A_old,X_last,A_new,Dstar,...
%     D_last,X_up_init,param);
% disp('a-IRG algorithm done!');
% 
% loss_vec_up_proj = loss_vec_up_proj/n_new;
% loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;
% 
% %% BiG-SAM
% % eta_up_list = [1e-2,1e-1,1];
% % eta_lo_list = [1e-2,1e-1,1];
% % gamma_list = [1e-1,1,10];
% % 
% % loss_list = zeros(3,3,3);
% % 
% % for idx_up = 1:3
% %     for idx_lo = 1:3
% %         for idx_gamma = 1:3
% %             param.eta_up = eta_up_list(idx_up);
% %             param.eta_lo = eta_lo_list(idx_lo);
% %             param.gamma =  gamma_list(idx_gamma);
% %             param.maxtime = 10;
% %             disp('BiG-SAM algorithm starts!')
% %             [loss_vec_up_sam,loss_vec_lo_sam,~,~] = BigSAM(A_old,A_new,X_last,...
% %             D_last,X_up_init,Dstar,param);
% %             disp('BiG-SAM algorithm done!');
% %             loss_list(idx_up,idx_lo,idx_gamma) = loss_vec_up_sam(end)/n_new+(loss_vec_lo_sam(end)-loss_0)/n_old;
% %         end
% %     end
% % end
% % [~,I] = min(loss_list,[],'all','linear');
% % [I1,I2,I3] = ind2sub([3,3,3],I);
% % param.eta_up = eta_up_list(I1);
% % param.eta_lo = eta_lo_list(I2);
% % param.gamma =  gamma_list(I3);
% 
% param.eta_up = 0.1;
% param.eta_lo = 0.1;
% param.gamma = 10;
% 
% param.maxtime = 39;
% disp('BiG-SAM algorithm starts!')
% [loss_vec_up_sam,loss_vec_lo_sam,rec_vec_sam,time_vec_sam] = BigSAM(A_old,A_new,X_last,...
%     D_last,X_up_init,Dstar,param);
% disp('BiG-SAM algorithm done!');
% 
% loss_vec_up_sam = loss_vec_up_sam/n_new;
% loss_vec_lo_sam = (loss_vec_lo_sam-loss_0)/n_old;
% 
% %% DBGD
% param.alpha = 1;
% param.beta = 1;
% param.stepsize = .1;
% param.delta = delta;
% param.thres = thres;
% 
% param.maxtime = 39;
% 
% disp('DBGD algorithm starts!')
% [loss_vec_up_dbgd,loss_vec_lo_dbgd,rec_vec_dbgd,time_vec_dbgd] = DBGD(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% disp('DBGD algorithm done!');
% 
% loss_vec_up_dbgd = loss_vec_up_dbgd/n_new;
% loss_vec_lo_dbgd = (loss_vec_lo_dbgd-loss_0)/n_old;

%% Yousefian's algorithm
param.maxiter = 2e4;
disp('a-IRG algorithm starts!')
[loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj, sample_vec_proj] = Alg_projection_sto(A_old,X_last,A_new,Dstar,...
    D_last,X_up_init,param);
disp('a-IRG algorithm done!');

loss_vec_up_proj = loss_vec_up_proj/n_new;
loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;


%% DBGD-sto
param.alpha = 100;
param.beta = 100;
param.stepsize = 5e-3;
param.delta = delta;
param.thres = thres;

% param.maxtime = 39;
param.maxiter = 4e4;

disp('DBGD algorithm starts!')
[loss_vec_up_dbgd,loss_vec_lo_dbgd,rec_vec_dbgd,time_vec_dbgd,sample_vec_dbgd] = DBGD_sto(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
disp('DBGD algorithm done!');

loss_vec_up_dbgd = loss_vec_up_dbgd/n_new;
loss_vec_lo_dbgd = (loss_vec_lo_dbgd-loss_0)/n_old;

%% Figures
maxsample = sample_vec2(end);
if figs == true
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
    N_marker = 10;
    time_idx = linspace(0,maxsample,N_marker);
    marker_idx = zeros(N_marker,1);
    marker_idx1 = zeros(N_marker,1);
    marker_idx_proj = zeros(N_marker,1);
    marker_idx2 = zeros(N_marker,1);
    marker_idx_t = zeros(N_marker,1);
    marker_idx_isam = zeros(N_marker,1);
    marker_idx_dbgd = zeros(N_marker,1);
    for j=1:N_marker
        [~,idx] = min(abs(sample_vec-time_idx(j)));
        marker_idx(j) = idx;
        [~,idx] = min(abs(sample_vec1-time_idx(j)));
        marker_idx1(j) = idx;
        [~,idx] = min(abs(sample_vec_proj-time_idx(j)));
        marker_idx_proj(j) = idx;
        [~,idx] = min(abs(sample_vec2-time_idx(j)));
        marker_idx2(j) = idx;
        [~,idx] = min(abs(sample_vec_t-time_idx(j)));
        marker_idx_t(j) = idx;
        [~,idx] = min(abs(sample_vec_dbgd-time_idx(j)));
        marker_idx_dbgd(j) = idx;
    end
   
    plot(sample_vec1,rec_vec1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    plot(sample_vec,rec_vec,'o-','DisplayName','SBCGI-M','MarkerIndices', marker_idx)
    plot(sample_vec_proj, rec_vec_proj,'o-','DisplayName','STORM-FW','MarkerIndices', marker_idx_proj)
    %hold on
    plot(sample_vec2,rec_vec2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    plot(sample_vec_t,rec_vec_t,'^-','DisplayName','SBCGF-M','MarkerIndices', marker_idx_t)
    plot(sample_vec_dbgd,rec_vec_dbgd,'^-','DisplayName','SPIDER-FW','MarkerIndices', marker_idx_dbgd)


    lgd = legend();
    %lgd.Position=[0.521319094763862 0.416352230128746 0.37569514380561 0.252044610374479];
    ylabel('Recovery rate')
    xlabel('number of samples')
    set(gca,'FontSize',24);
    set(gca,'YLim',[0,1])
    legend('Interpreter','latex')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/recovery_new.eps')
    %% Upper-level objective
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
%     semilogy(sample_vec,loss_vec_up,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
%     hold on
%     semilogy(sample_vec1,loss_vec_up1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
%     %hold on
%     semilogy(sample_vec_t,loss_vec_up_t,'^-','DisplayName','test','MarkerIndices', marker_idx_t)
%     semilogy(sample_vec2,loss_vec_up2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
%     semilogy(sample_vec_proj,loss_vec_up_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)   
%     semilogy(sample_vec_dbgd,loss_vec_up_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd)
    semilogy(sample_vec1, loss_vec_up1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    semilogy(sample_vec, loss_vec_up,'o-','DisplayName','SBCGI-M','MarkerIndices', marker_idx)
    semilogy(sample_vec_proj, loss_vec_up_proj,'o-','DisplayName','STORM-FW','MarkerIndices', marker_idx_proj)
    semilogy(sample_vec2, loss_vec_up2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    semilogy(sample_vec_t, loss_vec_up_t,'^-','DisplayName','SBCGF-M','MarkerIndices', marker_idx_t)  
    semilogy(sample_vec_dbgd,loss_vec_up_dbgd,'^-','DisplayName','SPDIER-FW','MarkerIndices', marker_idx_dbgd)

    legend
    
    ylabel('$f(\tilde{\mathbf{D}}_k,\tilde{\mathbf{X}}_k)$')
    xlabel('number of samples')
    set(gca,'FontSize',24);
%     set(gca,'YLim',[1e-6,1])
    legend('Interpreter','latex','Location','southwest')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/nonconvex_upper_new.eps')
    %% Lower-level objective
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
   
    semilogy(sample_vec1, loss_vec_lo1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    semilogy(sample_vec, loss_vec_lo,'o-','DisplayName','SBCGI-M','MarkerIndices', marker_idx)
    semilogy(sample_vec_proj, loss_vec_lo_proj,'o-','DisplayName','STORM-FW','MarkerIndices', marker_idx_proj)
    semilogy(sample_vec2, loss_vec_lo2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    semilogy(sample_vec_t, loss_vec_lo_t,'^-','DisplayName','SBCGF-M','MarkerIndices', marker_idx_t)  
    semilogy(sample_vec_dbgd,loss_vec_lo_dbgd,'^-','DisplayName','SPDIER-FW','MarkerIndices', marker_idx_dbgd)

    
    legend
    
    ylabel('$g(\tilde{\mathbf{D}}_k)-g(\tilde{\mathbf{D}}_0)$')
    xlabel('number of samples')
    set(gca,'FontSize',24);
%     set(gca,'YLim',[1e-22,1])
    legend('Interpreter','latex','Location','southwest')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/nonconvex_lower_new.eps')
    
end
