clear;
%% generate the data
p = 50; % number of atoms
m = 25; % dimension of data
n_old = 250; % number of lower-level samples
n_new = 250; % number of upper-level samples
thres = 0.9; % threshold for recovery rate
sigma = 0.01; % noise level
max_time = 100;  % max running time

seed = 2020;
rng(seed)
figs = true;

Dstar = randn(m,p); % true dictionary
Dstar = Dstar./vecnorm(Dstar); % normalize the atoms 
k_spar = 5; % number of nonzeros in each coefficient vector
mask = zeros(p,n_old);
for i=1:n_old
    mask(randsample(round(4*p/5),k_spar),i)=(0.8*rand(k_spar,1)+0.2).*(2*randi([0 1],k_spar,1)-1);
end
X_old = mask; % true coefficient matrix for the old dataset
A_old = Dstar*X_old+sigma*randn(m,n_old);

mask = zeros(p,n_new);
for i=1:n_new
    mask(randsample(round(3*p/5):p,k_spar),i)=(0.8*rand(k_spar,1)+0.2).*(2*randi([0 1],k_spar,1)-1);
end
X_new = mask; % true coefficient matrix for the new dataset
A_new = Dstar*X_new+sigma*randn(m,n_new);


%% solving the lower-level problem initially by CG
epsilon_g = 1e-6;
maxiter = 1e4;
delta = 3;

param.delta = delta;
param.thres = thres;
param.p = round(4*p/5);
param.eps = epsilon_g;
param.eps2 = 1e-10;
param.maxiter = maxiter;

disp('Initialization starts!');
[D_last,X_last,g_hist,rec_init,sample_init] = CG_init(A_old,Dstar,param);
disp('Initialization done!');

g_hist = g_hist/n_old;
p_res = p-round(4*p/5);
D_last = [D_last,zeros(m,p_res)];
X_last = [X_last;zeros(p_res,n_old)];
loss_0 = norm(A_old-D_last*X_last,'fro')^2/2;

% %% solving the lower-level problem initially by STORM
% epsilon_g = 1e-6;
% maxiter = 1000;
% delta = 3;
% 
% param.delta = delta;
% param.thres = thres;
% param.p = round(4*p/5);
% param.eps = epsilon_g;
% param.eps2 = 1e-10;
% param.maxiter = maxiter;
% 
% disp('Initialization starts!');
% [D_last1,X_last1,g_hist1,rec_init1,sample1] = STORM_init(A_old,Dstar,param);
% disp('Initialization done!');
% 
% g_hist1 = g_hist1/n_old;
% p_res1 = p-round(4*p/5);
% D_last1 = [D_last1,zeros(m,p_res)];
% X_last1 = [X_last1;zeros(p_res,n_old)];
% loss_01 = norm(A_old-D_last1*X_last1,'fro')^2/2;
% 

% %% solving the lower-level problem initially by SPIDER
% epsilon_g = 1e-6;
% maxiter = 1e3;
% delta = 3;
% 
% param.delta = delta;
% param.thres = thres;
% param.p = round(4*p/5);
% param.eps = epsilon_g;
% param.eps2 = 1e-10;
% param.maxiter = maxiter;
% 
% disp('Initialization starts!');
% [D_last2,X_last2,g_hist2,rec_init2,sample2] = SPIDER_init(A_old,Dstar,param);
% disp('Initialization done!');
% 
% g_hist2 = g_hist2/n_old;
% p_res2 = p-round(4*p/5);
% D_last2 = [D_last2,zeros(m,p_res)];
% X_last2 = [X_last2;zeros(p_res,n_old)];
% loss_02 = norm(A_old-D_last2*X_last2,'fro')^2/2;


