clear variables
% close all
seed = 2022;
rng(seed)

addpath test_func/

% synthetic dataset
n_samples = 2e3;
p_features = 149;
X = randn(n_samples,p_features);
w = randn(p_features,1);
y = sign(X*w);
X = X+0.8*randn(n_samples,p_features);
X = X+ 1; % adding a constant term to make it correlated
X = [X,ones(n_samples,1)];
% mu = 5e-3; % strong convexity parameter
mu = 0;

obj = logistic(X,y,mu);
loss = @(w) obj.loss(w);
grad = @(w) obj.grad(w);
hess = @(w) obj.hessian(w);



N_iter = 1000;
n = p_features+1;
L1 = norm(X'*X/n_samples/4,2)+mu; % estimate the smooth parameter
M = 4*L1;
epsilon = 5e-9;

init = 5*ones(n,1);


%% Numerical methods

% QNPE
% B0 = hess(init);
B0 = 0*eye(n);
[list_loss_aqne,list_iter_aqne, list_eta_aqne, list_steps_aqne, best_iter] = AQNE_CR(loss, grad, L1, 1,0.5,0.5, 1, init, B0, N_iter, epsilon);
[~,N_iter_aqne] = size(list_iter_aqne);
[var_opt,~] = min(list_loss_aqne);
% pt_opt = list_iter_qls_lazy(:,I);
% pt_opt = best_iter;


% 
% FISTA
[list_loss_fista, list_iter_fista, list_steps_fista] = fista(loss, grad, L1, 0.5, init, 1.5*N_iter, epsilon);
[~,N_iter_fista] = size(list_iter_fista);
% list_loss_fista = zeros(N_iter_fista,1);
% for i=1:N_iter_fista
%     list_loss_fista(i) = loss(list_iter_fista(:,i));
% end
[var_opt_new,I_new] = min(list_loss_fista);
if var_opt_new < var_opt
    var_opt = var_opt_new;
    pt_opt = list_iter_fista(:,I_new);
end
%%




H0 = 1/L1*eye(n);
% H0 = inv(hess(init));

epsilon = 1e-8;

[list_iter_bfgs,list_steps_bfgs] = BFGS(loss, grad, init, H0, N_iter, epsilon);
[~,N_iter_bfgs] = size(list_iter_bfgs);
list_loss_bfgs = zeros(N_iter_bfgs,1);
for i=1:N_iter_bfgs
    list_loss_bfgs(i) = loss(list_iter_bfgs(:,i));
end
[var_opt_new,I_new] = min(list_loss_bfgs);
if var_opt_new < var_opt
    var_opt = var_opt_new;
    pt_opt = list_iter_bfgs(:,I_new);
end

%% Plots

% The function values
mark_skip = 50;
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',3)

semilogy(cumsum(list_steps_fista),list_loss_fista-var_opt,'^-','DisplayName','NAG w/ LS','MarkerIndices', 1:mark_skip:N_iter_fista)
hold on;
semilogy(cumsum(list_steps_aqne), list_loss_aqne-var_opt,'^-','DisplayName','A-QPNE','MarkerIndices', 1:mark_skip:N_iter_aqne)


semilogy(cumsum(list_steps_bfgs),list_loss_bfgs-var_opt,'^-','DisplayName','BFGS w/ LS','MarkerIndices', 1:mark_skip:N_iter_bfgs)

ylabel('$f(x_k)-f(x^*)$','Interpreter','latex')
xlabel('Gradient queries')
set(gca,'FontSize',18);
legend('Interpreter','latex','Location','northeast')
grid on;
print('-depsc2','-r600','figs/subopt_gradient.eps')

% per iteration
mark_skip = 50;
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',3)

semilogy(list_loss_fista-var_opt,'^-','DisplayName','NAG w/ LS','MarkerIndices', 1:mark_skip:N_iter_fista)
hold on;
semilogy(list_loss_aqne-var_opt,'^-','DisplayName','A-QPNE','MarkerIndices', 1:mark_skip:N_iter_aqne)

semilogy(list_loss_bfgs-var_opt,'^-','DisplayName','BFGS w/ LS','MarkerIndices', 1:mark_skip:N_iter_bfgs)


ylabel('$f(x_k)-f(x^*)$','Interpreter','latex')
xlabel('Iteration')
set(gca,'FontSize',18);
legend('Interpreter','latex','Location','northeast')
grid on;
print('-depsc2','-r600','figs/subopt_iteration.eps')





%% Plot the histogram of number of steps per iteration
figure;
set(gcf,'DefaultLineLinewidth',3)
histogram(list_steps_aqne(2:end),'Normalization','probability')
xlabel('Gradient queries')
set(gca,'FontSize',28);
set(gcf,'Position',[410,165,653,491])

print('-depsc2','-r600','figs/hist.eps')