clear variables
%%
rng(2023)
% num_seed = 5;
num_seed = 10;
seeds = randi(1e4,[num_seed,1]);

max_time = 6.4e5;
% max_time = 5;
% N_time = 1e2;
N_time = 1e2;
time = linspace(0,max_time,N_time);
for i=1:num_seed
    seed = seeds(i);
    [time_vec,time_vec_proj,time_vec2,time_vec_sam,...
    rec_vec,rec_vec_proj,rec_vec2,rec_vec_sam,...
    loss_vec_up,loss_vec_up_proj,loss_vec_up2,loss_vec_up_sam,...
    loss_vec_lo,loss_vec_lo_proj,loss_vec_lo2,loss_vec_lo_sam] = dict_learning_sto(seed,false);
    if i == 1
        N_time = length(time_vec);
        N_time_proj = length(time_vec_proj);
        N_time2 = length(time_vec2);
        N_time_sam = length(time_vec_sam);

        time_v = linspace(0,max_time,N_time);
        time_v_proj = linspace(0,max_time,N_time_proj);
        time_v2 = linspace(0,max_time,N_time2);
        time_v_sam = linspace(0,max_time,N_time_sam);

        rec_mat = zeros(N_time,num_seed);
        rec_mat_proj = zeros(N_time_proj,num_seed);
        rec_mat2 = zeros(N_time2,num_seed);
        rec_mat_sam = zeros(N_time_sam,num_seed);

        loss_mat_up = zeros(N_time,num_seed);
        loss_mat_up_proj = zeros(N_time_proj,num_seed);
        loss_mat_up2 = zeros(N_time2,num_seed);
        loss_mat_up_sam = zeros(N_time_sam,num_seed);

        loss_mat_lo = zeros(N_time,num_seed);
        loss_mat_lo_proj = zeros(N_time_proj,num_seed);
        loss_mat_lo2 = zeros(N_time2,num_seed);
        loss_mat_lo_sam = zeros(N_time_sam,num_seed);
    end
        
    F = griddedInterpolant(time_vec,rec_vec);
    rec_mat(:,i) = F(time_v);
    F = griddedInterpolant(time_vec_proj,rec_vec_proj);
    rec_mat_proj(:,i) = F(time_v_proj);
    F = griddedInterpolant(time_vec2,rec_vec2);
    rec_mat2(:,i) = F(time_v2);
    F = griddedInterpolant(time_vec_sam,rec_vec_sam);
    rec_mat_sam(:,i) = F(time_v_sam);

    F = griddedInterpolant(time_vec,loss_vec_up);
    loss_mat_up(:,i) = F(time_v);
    F = griddedInterpolant(time_vec_proj,loss_vec_up_proj);
    loss_mat_up_proj(:,i) = F(time_v_proj);
    F = griddedInterpolant(time_vec2,loss_vec_up2);
    loss_mat_up2(:,i) = F(time_v2);
    F = griddedInterpolant(time_vec_sam,loss_vec_up_sam);
    loss_mat_up_sam(:,i) = F(time_v_sam);

    F = griddedInterpolant(time_vec,loss_vec_lo);
    loss_mat_lo(:,i) = F(time_v);
    F = griddedInterpolant(time_vec_proj,loss_vec_lo_proj);
    loss_mat_lo_proj(:,i) = F(time_v_proj);
    F = griddedInterpolant(time_vec2,loss_vec_lo2);
    loss_mat_lo2(:,i) = F(time_v2);
    F = griddedInterpolant(time_vec_sam,loss_vec_up_sam);
    loss_mat_lo_sam(:,i) = F(time_v_sam);
end

%% Recovery rate
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,215,720,538])
% set(gcf,'WindowState','maximized');
N_marker = 10;
time_idx = linspace(0,max_time,N_marker);
marker_idx = zeros(N_marker,1);
marker_idx_proj = zeros(N_marker,1);
marker_idx2 = zeros(N_marker,1);
marker_idx_sam = zeros(N_marker,1);
for j=1:N_marker
    [~,idx] = min(abs(time_v-time_idx(j)));
    marker_idx(j) = idx;
    [~,idx] = min(abs(time_v_proj-time_idx(j)));
    marker_idx_proj(j) = idx;
    [~,idx] = min(abs(time_v2-time_idx(j)));
    marker_idx2(j) = idx;
    [~,idx] = min(abs(time_v_sam-time_idx(j)));
    marker_idx_sam(j) = idx;
