clear;
%% load data
seed = 123456;

rng(seed);
load('wikivital_mathematics.mat');
[m,n]= size(A);

% training model data
trp=1/3; % 60% of data is used as the training set
%          % 20% of data is used as the validation set
%          % 20% of data is used as the test set
% 
idx=randperm(m); %for randomly scatter the dataset
A2=A(idx(1:round(trp*m)),:);
b2=b(idx(1:round(trp*m)),:);
% validation model data 
A1=A(idx(round(trp*m)+1:(round(trp*m+(m*(1-trp)/2)))),:);
b1=b(idx(round(trp*m)+1:(round(trp*m+(m*(1-trp)/2)))),:);

% test model data
A3=A(idx(round((trp*m)+(m*(1-trp)/2))+1:end),:);
b3=b(idx(round((trp*m)+(m*(1-trp)/2))+1:end),:);

epsilon_f = 1e-4;
epsilon_g = 1e-4;
%% global lambda;
lambda=10;
% x_init = sparse(n,1);
%%
%%%% use a random x_init
maxiter=1e5;
%%
% x_init = Sample_L1ball(n,1,lambda);
% x_init = x_init';
x_init = sparse(n,1);
% x_init(1) = 1;
%% function definition
fun_f= @(x) sum_square(A1*x-b1)/2;
fun_g = @(x) sum_square(A2*x-b2)/2;
grad_g = @(x) A2'*(A2*x-b2);
grad_f= @(x) A1'*(A1*x-b1);
 
% Finding optimal solution
cvx_begin quiet
variables xstar_g(n,1)
cvx_precision high
minimize fun_g(xstar_g)
subject to
    norm(xstar_g,1)<=lambda;
cvx_end
gstar = cvx_optval;

cvx_begin quiet
variables xstar(n,1)
cvx_precision high
cvx_solver sedumi
minimize fun_f(xstar)
subject to
    norm(xstar,1)<=lambda;
    fun_g(xstar)<=gstar;