% %% initial plots
% maxsample = sample_init(end);
% figure;
% set(0,'defaulttextinterpreter','latex')
% set(gcf,'DefaultLineLinewidth',5)
% set(gcf,'DefaultLineMarkerSize',16);
% set(gcf,'Position',[331,215,720,538])
% % set(gcf,'WindowState','maximized');
% N_marker = 10;
% time_idx = linspace(0,maxsample,N_marker);
% marker_idx = zeros(N_marker,1);
% marker_idx1 = zeros(N_marker,1);
% marker_idx_proj = zeros(N_marker,1);
% marker_idx2 = zeros(N_marker,1);
% marker_idx_t = zeros(N_marker,1);
% marker_idx_isam = zeros(N_marker,1);
% marker_idx_dbgd = zeros(N_marker,1);
% for j=1:N_marker
%     [~,idx] = min(abs(sample_init-time_idx(j)));
%     marker_idx(j) = idx;
%     [~,idx] = min(abs(sample1-time_idx(j)));
%     marker_idx1(j) = idx;
% %        [~,idx] = min(abs(time_vec_proj-time_idx(j)));
% %        marker_idx_proj(j) = idx;
%     [~,idx] = min(abs(sample2-time_idx(j)));
%     marker_idx2(j) = idx;
% %         [~,idx] = min(abs(time_vec_sam-time_idx(j)));
% %         marker_idx_sam(j) = idx;
% %         [~,idx] = min(abs(time_vec_dbgd-time_idx(j)));
% %         marker_idx_dbgd(j) = idx;
% end
% semilogy(sample_init, g_hist,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
% hold on
% semilogy(sample1, g_hist1,'o-','DisplayName','CG-SBO1','MarkerIndices', marker_idx1)
% %hold on
% %     semilogy(time_vec_sam, loss_vec_lo_sam,'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx_sam)
% %     semilogy(time_vec_proj, loss_vec_lo_proj,'s-','DisplayName','a-IRG','MarkerIndices', marker_idx_proj)
% semilogy(sample2, g_hist2,'d-','DisplayName','SPIDER-SBO','MarkerIndices', marker_idx2)
% %     semilogy(time_vec_dbgd,loss_vec_lo_dbgd,'>-','DisplayName','DBGD','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")
% 
% 
% legend
% 
% ylabel('$g(\tilde{\mathbf{D}}_k)-g(\tilde{\mathbf{D}}_0)$')
% xlabel('number of samples')
% set(gca,'FontSize',24);
% % set(gca,'YLim',[1e-22,1])
% legend('Interpreter','latex','Location','southwest')
% grid on;
% pbaspect([1 0.7 1])
% % print('-depsc2','-r600','./figs/nonconvex_lower_new.eps')
% 

%% initial for upper level
X_up_init = randn(p,n_new);
X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;

% %% CG-BiO
% max_time = 100;
% gamma0 = .3;
% maxiter_up = 1e3;
% 
% param.delta = delta;
% param.maxtime = max_time;
% param.gamma0 = gamma0; % initial stepsize
% param.thres = thres; % threshold for recovery
% param.p = p;
% param.maxiter = maxiter_up;
% 
% % X_up_init = randn(p,n_new);
% % X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
% [loss_vec_up,loss_vec_lo,rec_vec,time_vec,sample_vec] = CG_BiO(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% % normalize the loss
% loss_vec_up = loss_vec_up/n_new;
% loss_vec_lo = (loss_vec_lo-loss_0)/n_old;

% % %% CG-SBO1-noK
% % % seed = 2020;
% % % rng(seed)
% % param.maxiter = 4e4;
% % param.gamma0 = 0.5;
% % param.maxtime = max_time;
% % param.delta = delta;
% % param.thres = thres; % threshold for recovery
% % param.K = 0;
% % param.p = p;
% % 
% % % param.delta = 1;
% % % X_up_init = randn(p,n_new);
% % % X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
% % [loss_vec_up,loss_vec_lo,rec_vec,time_vec,sample_vec] = CG_SBO1(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% % % normalize the loss
% % loss_vec_up = loss_vec_up/n_new;
% % loss_vec_lo = (loss_vec_lo-loss_0)/n_old;
% 
% %% CG-SBO2-noK
% % seed = 2020;
% % rng(seed)
% param.maxiter = 2e4;
% param.gamma0 = 2e-3;
% param.K = 0;
% param.p = p;
% 
% % param.delta = 1;
% % X_up_init = randn(p,n_new);
% % X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
% [loss_vec_up_t,loss_vec_lo_t,rec_vec_t,time_vec_t,sample_vec_t] = CG_SBO2(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% % normalize the loss
% loss_vec_up_t = loss_vec_up_t/n_new;
% loss_vec_lo_t = (loss_vec_lo_t-loss_0)/n_old;

%% CG-SBO1
% seed = 2020;
% rng(seed)
param.maxiter = 4e4;
param.gamma0 = 0.1;
param.maxtime = max_time;
param.delta = delta;
param.thres = thres; % threshold for recovery
param.K = 0.01;
param.p = p;

% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up1,loss_vec_lo1,rec_vec1,time_vec1,sample_vec1] = CG_SBO1(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up1 = loss_vec_up1/n_new;
loss_vec_lo1 = (loss_vec_lo1-loss_0)/n_old;

%% CG-SBO2
param.maxiter = 2e4;
param.gamma0 = 1e-3;
param.K = 0.01;
% param.delta = 1;
% X_up_init = randn(p,n_new);
% X_up_init = X_up_init./vecnorm(X_up_init,1)*delta;
[loss_vec_up2,loss_vec_lo2,rec_vec2,time_vec2,sample_vec2] = CG_SBO2(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% normalize the loss
loss_vec_up2 = loss_vec_up2/n_new;
loss_vec_lo2 = (loss_vec_lo2-loss_0)/n_old;
%% CG with the same stepsize but without cutting plane
% param.maxiter = 100;
% X_up_init = zeros(p,n_new);
% [loss_vec_up2,loss_vec_lo2,rec_vec2,time_vec2,sample_vec2] = CG_upper(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% loss_vec_up2 = loss_vec_up2/n_new;
% loss_vec_lo2 = (loss_vec_lo2-loss_0)/n_old;

% %% Yousefian's algorithm
% disp('a-IRG algorithm starts!')
% [loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj] = Alg_projection(A_old,X_last,A_new,Dstar,...
%     D_last,X_up_init,param);
% disp('a-IRG algorithm done!');
% 
% loss_vec_up_proj = loss_vec_up_proj/n_new;
% loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;
% 
% %% BiG-SAM
% % eta_up_list = [1e-2,1e-1,1];
% % eta_lo_list = [1e-2,1e-1,1];
% % gamma_list = [1e-1,1,10];
% % 
% % loss_list = zeros(3,3,3);
% % 
% % for idx_up = 1:3
% %     for idx_lo = 1:3
% %         for idx_gamma = 1:3
% %             param.eta_up = eta_up_list(idx_up);
% %             param.eta_lo = eta_lo_list(idx_lo);
% %             param.gamma =  gamma_list(idx_gamma);
% %             param.maxtime = 10;
% %             disp('BiG-SAM algorithm starts!')
% %             [loss_vec_up_sam,loss_vec_lo_sam,~,~] = BigSAM(A_old,A_new,X_last,...
% %             D_last,X_up_init,Dstar,param);
% %             disp('BiG-SAM algorithm done!');
% %             loss_list(idx_up,idx_lo,idx_gamma) = loss_vec_up_sam(end)/n_new+(loss_vec_lo_sam(end)-loss_0)/n_old;
% %         end
% %     end
% % end
% % [~,I] = min(loss_list,[],'all','linear');
% % [I1,I2,I3] = ind2sub([3,3,3],I);
% % param.eta_up = eta_up_list(I1);
% % param.eta_lo = eta_lo_list(I2);
% % param.gamma =  gamma_list(I3);
% 
% param.eta_up = 0.1;
% param.eta_lo = 0.1;
% param.gamma = 10;
% 
% param.maxtime = 39;
% disp('BiG-SAM algorithm starts!')
% [loss_vec_up_sam,loss_vec_lo_sam,rec_vec_sam,time_vec_sam] = BigSAM(A_old,A_new,X_last,...
%     D_last,X_up_init,Dstar,param);
% disp('BiG-SAM algorithm done!');
% 
% loss_vec_up_sam = loss_vec_up_sam/n_new;
% loss_vec_lo_sam = (loss_vec_lo_sam-loss_0)/n_old;
% 
% %% DBGD
% param.alpha = 1;
% param.beta = 1;
% param.stepsize = .1;
% param.delta = delta;
% param.thres = thres;
% 
% param.maxtime = 39;
% 
% disp('DBGD algorithm starts!')
% [loss_vec_up_dbgd,loss_vec_lo_dbgd,rec_vec_dbgd,time_vec_dbgd] = DBGD(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
% disp('DBGD algorithm done!');
% 
% loss_vec_up_dbgd = loss_vec_up_dbgd/n_new;
% loss_vec_lo_dbgd = (loss_vec_lo_dbgd-loss_0)/n_old;

%% Yousefian's algorithm
param.maxiter = 2e4;
disp('aR-IP-SeG algorithm starts!')
[loss_vec_up_proj,loss_vec_lo_proj,rec_vec_proj,time_vec_proj, sample_vec_proj] = Alg_projection_sto(A_old,X_last,A_new,Dstar,...
    D_last,X_up_init,param);
disp('aR-IP-SeG algorithm done!');

loss_vec_up_proj = loss_vec_up_proj/n_new;
loss_vec_lo_proj = (loss_vec_lo_proj-loss_0)/n_old;


%% DBGD-sto
param.alpha = 100;
param.beta = 100;
param.stepsize = 5e-3;
param.delta = delta;
param.thres = thres;

% param.maxtime = 39;
param.maxiter = 4e4;

disp('DBGD algorithm starts!')
[loss_vec_up_dbgd,loss_vec_lo_dbgd,rec_vec_dbgd,time_vec_dbgd,sample_vec_dbgd] = DBGD_sto(A_old,A_new,X_last,D_last,X_up_init,Dstar,param);
disp('DBGD algorithm done!');

loss_vec_up_dbgd = loss_vec_up_dbgd/n_new;
loss_vec_lo_dbgd = (loss_vec_lo_dbgd-loss_0)/n_old;

%% Figures
maxsample = sample_vec2(end);
if figs == true
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
    N_marker = 10;
    time_idx = linspace(0,maxsample,N_marker);
    marker_idx = zeros(N_marker,1);
    marker_idx1 = zeros(N_marker,1);
    marker_idx_proj = zeros(N_marker,1);
    marker_idx2 = zeros(N_marker,1);
    marker_idx_t = zeros(N_marker,1);
    marker_idx_isam = zeros(N_marker,1);
    marker_idx_dbgd = zeros(N_marker,1);
    for j=1:N_marker
%         [~,idx] = min(abs(sample_vec-time_idx(j)));
%         marker_idx(j) = idx;
        [~,idx] = min(abs(sample_vec1-time_idx(j)));
        marker_idx1(j) = idx;
        [~,idx] = min(abs(sample_vec_proj-time_idx(j)));
        marker_idx_proj(j) = idx;
        [~,idx] = min(abs(sample_vec2-time_idx(j)));
        marker_idx2(j) = idx;
%         [~,idx] = min(abs(sample_vec_t-time_idx(j)));
%         marker_idx_t(j) = idx;
        [~,idx] = min(abs(sample_vec_dbgd-time_idx(j)));
        marker_idx_dbgd(j) = idx;
    end
    %plot(sample_vec,rec_vec,'s-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
    %hold on
    plot(sample_vec1,rec_vec1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    %plot(sample_vec_t,rec_vec_t,'^-','DisplayName','test','MarkerIndices', marker_idx_t)
    plot(sample_vec2,rec_vec2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    plot(sample_vec_proj, rec_vec_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)
    plot(sample_vec_dbgd,rec_vec_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")


    lgd = legend();
    %lgd.Position=[0.521319094763862 0.416352230128746 0.37569514380561 0.252044610374479];
    ylabel('Recovery rate')
    xlabel('number of samples')
    set(gca,'FontSize',24);
    set(gca,'YLim',[0,1])
    legend('Interpreter','latex')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/recovery_new.eps')
    %% Upper-level objective
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
    %semilogy(sample_vec,loss_vec_up,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
    %hold on
    semilogy(sample_vec1,loss_vec_up1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    %semilogy(sample_vec_t,loss_vec_up_t,'^-','DisplayName','test','MarkerIndices', marker_idx_t)
    semilogy(sample_vec2,loss_vec_up2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    semilogy(sample_vec_proj,loss_vec_up_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)   
    semilogy(sample_vec_dbgd,loss_vec_up_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")

    legend
    
    ylabel('$f(\tilde{\mathbf{D}}_k,\tilde{\mathbf{X}}_k)$')
    xlabel('number of samples')
    set(gca,'FontSize',24);
%     set(gca,'YLim',[1e-6,1])
    legend('Interpreter','latex','Location','southwest')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/nonconvex_upper_new.eps')
    %% Lower-level objective
    figure;
    set(0,'defaulttextinterpreter','latex')
    set(gcf,'DefaultLineLinewidth',5)
    set(gcf,'DefaultLineMarkerSize',16);
    set(gcf,'Position',[331,215,720,538])
    % set(gcf,'WindowState','maximized');
    
    %semilogy(sample_vec, loss_vec_lo,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx)
    %hold on
    semilogy(sample_vec1, loss_vec_lo1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx1)
    hold on
    %semilogy(sample_vec_t, loss_vec_lo_t,'^-','DisplayName','test','MarkerIndices', marker_idx_t)
    semilogy(sample_vec2, loss_vec_lo2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx2)
    semilogy(sample_vec_proj, loss_vec_lo_proj,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx_proj)
    semilogy(sample_vec_dbgd,loss_vec_lo_dbgd,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx_dbgd, 'Color',"#77AC30")

    
    legend
    
    ylabel('$g(\tilde{\mathbf{D}}_k)-g(\tilde{\mathbf{D}}_0)$')
    xlabel('number of samples')
    set(gca,'FontSize',24);
%     set(gca,'YLim',[1e-22,1])
    legend('Interpreter','latex','Location','southwest')
    grid on;
    pbaspect([1 0.7 1])
    % print('-depsc2','-r600','./figs/nonconvex_lower_new.eps')
    
end