end
rec_mean = plot(time_vec,mean(rec_mat,2),'o-','DisplayName','SBCGI','MarkerIndices', marker_idx);
c = get(rec_mean,'Color');
hold on
time_v_fill = [time_v,fliplr(time_v)];
fill_between = [min(rec_mat,[],2);flipud(max(rec_mat,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(rec_mean)

rec_mean_proj = plot(time_v_proj,mean(rec_mat_proj,2),'^-','DisplayName','SBCGF','MarkerIndices', marker_idx_proj);
c = get(rec_mean_proj,'Color');
time_v_fill = [time_v_proj,fliplr(time_v_proj)];
fill_between = [min(rec_mat_proj,[],2);flipud(max(rec_mat_proj,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(rec_mean_proj)

rec_mean2 = plot(time_v2,mean(rec_mat2,2),'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx2);
c = get(rec_mean2,'Color');
time_v_fill = [time_v2,fliplr(time_v2)];
fill_between = [min(rec_mat2,[],2);flipud(max(rec_mat2,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(rec_mean2)

rec_mean_sam = plot(time_v_sam,mean(rec_mat_sam,2),'>-','DisplayName','DBGD-sto','MarkerIndices', marker_idx_sam, 'Color',"#77AC30");
c = get(rec_mean_sam,'Color');
time_v_fill = [time_v_sam,fliplr(time_v_sam)];
fill_between = [min(rec_mat_sam,[],2);flipud(max(rec_mat_sam,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(rec_mean_sam)


legend
ylabel('Recovery rate')
xlabel('time (s)')
set(gca,'FontSize',24);
set(gca,'YLim',[0,1])
legend('Interpreter','latex','Location','southeast')
grid on;
pbaspect([1 0.7 1])
% print('-depsc2','-r600','./figs/recovery.eps')
%% Upper-level objective
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,215,720,538])
% set(gcf,'WindowState','maximized');

loss_mean = semilogy(time_v,mean(loss_mat_up,2),'o-','DisplayName','SBCGI','MarkerIndices', marker_idx);
c = get(loss_mean,'Color');
hold on
time_v_fill = [time_v,fliplr(time_v)];
fill_between = [min(loss_mat_up,[],2);flipud(max(loss_mat_up,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean)

loss_mean_proj = semilogy(time_v_proj,mean(loss_mat_up_proj,2),'^-','DisplayName','SBCGF','MarkerIndices', marker_idx_proj);
c = get(loss_mean_proj,'Color');
time_v_fill = [time_v_proj,fliplr(time_v_proj)];
fill_between = [min(loss_mat_up_proj,[],2);flipud(max(loss_mat_up_proj,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean_proj)

loss_mean2 = semilogy(time_v2,mean(loss_mat_up2,2),'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx2);
c = get(loss_mean2,'Color');
time_v_fill = [time_v2,fliplr(time_v2)];
fill_between = [min(loss_mat_up2,[],2);flipud(max(loss_mat_up2,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean2)


loss_mean_sam = semilogy(time_v_sam,mean(loss_mat_up_sam,2),'>-','DisplayName','DBGD-sto','MarkerIndices', marker_idx_sam, 'Color',"#77AC30");
c = get(loss_mean_sam,'Color');
time_v_fill = [time_v_sam,fliplr(time_v_sam)];
fill_between = [min(loss_mat_up_sam,[],2);flipud(max(loss_mat_up_sam,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean_sam)
legend

ylabel('$f(\tilde{\mathbf{D}}_k,\tilde{\mathbf{X}}_k)$')
xlabel('time (s)')
set(gca,'FontSize',24);
% set(gca,'YLim',[1e-5,1])
legend('Interpreter','latex','Location','southwest')
grid on;
pbaspect([1 0.7 1])
% print('-depsc2','-r600','./figs/nonconvex_upper.eps')
%% Lower-level objective
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,215,720,538])

time_v(1) = [];
loss_mat_lo(1,:) = [];
marker_idx(end) = marker_idx(end)-1;
loss_mean = semilogy(time_v,mean(loss_mat_lo,2),'o-','DisplayName','SBCGI','MarkerIndices', marker_idx);
c = get(loss_mean,'Color');
hold on
time_v_fill = [time_v,fliplr(time_v)];
fill_between = [min(loss_mat_lo,[],2);flipud(max(loss_mat_lo,[],2))];
fill_between = max(fill_between,eps);
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean)

time_v_proj(1) = [];
loss_mat_lo_proj(1,:) = [];
marker_idx_proj(end) = marker_idx_proj(end)-1;
loss_mean_proj = semilogy(time_v_proj,mean(loss_mat_lo_proj,2),'^-','DisplayName','SBCGF','MarkerIndices', marker_idx_proj);
c = get(loss_mean_proj,'Color');
time_v_fill = [time_v_proj,fliplr(time_v_proj)];
fill_between = [min(loss_mat_lo_proj,[],2);flipud(max(loss_mat_lo_proj,[],2))];
% fill_between = max(fill_between,eps);
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean_proj)

time_v2(1) = [];
loss_mat_lo2(1,:) = [];
marker_idx2(end) = marker_idx2(end)-1;
loss_mean2 = semilogy(time_v2,mean(loss_mat_lo2,2),'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx2);
c = get(loss_mean2,'Color');
time_v_fill = [time_v2,fliplr(time_v2)];
fill_between = [min(loss_mat_lo2,[],2);flipud(max(loss_mat_lo2,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean2)

time_v_sam(1) = [];
loss_mat_lo_sam(1,:) = [];
marker_idx_sam(end) = marker_idx_sam(end)-1;
loss_mean_sam = semilogy(time_v_sam,mean(loss_mat_lo_sam,2),'>-','DisplayName','DBGD-sto','MarkerIndices', marker_idx_sam, 'Color',"#77AC30");
c = get(loss_mean_sam,'Color');
time_v_fill = [time_v_sam,fliplr(time_v_sam)];
fill_between = [min(loss_mat_lo_sam,[],2);flipud(max(loss_mat_lo_sam,[],2))];
fill(time_v_fill,fill_between,c,'LineStyle','none','FaceAlpha',0.3,'HandleVisibility','off')
color_index = get(gca,'ColorOrderIndex');
set(gca,'ColorOrderIndex',color_index-1);
uistack(loss_mean_sam)
legend

ylabel('$g(\tilde{\mathbf{D}}_k)-g(\tilde{\mathbf{D}}_0)$')
xlabel('time (s)')
set(gca,'FontSize',24);
% set(gca,'YLim',[1e-22,1])
legend('Interpreter','latex','Location','southwest')
grid on;
pbaspect([1 0.7 1])
% print('-depsc2','-r600','./figs/nonconvex_lower.eps')