cvx_end
cvx_optval
fstar = cvx_optval;
% %% CG for the sub problem
% param.epsilong = epsilon_g/2;
% param.lam1=lambda;
% param.maxiter=1e4;
% tic;
% [last_iter , f_hist, g_hist, sample] = CG_lowerlevel(fun_f,fun_g,grad_g,x_init,param);
% time_init = toc;
% %% STORM for the sub problem
% param.epsilong = epsilon_g/2;
% param.lam1=lambda;
% param.maxiter=1e1;
% tic;
% [last_iter1, f_hist1,g_hist1,sample1] = CG_lowerlevel_STORM(A2,b2,fun_f,fun_g,grad_g,x_init,param);
% time_init = toc;
%% SPIDER for the subproblem
param.epsilong = epsilon_g/2;
param.lam1=lambda;
param.maxiter=5e3;
tic;
[last_iter2, f_hist2,g_hist2,sample2] = CG_lowerlevel_SPIDER(A2,b2,fun_f,fun_g,grad_g,x_init,param);
time_init = toc;
% %% Comparasion of the first stage (solving lower level problem)
% param.maxsample = sample(end);
% % lower-level gap
% figure;
% set(0,'defaulttextinterpreter','latex')
% set(gcf,'DefaultLineLinewidth',5)
% set(gcf,'DefaultLineMarkerSize',16);
% set(gcf,'Position',[331,167,591,586])
% 
% N_marker = 10;
% time_idx = linspace(0,param.maxsample,N_marker);
% marker_idx1 = zeros(N_marker,1);
% marker_idx2 = zeros(N_marker,1);
% marker_idx3 = zeros(N_marker,1);
% marker_idx4 = zeros(N_marker,1);
% for j=1:N_marker
%     [~,idx] = min(abs(sample-time_idx(j)));
%     marker_idx1(j) = idx;
%     [~,idx] = min(abs(sample1-time_idx(j)));
%     marker_idx2(j) = idx;
%     [~,idx] = min(abs(sample2-time_idx(j)));
%     marker_idx3(j) = idx;
% %     [~,idx] = min(abs(sample_vec4-time_idx(j)));
% %     marker_idx4(j) = idx;
% end
% 
% semilogy(sample,(g_hist-gstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx1);
% hold on;
% semilogy(sample1,(g_hist1-gstar),'s-','DisplayName','CG-SBO1','MarkerIndices', marker_idx2);
% semilogy(sample2,(g_hist2-gstar),'^-','DisplayName','CG-SBO2','MarkerIndices', marker_idx3);
% %semilogy(sample_vec4,abs(g_vec4-gstar),'>-','DisplayName','CG-BiO (w/o cutting plane)','MarkerIndices', marker_idx4, 'Color',"#77AC30");
% ylabel('$g(\beta_k)-g^*$')
% xlabel('number of samples')
% set(gca,'FontSize',24);
% legend('Interpreter','latex','Location','northeast')
% grid on;
% grid minor
% pbaspect([1 0.8 1])

% print('-depsc2','-r600','./figs/nc_lower_time.eps')
%% Param Setting
param.epsilonf = epsilon_f;
param.epsilong = epsilon_g;
param.lam=lambda;
param.fun_g_x0 = fun_g(last_iter2);
%param.maxiter=1e3;

% %% CG-BiO algorithm
% param.maxiter=100;
% disp('CG-BiO starts');
% [f_vec1,g_vec1,time_vec1,sample_vec1,x1,tsa_BCG] = CG_BiO(fun_f,grad_f,grad_g,fun_g,...
%     @(x)TSA_LS(x,A3,b3),param,last_iter);
% disp('CG-BiO Achieved!');
% 
% time_vec1 = time_vec1+time_init;

%% CG-SBO1 algorithm
param.maxiter=4.5e5;
param.gamma = 1e-2;
param.K = 1e-4;
disp('CG-SBO1 starts');
[f_vec2,g_vec2,time_vec2,sample_vec2,x2,tsa_SBO1] = CG_SBO1(fun_f, fun_g, @(x)TSA_LS(x,A3,b3),param,last_iter2, A1, A2, b1, b2);
disp('CG-SBO1 Achieved!');

time_vec2 = time_vec2+time_init;
param.maxtime = time_vec2(end);

% %% modify sample_vec1
% for i = 100003:500003
%     sample_vec2(i) = sample_vec2(i)+100002;
% end
%% CG-SBO2 algorithm
param.maxiter=23685;
param.gamma = 1e-5;
param.K = 1e-4;
disp('CG-SBO2 starts');
[f_vec3,g_vec3,time_vec3,sample_vec3,x3,tsa_SBO2] = CG_SBO2(fun_f, fun_g, @(x)TSA_LS(x,A3,b3),param,last_iter2, A1, A2, b1, b2);
disp('CG-SBO2 Achieved!');

time_vec3 = time_vec3+time_init;
param.maxtime = time_vec3(end);

% %% CG-SBO2-ALT algorithm (different K)
% param.maxiter=1e5;
% param.gamma = 1e-2;
% param.K = 1;
% disp('CG-SBO2 starts');
% [f_vec_ATL,g_vec_ATL,time_vec_ATL,sample_vec_ATL,x_ATL,tsa_ATL] = CG_SBO1(fun_f, fun_g, @(x)TSA_LS(x,A3,b3),param,last_iter2, A1, A2, b1, b2);
% disp('CG-SBO2 Achieved!');
% %% CG-SBO2-ALT algorithm
% param.maxiter=1.6e4;
% param.gamma = 2;
% disp('CG-SBO2 starts');
% [f_vec_ATL,g_vec_ATL,time_vec_ATL,sample_vec_ATL,gz_vec,hl_vec,x_ATL,tsa_ATL] = CG_SBO2_ALT(fun_f, fun_g, grad_g, @(x)TSA_LS(x,A3,b3),param,x_init, A1, A2, b1, b2);
% disp('CG-SBO2 Achieved!');

%% check convergence
% plot(sample_vec_ATL,gz_vec);
%plot(sample_vec_ATL,g_vec_ATL);
% time_vec3 = time_vec3+time_init;
% param.maxtime = time_vec3(end);
%% CG-SBO3 algorithm
% param.maxiter=1e4;
% disp('CG-SBO3 starts');
% [f_vec4,g_vec4,time_vec4,x,tsa_SBO3] = CG_SBO3(fun_f, fun_g, @(x)TSA_LS(x,A3,b3),param,last_iter, A1, A2, b1, b2);
% disp('CG-SBO3 Achieved!');
% 
% time_vec4 = time_vec4+time_init;
%param.maxtime = time_vec4(end);

%% a-IRG Algorithm
% param.maxtime = time_vec2(end);
% param.maxiter=1e1;
% 
% disp('a-IRG Algorithm starts')
% [f_vec_ap,g_vec_ap,time_vec_ap,sample_vec_ap,xlast,tsa_AP] = Alg_Projection(fun_f,grad_f,grad_g,...
%     fun_g,@(x)TSA_LS(x,A3,b3),param,x_init, A1, A2, b1, b);
% disp('a-IRG Solution Achieved!');
%% a-IRG_Sto
param.maxiter=1e4;
param.gamma = 1e-7;
param.eta = 1e3;
disp('a-IRG Algorithm starts')
[f_vec_ap,g_vec_ap,time_vec_ap,sample_vec_ap,xlast,tsa_AP] = Alg_Projection_Sto(fun_f,grad_f,grad_g,...
    fun_g,@(x)TSA_LS(x,A3,b3),param,x_init, A1, A2, b1, b2);
disp('a-IRG Solution Achieved!');
% %% BiG-SAM Algorithm
% param.eta_g = 1/eigs(A2'*A2,1);
% param.eta_f = 2/eigs(A1'*A1,1);
% param.gamma = 10;
% disp('BiG-SAM Algorithm starts');
% [f_vec3,g_vec3,time_vec3,xlast3,tsa_SAM] = BigSAM(fun_f,grad_f,grad_g,fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
% disp('BiG-SAM Solution Achieved!');
% 
% %% DBGD
% param.alpha = 1;
% param.beta = 1;
% param.stepsize = 1e-4;
% param.maxiter=1e7;
% param.maxtime = time_vec1(end);
% disp('DBGD Algorithm starts');
% [f_vec5,g_vec5,time_vec5,xlast5,tsa_DBGD] = DBGD(fun_f,grad_f,grad_g,fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
% disp('DBGD Solution Achieved!');
% 
%% DBGD-sto
param.alpha = 1;
param.beta = 1;
param.stepsize = 1e-6;
param.maxiter=5e5;
% param.maxtime = time_vec1(end);
disp('DBGD Algorithm starts');
[f_vec_DBGD,g_vec_DBGD,time_vec_DBGD,sample_vec_DBGD,xlast5,tsa_DBGD] = DBGD_Sto(fun_f,grad_f,grad_g,fun_g,@(x)TSA_LS(x,A3,b3),param,x_init,A1,A2,b1,b2);
disp('DBGD Solution Achieved!');
% 
% %% MNG
% param.maxtime = time_vec1(end);
% param.maxiter=length(time_vec1);
% param.M = eigs(A2'*A2,1);
% 
% [f_vec4,g_vec4,time_vec4,xlast4,tsa_MNG] = MNG(A1,b1,fun_g,grad_g,@(x)TSA_LS(x,A3,b3),param,A1\b1);

%% modify plot
sample_vec_ap = [0;sample_vec_ap];
sample_vec_DBGD = [0;sample_vec_DBGD];
g_vec_ap = [fun_g(x_init);g_vec_ap];
g_vec_DBGD = [fun_g(x_init);g_vec_DBGD];
f_vec_ap = [fun_f(x_init);f_vec_ap];
f_vec_DBGD = [fun_f(x_init);f_vec_DBGD];
tsa_AP = [TSA_LS(x_init,A3,b3);tsa_AP];
tsa_DBGD = [TSA_LS(x_init,A3,b3);tsa_DBGD];
%% Figures
%% lower-level gap v.s. number of samples
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])


N_marker = 10;
time_idx = linspace(0,sample_vec3(end),N_marker);
marker_idx1 = zeros(N_marker,1);
marker_idx2 = zeros(N_marker,1);
marker_idx3 = zeros(N_marker,1);
marker_idx4 = zeros(N_marker,1);
marker_idx5 = zeros(N_marker,1);
marker_idx_GD = zeros(N_marker,1);
for j=1:N_marker
    [~,idx] = min(abs(sample_vec_ATL-time_idx(j)));
    marker_idx1(j) = idx;
    [~,idx] = min(abs(sample_vec2-time_idx(j)));
    marker_idx2(j) = idx;
    [~,idx] = min(abs(sample_vec3-time_idx(j)));
    marker_idx3(j) = idx;
    [~,idx] = min(abs(sample_vec_ap-time_idx(j)));
    marker_idx4(j) = idx;
    [~,idx] = min(abs(sample_vec_DBGD-time_idx(j)));
    marker_idx5(j) = idx;
end

semilogy(sample_vec_ATL,abs(g_vec_ATL-gstar),'s-','DisplayName','ATL','MarkerIndices', marker_idx1);
hold on;
% ma = sample_vec2(end);
semilogy(sample_vec2,abs(g_vec2-gstar),'o-','DisplayName','SBCGI','MarkerIndices', marker_idx2);
%hold on;
semilogy(sample_vec3,abs(g_vec3-gstar),'^-','DisplayName','SBCGF','MarkerIndices', marker_idx3);
semilogy(sample_vec_ap,abs(g_vec_ap-gstar),'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx4);
semilogy(sample_vec_DBGD,abs(g_vec_DBGD-gstar),'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx5, 'Color',"#77AC30");

ylabel('$|g(\beta_k)-g^*|$')
xlabel('number of samples')
set(gca,'FontSize',24);
set(gca,'YLim',[1e-6,10])
legend('Interpreter','latex','Location','southwest')
grid on;
grid minor
pbaspect([1 0.8 1])

% print('-depsc2','-r600','./figs/lower_subopt_time.eps')
%% upper-level gap v.s. # of samples
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])

semilogy(sample_vec_ATL,abs(f_vec_ATL-fstar),'s-','DisplayName','SPIDER-SBO-ATL','MarkerIndices', marker_idx1);
hold on;
% semilogy(sample_vec2,abs(f_vec2-fstar),'s-','DisplayName','CG-SBO1','MarkerIndices', marker_idx2);
% hold on;
% semilogy(sample_vec3,abs(f_vec3-fstar),'^-','DisplayName','CG-SBO2','MarkerIndices', marker_idx3);
% semilogy(sample_vec_ap,abs(f_vec_ap-fstar),'d-','DisplayName','a-IRG','MarkerIndices', marker_idx4);
% semilogy(sample_vec_DBGD,abs(f_vec_DBGD-fstar),'>-','DisplayName','DBGD','MarkerIndices', marker_idx5, 'Color',"#77AC30");
semilogy(sample_vec2,abs(f_vec2-fstar),'o-','DisplayName','SBCGI','MarkerIndices', marker_idx2);
%hold on;
semilogy(sample_vec3,abs(f_vec3-fstar),'^-','DisplayName','SBCGF','MarkerIndices', marker_idx3);
semilogy(sample_vec_ap,abs(f_vec_ap-fstar),'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx4);
semilogy(sample_vec_DBGD,abs(f_vec_DBGD-fstar),'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx5, 'Color',"#77AC30");

ylabel('$|f(\beta_k)-f^*|$')
xlabel('number of samples')
% set(gca,'YLim',[1e-11,1e3])
set(gca,'FontSize',24);
% set(gca,'YLim',[0,1])
legend('Interpreter','latex','Location','northeast')
grid on;
grid minor
pbaspect([1 0.8 1])

% print('-depsc2','-r600','./figs/upper_subopt_time.eps')
%% test error v.s. # of samples
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])

semilogy(sample_vec_ATL(2:end),tsa_ATL,'s-','DisplayName','SPIDER-SBO-ATL','MarkerIndices', marker_idx1);
hold on;
semilogy(sample_vec2(2:end),tsa_SBO1,'o-','DisplayName','SBCGI','MarkerIndices', marker_idx2);
%hold on;
semilogy(sample_vec3(2:end),tsa_SBO2,'^-','DisplayName','SBCGF','MarkerIndices', marker_idx3);
semilogy(sample_vec_ap,tsa_AP,'d-','DisplayName','aR-IP-SeG','MarkerIndices', marker_idx4);
semilogy(sample_vec_DBGD,tsa_DBGD,'>-','DisplayName','DBGD-Sto','MarkerIndices', marker_idx5, 'Color',"#77AC30");


ylabel('Test error')
xlabel('number of samples')
set(gca,'FontSize',24);
% set(gca,'YLim',[0,1])
legend('Interpreter','latex','Location','northeast')
grid on;
grid minor
pbaspect([1 0.8 1])
% print('-depsc2','-r600','./figs/test_error_time.eps')

