clear;
%% load data
seed = 123456;

rng(seed);
load('wikivital_mathematics.mat');
[m,n]= size(A);

% training model data
trp=0.6; % 60% of data is used as the training set
%          % 20% of data is used as the validation set
%          % 20% of data is used as the test set
% 
idx=randperm(m); %for randomly scatter the dataset
A2=A(idx(1:round(trp*m)),:);
b2=b(idx(1:round(trp*m)),:);
% validation model data 
A1=A(idx(round(trp*m)+1:(round(trp*m+(m*(1-trp)/2)))),:);
b1=b(idx(round(trp*m)+1:(round(trp*m+(m*(1-trp)/2)))),:);

% test model data
A3=A(idx(round((trp*m)+(m*(1-trp)/2))+1:end),:);
b3=b(idx(round((trp*m)+(m*(1-trp)/2))+1:end),:);

epsilon_f = 1e-4;
epsilon_g = 1e-4;
% global lambda;
lambda=1;
x_init = sparse(n,1);
maxiter=1e5;

%% param
A_val = A1'*A1;
param.L_f = eigs(A_val,1);
A_tr = A2'*A2;
param.L_g = eigs(A_tr,1);

%% function definition
% lambda=0.5;
fun_f= @(x) sum_square(A1*x-b1)/2;
fun_g = @(x) sum_square(A2*x-b2)/2;
grad_g = @(x) A2'*(A2*x-b2);
grad_f= @(x) A1'*(A1*x-b1);
 
% Finding optimal solution
cvx_begin quiet
variables xstar_g(n,1)
cvx_precision high
minimize fun_g(xstar_g)
subject to
    norm(xstar_g,2)<=lambda;
cvx_end
gstar = cvx_optval;

cvx_begin
variables xstar(n,1)
cvx_precision medium
cvx_solver sdpt3
minimize fun_f(xstar)
subject to
    norm(xstar,2)<=lambda;
    fun_g(xstar)<=gstar;
cvx_end
cvx_optval
fstar = cvx_optval;

% %% test1
% lambda = 0.1;
% 
% cvx_begin quiet
% variables xg(n,1)
% cvx_precision high
% minimize fun_g(xg)
% subject to
%     norm(xg,2)<=lambda;
% cvx_end
% gstar1 = cvx_optval;
% 
% %% test2
% cvx_begin quiet
% variables xf(n,1)
% cvx_precision medium
% minimize fun_f(xf)
% subject to
%       fun_g(xf)<=gstar1;
%       norm(xf,2)<=lambda;
% cvx_end
% fstar1 = cvx_optval;
%% AGM-BiO Algorithm
% param.maxtime = time_vec1(end);
param.epsilonf = epsilon_f;
param.epsilong = epsilon_g;
param.lam=lambda;
% param.fun_g_x0 = fun_g(last_iter);
% A_val = A1'*A1;
% param.L_f = eigs(A_val,1);
% A_tr = A2'*A2;
% param.L_g = eigs(A_tr,1);
param.maxiter=5e4;
param.gstar = gstar;

disp('AGM-BiO starts');
% [f_vecAB,g_vecAB,time_vecAB,x_AB,tsa_AB, X, Z] = AGM_BiO(fun_f,grad_f,grad_g,fun_g,...
%     @(x)TSA_LS(x,A3,b3),param,x_init);
[f_vecAB,g_vecAB,time_vecAB,x_AB,tsa_AB] = AGM_BiO(fun_f,grad_f,grad_g,fun_g,...
    @(x)TSA_LS(x,A3,b3),param,x_init);
disp('AGM-BiO Achieved!');

time_vec1 = time_vecAB;
% time_vec1 = time_vec1+time_init;

%% CG for the sub problem
param.epsilong = epsilon_g/2;
param.lam1=lambda;
param.maxiter=1e5;
tic;
[last_iter , f_hist] = CG_lowerlevel(fun_g,grad_g,x_init,param);
time_init = toc;
%%
time_init = 4;

%% CG-BiO algorithm
param.epsilonf = epsilon_f;
param.epsilong = epsilon_g;
param.lam=lambda;
param.fun_g_x0 = fun_g(last_iter);
param.maxiter=1e6;
param.maxtime = time_vec1(end) - time_init;

disp('CG-BiO starts');
[f_vecCB,g_vecCB,time_vecCB,x,tsa_CB] = CG_BiO(fun_f,grad_f,grad_g,fun_g,...
    @(x)TSA_LS(x,A3,b3),param,last_iter);
disp('CG-BiO Achieved!');

time_vecCB = time_vecCB+time_init;


%% a-IRG Algorithm
param.maxtime = time_vec1(end);
param.maxiter=1e8;

disp('a-IRG Algorithm starts')
[f_vec2,g_vec2,time_vec2,xlast,tsa_AP] = Alg_Projection(fun_f,grad_f,grad_g,...
    fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
disp('a-IRG Solution Achieved!');

% %% BiG-SAM Algorithm
% param.eta_g = 1/eigs(A2'*A2,1);
% param.eta_f = 2/eigs(A1'*A1,1);
% param.gamma = 10;
% disp('BiG-SAM Algorithm starts');
% [f_vec3,g_vec3,time_vec3,xlast3,tsa_SAM] = BigSAM(fun_f,grad_f,grad_g,fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
% disp('BiG-SAM Solution Achieved!');

%% Bi-SG Algorithm
% param.eta_g = 1/eigs(A2'*A2,1);
% param.eta_f = 2/eigs(A1'*A1,1);
% param.gamma = 10;
param.maxtime = time_vec1(end);
param.maxiter=1e8;
disp('Bi-SG Algorithm starts');
[f_vecBSG,g_vecBSG,time_vecBSG,xlastBSG,tsa_BSG] = BiSG(fun_f,grad_f,grad_g,fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
disp('Bi-SG Solution Achieved!');

% %% DBGD
% param.alpha = 1;
% param.beta = 1;
% param.stepsize = 5e-4;
% param.maxiter=1e8;
% param.maxtime = time_vec1(end);
% disp('DBGD Algorithm starts');
% [f_vec5,g_vec5,time_vec5,xlast5,tsa_DBGD] = DBGD(fun_f,grad_f,grad_g,fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
% disp('DBGD Solution Achieved!');

% %% MNG
% param.maxtime = time_vec1(end);
% param.maxiter=length(time_vec1);
% param.M = eigs(A2'*A2,1);
% 
% [f_vec4,g_vec4,time_vec4,xlast4,tsa_MNG] = MNG(A1,b1,fun_g,grad_g,@(x)TSA_LS(x,A3,b3),param,A1\b1);

%% SEA Algorithm
param.maxtime = time_vec1(end);
param.maxiter=1e6;

disp('SEA Algorithm starts')
[f_vecS,g_vecS,time_vecS,xlastS,tsa_S] = SEA(fun_f,grad_f,grad_g,...
    fun_g,@(x)TSA_LS(x,A3,b3),param,x_init,x_init,0);
disp('SEA Solution Achieved!');

%% R-APM Algorithm
param.maxtime = time_vec1(end);
param.maxiter=1e6;

disp('R-APM Algorithm starts')
[f_vecR,g_vecR,time_vecR,xlastR,tsa_R] = R_APM(fun_f,grad_f,grad_g,...
    fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
disp('R-APM Solution Achieved!');

%% PB-APG Algorithm
param.maxtime = time_vec1(end);
param.maxiter=1e6;

disp('PB-APG Algorithm starts')
[f_vecPB,g_vecPB,time_vecPB,xlastPB,tsa_PB] = PB_APG(fun_f,grad_f,grad_g,...
    fun_g,@(x)TSA_LS(x,A3,b3),param,x_init);
disp('PB-APG Solution Achieved!');


%% lower-level gap
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])

param.maxtime = time_vec1(end);
N_marker = 10;
time_idx = linspace(0,param.maxtime,N_marker);
marker_idx1 = zeros(N_marker,1);
marker_idx2 = zeros(N_marker,1);
marker_idx3 = zeros(N_marker,1);
marker_idx4 = zeros(N_marker,1);
marker_idx5 = zeros(N_marker,1);
marker_idxAGM = zeros(N_marker,1);
marker_idxR = zeros(N_marker,1);
marker_idxS = zeros(N_marker,1);
marker_idxBSG = zeros(N_marker,1);
marker_idxCB = zeros(N_marker,1);
marker_idxPB = zeros(N_marker,1);
for j=1:N_marker
    [~,idx] = min(abs(time_vec1-time_idx(j)));
    marker_idx1(j) = idx;
    [~,idx] = min(abs(time_vec2-time_idx(j)));
    marker_idx2(j) = idx;
%     [~,idx] = min(abs(time_vec3-time_idx(j)));
%     marker_idx3(j) = idx;
%     [~,idx] = min(abs(time_vec4-time_idx(j)));
%    marker_idx4(j) = idx;
%    [~,idx] = min(abs(time_vec5-time_idx(j)));
%    marker_idx5(j) = idx;
   [~,idx] = min(abs(time_vecR-time_idx(j)));
   marker_idxR(j) = idx;
   [~,idx] = min(abs(time_vecS-time_idx(j)));
   marker_idxS(j) = idx;
   [~,idx] = min(abs(time_vecAB-time_idx(j)));
   marker_idxAGM(j) = idx;
   [~,idx] = min(abs(time_vecBSG-time_idx(j)));
   marker_idxBSG(j) = idx;
   [~,idx] = min(abs(time_vecCB-time_idx(j)));
   marker_idxCB(j) = idx;
   [~,idx] = min(abs(time_vecPB-time_idx(j)));
   marker_idxPB(j) = idx;
end

% semilogy(time_vec1,abs(g_vec1-gstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx1);
% hold on;
% semilogy(time_vec4,abs(g_vec4-gstar),'d-','DisplayName','MNG','MarkerIndices', marker_idx4);
% hold on;

%semilogy(time_vec3,abs(g_vec3-gstar),'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx3);


semilogy(time_vec2,abs(g_vec2-gstar),'s-','DisplayName','a-IRG','MarkerIndices', marker_idx2);
hold on;
semilogy(time_vecCB,abs(g_vecCB-gstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idxCB,"Color","#FF0000");
semilogy(time_vecBSG,abs(g_vecBSG-gstar),'+-','DisplayName','Bi-SG','MarkerIndices', marker_idxBSG);
%semilogy(time_vec5,abs(g_vec5-gstar),'>-','DisplayName','DBGD','MarkerIndices', marker_idx5);
semilogy(time_vecS,abs(g_vecS-gstar),'>-','DisplayName','SEA','MarkerIndices', marker_idxS);
semilogy(time_vecR,abs(g_vecR-gstar),'^-','DisplayName','R-APM','MarkerIndices', marker_idxR);
semilogy(time_vecPB,abs(g_vecPB-gstar),'v-','DisplayName','PB-APG','MarkerIndices', marker_idxPB,"Color","#D95319");
semilogy(time_vecAB,abs(g_vecAB-gstar),'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM,"Color","#4DBEEE");

ylabel('$|g(\beta_k)-g^*|$')
xlabel('time (s)')
set(gca,'FontSize',20);
set(gca,'YLim',[1e-8,10])
legend('Interpreter','latex','Location','southwest')
grid on;
grid minor
pbaspect([1 0.8 1])

% print('-depsc2','-r600','./figs/lower_subopt_time.eps')
%% upper-level gap
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])

% semilogy(time_vec1,abs(f_vec1-fstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx1);
% hold on;
% semilogy(time_vec4,abs(f_vec4-fstar),'d-','DisplayName','MNG','MarkerIndices', marker_idx4);
% hold on;
%semilogy(time_vec3,abs(f_vec3-fstar),'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx3);

semilogy(time_vec2,abs(f_vec2-fstar),'s-','DisplayName','a-IRG','MarkerIndices', marker_idx2);
hold on;
semilogy(time_vecCB,abs(f_vecCB-fstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idxCB,"Color","#FF0000");
semilogy(time_vecBSG,abs(f_vecBSG-fstar),'+-','DisplayName','Bi-SG','MarkerIndices', marker_idxBSG);
%semilogy(time_vec5,abs(f_vec5-fstar),'>-','DisplayName','DBGD','MarkerIndices', marker_idx5);
semilogy(time_vecS,abs(f_vecS-fstar),'>-','DisplayName','SEA','MarkerIndices', marker_idxS);
semilogy(time_vecR,abs(f_vecR-fstar),'^-','DisplayName','R-APM','MarkerIndices', marker_idxR);
%hold on;
semilogy(time_vecPB,abs(f_vecPB-fstar),'v-','DisplayName','PB-APG','MarkerIndices', marker_idxPB,"Color","#D95319");
semilogy(time_vecAB,abs(f_vecAB-fstar),'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM,"Color","#4DBEEE");


ylabel('$|f(\beta_k)-f^*|$')
xlabel('time (s)')
set(gca,'YLim',[1e-10,1e1])
set(gca,'FontSize',20);
% set(gca,'YLim',[0,1])
legend('Interpreter','latex','Location','northeast')
grid on;
grid minor
pbaspect([1 0.8 1])

% print('-depsc2','-r600','./figs/upper_subopt_time.eps')
% %% test error
% figure;
% set(0,'defaulttextinterpreter','latex')
% set(gcf,'DefaultLineLinewidth',5)
% set(gcf,'DefaultLineMarkerSize',16);
% set(gcf,'Position',[331,167,591,586])
% % 
% % semilogy(time_vec1,tsa_BCG,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx1);
% % hold on;
% % semilogy(time_vec4,tsa_MNG,'d-','DisplayName','MNG','MarkerIndices', marker_idx4);
% % hold on;
% % semilogy(time_vec3,tsa_SAM,'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx3);
% 
% semilogy(time_vec2,tsa_AP,'s-','DisplayName','a-IRG','MarkerIndices', marker_idx2);
% hold on;
% semilogy(time_vecCB,tsa_CB,'o-','DisplayName','CG-BiO','MarkerIndices', marker_idxCB);
% semilogy(time_vecBSG,tsa_BSG,'+-','DisplayName','Bi-SG','MarkerIndices', marker_idxBSG);
% %semilogy(time_vec5,tsa_DBGD,'>-','DisplayName','DBGD','MarkerIndices', marker_idx5);
% semilogy(time_vecS,tsa_S,'>-','DisplayName','SEA','MarkerIndices', marker_idxS);
% semilogy(time_vecR,tsa_R,'s-','DisplayName','R-APM','MarkerIndices', marker_idxR);
% semilogy(time_vecAB,tsa_AB,'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM);
% 
% 
% ylabel('Test error')
% xlabel('time (s)')
% set(gca,'FontSize',24);
% % set(gca,'YLim',[0,1])
% legend('Interpreter','latex','Location','northeast')
% grid on;
% grid minor
% pbaspect([1 0.8 1])
% % print('-depsc2','-r600','./figs/test_error_time.eps')

%% lower-level gap (vs # of iterations)
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])
x_axis = [1:1:3e4];
param.maxtime = 3e4;
N_marker = 10;
time_idx = linspace(0,param.maxtime,N_marker);
marker_idx1 = zeros(N_marker,1);
marker_idx2 = zeros(N_marker,1);
marker_idx3 = zeros(N_marker,1);
marker_idx4 = zeros(N_marker,1);
marker_idx5 = zeros(N_marker,1);
marker_idxAGM = zeros(N_marker,1);
marker_idxR = zeros(N_marker,1);
marker_idxS = zeros(N_marker,1);
marker_idxBSG = zeros(N_marker,1);
marker_idxCB = zeros(N_marker,1);
marker_idxPB = zeros(N_marker,1);
for j=1:N_marker
    [~,idx] = min(abs(x_axis-time_idx(j)));
    marker_idx1(j) = idx;
    [~,idx] = min(abs(x_axis-time_idx(j)));
    marker_idx2(j) = idx;
    [~,idx] = min(abs(x_axis-time_idx(j)));
    marker_idx3(j) = idx;
    [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idx4(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idx5(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idxR(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idxS(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idxAGM(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idxBSG(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idxCB(j) = idx;
   [~,idx] = min(abs(x_axis-time_idx(j)));
   marker_idxPB(j) = idx;
end


% semilogy(time_vec1,abs(g_vec1-gstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx1);
% hold on;
% semilogy(time_vec4,abs(g_vec4-gstar),'d-','DisplayName','MNG','MarkerIndices', marker_idx4);
% hold on;

%semilogy(time_vec3,abs(g_vec3-gstar),'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx3);


semilogy(x_axis,abs(g_vec2(1:3e4)-gstar),'s-','DisplayName','a-IRG','MarkerIndices', marker_idx2);
hold on;
semilogy(x_axis,abs(g_vecCB(1:3e4)-gstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idxCB,"Color","#FF0000");
semilogy(x_axis,abs(g_vecBSG(1:3e4)-gstar),'+-','DisplayName','Bi-SG','MarkerIndices', marker_idxBSG);
%semilogy(time_vec5,abs(g_vec5-gstar),'>-','DisplayName','DBGD','MarkerIndices', marker_idx5);
semilogy(x_axis,abs(g_vecS(1:3e4)-gstar),'>-','DisplayName','SEA','MarkerIndices', marker_idxS);
semilogy(x_axis,abs(g_vecR(1:3e4)-gstar),'^-','DisplayName','R-APM','MarkerIndices', marker_idxR);
% semilogy(x_axis,abs(g_vecAB(1:3e4)-gstar),'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM);
semilogy(x_axis,abs(g_vecPB(1:3e4)-gstar),'v-','DisplayName','PB-APG','MarkerIndices', marker_idxPB,"Color","#D95319");
semilogy(x_axis,abs(g_vecAB(1:3e4)-gstar),'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM,"Color","#4DBEEE");

ylabel('$|g(\beta_k)-g^*|$')
xlabel('number of iterations')
set(gca,'FontSize',20);
set(gca,'YLim',[1e-7,1e2])
legend('Interpreter','latex','Location','southwest')
grid on;
grid minor
pbaspect([1 0.8 1])

% print('-depsc2','-r600','./figs/lower_subopt_time.eps')
%% upper-level gap (vs # of iterations)
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',5)
set(gcf,'DefaultLineMarkerSize',16);
set(gcf,'Position',[331,167,591,586])

% semilogy(time_vec1,abs(f_vec1-fstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idx1);
% hold on;
% semilogy(time_vec4,abs(f_vec4-fstar),'d-','DisplayName','MNG','MarkerIndices', marker_idx4);
% hold on;
%semilogy(time_vec3,abs(f_vec3-fstar),'^-','DisplayName','BiG-SAM','MarkerIndices', marker_idx3);

semilogy(x_axis,abs(f_vec2(1:3e4)-fstar),'s-','DisplayName','a-IRG','MarkerIndices', marker_idx2);
hold on;
semilogy(x_axis,abs(f_vecCB(1:3e4)-fstar),'o-','DisplayName','CG-BiO','MarkerIndices', marker_idxCB,"Color","#FF0000");
semilogy(x_axis,abs(f_vecBSG(1:3e4)-fstar),'+-','DisplayName','Bi-SG','MarkerIndices', marker_idxBSG);
%semilogy(time_vec5,abs(f_vec5-fstar),'>-','DisplayName','DBGD','MarkerIndices', marker_idx5);
semilogy(x_axis,abs(f_vecS(1:3e4)-fstar),'>-','DisplayName','SEA','MarkerIndices', marker_idxS);
semilogy(x_axis,abs(f_vecR(1:3e4)-fstar),'^-','DisplayName','R-APM','MarkerIndices', marker_idxR);
%hold on;
% semilogy(x_axis,abs(f_vecAB(1:3e4)-fstar),'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM);
semilogy(x_axis,abs(f_vecPB(1:3e4)-fstar),'v-','DisplayName','PB-APG','MarkerIndices', marker_idxPB,"Color","#D95319");
semilogy(x_axis,abs(f_vecAB(1:3e4)-fstar),'*-','DisplayName','AGM-BiO','MarkerIndices', marker_idxAGM,"Color","#4DBEEE");


ylabel('$|f(\beta_k)-f^*|$')
xlabel('number of iterations')
set(gca,'YLim',[1e-8,1e1])
set(gca,'FontSize',20);
% set(gca,'YLim',[0,1])
legend('Interpreter','latex','Location','northeast')
grid on;
grid minor
pbaspect([1 0.8 1])