% Acccelerated Quasi-Newton extragradient method with line search, using conjugate residual for
% solving the linear system
% g_func: gradient oracle
% L1: Lipschitz constant of gradient
% init_pt: initial point
% alpha1, alpha2, beta \in (0,1): line search parameters
% sigma_0: the initial trial stepsize
% N_iter: number of iterations
% epsilon: accuracy for stopping (when ||x_{k+1}-x_k||<= epsilon
function [list_loss, list_iter, list_eta, list_steps, best_iter] = AQNE_CR(loss, g_func, L1, ...
    alpha1, alpha2, beta, sigma_0, init_pt, B_0, N_iter, epsilon)
d = length(init_pt); % the  dimension
list_loss = zeros(N_iter+1,1); % record the loss
list_iter = zeros(d,N_iter+1); % record the iterates
% list_dis = zeros(N_iter,1); % record the displacements
list_eta = zeros(N_iter,1); % record the stepsizes
list_steps = zeros(N_iter,1); % record the number of steps in each linesearch

list_iter(:,1) = init_pt;
list_loss(1) = loss(init_pt);

list_Hloss = zeros(N_iter,1);

x = init_pt;
best_iter = x;
best_loss = list_loss(1);


% d = length(init_pt);
% B_0 = (mu)*eye(d);
% B_0 = hess_func(x);
B = B_0;
W = 2/(L1)*(B-(L1)/2*eye(d));
gamma = 1;

t = 1;
% p = 0.01;
rho = 1/2;

list_iter(:,1) = init_pt;

A = 0;
eta = sigma_0;
x = init_pt;
z = init_pt;

for i_iter = 1:N_iter
    a = (eta+sqrt(eta^2+4*eta*A))/2;
    y = A/(A+a)*x + a/(A+a)*z;
    g = g_func(y);
    % B = hess(y);
    [x_hat,eta_hat,g_hat,dx,dg,steps,flag_b] = line_search2(y,g,B,g_func,alpha1,alpha2,beta,eta);
    steps = steps+1;
    % dis = norm(x_hat-x);
    % g_hat = g_func(x_hat);
    
    if ~flag_b
        if loss(x_hat) < loss(x)
            x = x_hat;
        end
        z = z-a*g_hat;
        A = A+a;
        eta = eta_hat/beta;
    else
        % B = quasi_Newton_update(B,dx,dg,L1,mu);
        % B = BFGS_update(B,dx,dg);
        gamma_e = eta_hat/eta;
        x_damped = (1-gamma_e)*A/(A+gamma_e*a)*x + gamma_e*(A+a)/(A+gamma_e*a)*x_hat;
        if loss(x_damped) < loss(x)
            x = x_damped;
        end
        z = z-gamma_e*a*g_hat;
        A = A+gamma_e*a;
        eta = eta_hat;
        % eta = eta*beta;

        % B = quasi_Newton_update(B,dx,dg,L1,0);
        
        % Hessian approximation update
         w = dg-B*dx;
         grad_B = -1/(2*(dx'*dx))*(dx*w'+w*dx');
         
         list_Hloss(i_iter) = norm(w)^2/norm(dx)^2;
         
         GB = 2/(L1)*vec(grad_B)'*vec(B);
         if gamma<=1 || GB >= 0
             G_tilde = 2/(L1)*grad_B;
         else
             G_tilde = 2/(L1)*grad_B-GB*S;
         end
         W = W-rho*G_tilde;
         W_norm = norm(W,'fro');
         % projection on the ball
         if W_norm > sqrt(d)
             W = sqrt(d)*W/W_norm;
         end
% 
%         q = p/2.5/(t+1);
         delta = 1/sqrt(t+1);
         [gamma,S] = ExtEvec_eig(W);
         if gamma<=1
             B = W;
             B = (L1)/2*B+(L1)/2*eye(d);
         else
             B = W/gamma;
             B = (L1)/2*B+(L1)/2*eye(d);
         end
         t = t+1;
    end
    % B = BFGS_update(B,x_hat-x,g_hat-g);
    
    loss_x = loss(x);
    list_loss(i_iter+1) = loss_x;
    if loss_x < best_loss
        best_iter = x;
        best_loss = loss_x;
    end
    list_iter(:,i_iter+1) = x;
    % list_dis(i_iter) = dis;
    list_eta(i_iter) = eta;
    list_steps(i_iter) = steps;
    if norm(g)<epsilon || norm(g_hat)<epsilon
        break
    end
end
    list_loss = list_loss(1:i_iter+1);
    list_iter = list_iter(:,1:i_iter+1); 
    % list_dis = list_dis(1:i_iter); 
    list_eta = list_eta(1:i_iter); 
    list_steps = [0;list_steps(1:i_iter)];
end

% function [x_hat,eta,dis,res,steps] = line_search(x,g,dir,g_func,alpha,beta,const)
%     eta_lo = 0;
%     eta_hi = 1;
%     steps = 0;
% 
% 
%     flag = false;
%     while ~flag
%             eta_mid = (eta_lo+eta_hi)/2;
%             x_hat = x-eta_mid*dir;
%             g_new = g_func(x_hat);
%             dis = eta_mid*norm(dir);
%             res = g_new-(1-eta_mid)*g;
%             steps = steps+1;
%         if norm(res) > alpha*const*(1-eta_mid)*norm(dir)
%             eta_hi = eta_mid;
%         elseif norm(res) < beta*const*(1-eta_mid)*norm(dir)
%             eta_lo = eta_mid;
%         else
%             flag = true;
%         end
%     end
%     eta = eta_mid/(1-eta_mid)/const;
% end

function [x_hat_lo,eta_lo,g_lo,dx_hi,dg_hi,steps,flag_b] = line_search(x,g,B,g_func,alpha,beta,sigma)
    d = length(x);
    eta = sigma;
    x_hat = x-eta*((eye(d)+eta*B)\g); % update rule
    g_new = g_func(x_hat);
    res = x_hat-x+eta*g_new;
    dx = x_hat-x;
    steps = 1;
    
    dx_hi = 0;
    dg_hi = 0;

    flag = false;

    if norm(res) > alpha*norm(dx)
        %% backtracking
        flag_b = true;
        while ~flag
            eta_hi = eta;
            dx_hi = dx;
            dg_hi = g_new-g;
            eta = beta*eta^2/sigma;
            x_hat = x-eta*((eye(d)+eta*B)\g);
            g_new = g_func(x_hat);
            res = x_hat-x+eta*g_new;
            dx = (x_hat-x);
            steps = steps+1;
            flag = (norm(res) <= alpha*norm(dx));
        end
        eta_lo = eta;
        x_hat_lo = x_hat;
        g_lo = g_new;
    else
        %% stop
%         while ~flag
%             eta_lo = eta;
%             eta = eta^2/(beta*sigma);
%             x_hat = x-eta*((eye(d)+eta*B)\g);
%             g_new = g_func(x_hat);
%             res = x_hat-x+eta*g_new;
%             dis = norm(x_hat-x);
%             steps = steps+1;
%             flag = (norm(res) > alpha*norm(dis));
%         end
        flag_b = false;
        % res_hi = res;
        % dis_hi = dis;
        eta_lo = eta;
        x_hat_lo = x_hat;
        g_lo = g_new;
        return
    end
    %% bisection
    while eta_hi/eta_lo > 1/beta+eps
            eta = sqrt(eta_lo*eta_hi);
            x_hat = x-eta*((eye(d)+eta*B)\g);
            g_new = g_func(x_hat);
            res = x_hat-x+eta*g_new;
            dx = (x_hat-x);
            steps = steps+1;
        if norm(res) > alpha*norm(dx)
            eta_hi = eta;
            dx_hi = dx;
            dg_hi = g_new-g;
            % res_hi = res;
            % dis_hi = dis;
        else
            eta_lo = eta;
            x_hat_lo = x_hat;
            g_lo = g_new;
        end
    end

%     flag = false;
%     while ~flag
%             eta_mid = (eta_lo+eta_hi)/2;
%             x_hat = x-eta_mid*dir;
%             g_new = g_func(x_hat);
%             dis = eta_mid*norm(dir);
%             res = g_new-(1-eta_mid)*g;
%             steps = steps+1;
%         if norm(res) > alpha*const*(1-eta_mid)*norm(dir)
%             eta_hi = eta_mid;
%         elseif norm(res) < beta*const*(1-eta_mid)*norm(dir)
%             eta_lo = eta_mid;
%         else
%             flag = true;
%         end
%     end
%     eta = eta_mid/(1-eta_mid)/const;
end

function [x_hat_lo,eta_lo,g_lo,dx_hi,dg_hi,steps,flag_b] = line_search2(x,g,B,g_func,alpha1,alpha2,beta,sigma)
    d = length(x);
    eta = sigma;
    [dx,MV_count] = LinearSolver(eye(d)+eta*B,-eta*g,alpha1,zeros(d,1));
    x_hat = x+dx;
    % x_hat = x-eta*((eye(d)+eta*B)\g); % update rule
    g_new = g_func(x_hat);
    res = x_hat-x+eta*g_new;
    % dx = x_hat-x;
    steps = 1;
    
    dx_hi = 0;
    dg_hi = 0;

    flag = false;

    if norm(res) > (alpha1+alpha2)*norm(dx)
        %% backtracking
        flag_b = true;
        while ~flag
            dx_hi = dx;
            dg_hi = g_new-g;
            eta = beta*eta;
            [dx,MV_count] = LinearSolver(eye(d)+eta*B,-eta*g,alpha1,dx*beta);
            x_hat = x+dx;
            % x_hat = x-eta*((eye(d)+eta*B)\g);
            g_new = g_func(x_hat);
            res = x_hat-x+eta*g_new;
            % dx = (x_hat-x);
            steps = steps+1;
            flag = (norm(res) <= (alpha1+alpha2)*norm(dx));
        end
        eta_lo = eta;
        x_hat_lo = x_hat;
        g_lo = g_new;
    else
        flag_b = false;
        eta_lo = eta;
        x_hat_lo = x_hat;
        g_lo = g_new;
        return
    end
end

function B = quasi_Newton_update(B,dx,dg,L,mu)
z = dg-B*dx;
stp = 1/2;
B_new = B+stp/(2*(dx'*dx))*(dx*z'+z*dx');
B_new = (B_new+B_new)/2;
[V,D] = eigs(B_new,2,'bothendsreal');
if ~any(isnan(diag(D)))
    lambdan = real(D(1,1));
    lambda1 = real(D(2,2));
    vn = V(:,1);
    v1 = V(:,2);
    if lambdan>lambda1
        temp = lambda1;
        lambda1 = lambdan;
        lambdan = temp;
        temp = v1;
        v1 = vn;
        vn = temp;
    end
    B = B_new-max(0,lambda1-L)*(v1*v1')+max(0,mu-lambdan)*(vn*vn');
else
    % In case the eigenvalues did not converge
    [V,D] = eig(B_new);
    D = min(max(D,mu),L);
    B = V*D*V';
end
% lambdan = real(D(1,1));
% lambda1 = real(D(2,2));
% vn = V(:,1);
% v1 = V(:,2);
% if lambdan>lambda1
%     temp = lambda1;
%     lambda1 = lambdan;
%     lambdan = temp;
%     temp = v1;
%     v1 = vn;
%     vn = temp;
% end
% B = B_new-max(0,lambda1-L)*(v1*v1')+max(0,mu-lambdan)*(vn*vn');
end

function B = BFGS_update(B,s,y)
B = B+ y*y'/(y'*s)-(B*s)*(B*s)'/(s'*B*s);
end

function [s,count_MV] = LinearSolver(A,b,alpha,s0)
% d = length(b);
% s = zeros(d,1);
s = s0;
r = b-A*s;
p = r;
Ar = A*r;
Ap = Ar;
count_MV = 1;
while(true)
    if norm(r) <= alpha*norm(s)
        return;
    end
    alpha_k = r'*Ar/(Ap'*Ap);
    s = s+alpha_k*p;
    r_new = r-alpha_k*Ap;
    Ar_new = A*r_new;
    count_MV = count_MV+1;
    beta = r_new'*Ar_new/(r'*Ar);
    p = r_new+beta*p;
    Ap = Ar_new+beta*Ap;
    r = r_new;
    Ar = Ar_new;
end
end

function [gamma,S] = ExtEvec(W,delta)
% adapted from https://github.com/alpyurtsever/SketchyCGAL/blob/master/solver/CGAL.m
W = (W+W')/2;
[d,~] = size(W);
% epsilon = delta/2/(1+delta);
% N = ceil(1/4*1/sqrt(epsilon)*log(11*d));
% N = min(N,d-1);
N = ceil(log(44*d));
Q = zeros(d, N+1);                  % Lanczos vectors
aleph = zeros(N,1);                 % Diagonal Lanczos coefs
beth = zeros(N,1);                  % Off-diagonal Lanczos coefs
Q(:,1) = randn(d, 1);               % First Lanczos vector is random
Q(:,1) = Q(:,1) / norm(Q(:,1));

for i = 1 : N
    Q(:, i+1) = W * Q(:, i) ;				% Apply M to previous Lanczos vector
    aleph(i) = real(Q(:, i)' * Q(:, i+1));		% Compute diagonal coefficients
    
    if (i == 1)                     % Lanczos iteration
        Q(:, i+1) = Q(:, i+1) - aleph(i) * Q(:, i);
    else
        Q(:, i+1) = Q(:, i+1) - aleph(i) * Q(:, i) - beth(i-1) * Q(:, i-1);
    end
    
    beth(i) = norm( Q(:, i+1) );            % Compute off-diagonal coefficients
    
    if ( abs(beth(i)) < sqrt(d)*eps ), break; end
    
    Q(:, i+1) = Q(:, i+1) / beth(i);        % Normalize
    
end

% i contains number of completed iterations

B = diag(aleph(1:i), 0) + diag(beth(1:(i-1)), +1) + diag(beth(1:(i-1)), -1);

[V, D] = eig(0.5*(B+B'));
eigenvalues = real(diag(D));
[lambdan, min_index] = min(eigenvalues);
[lambda1, max_index] = max(eigenvalues);
vn = Q(:, 1:i)*V(:, min_index);
v1 = Q(:, 1:i)*V(:, max_index);

lam_max = max(lambda1, -lambdan);
if lam_max <= 0.5
    gamma = 2*lam_max;
    S = 0;
    return
elseif lam_max >= 2
    gamma = 2*lam_max;
      if lambda1 > -lambdan
          S = 3*(v1*v1');
      else
          S = -3*(vn*vn');
      end
      return
end

N = min(ceil(1/(4*sqrt(2*delta))*log(44*d)),d);

Q = zeros(d, N+1);                  % Lanczos vectors
aleph = zeros(N,1);                 % Diagonal Lanczos coefs
beth = zeros(N,1);                  % Off-diagonal Lanczos coefs
Q(:,1) = randn(d, 1);               % First Lanczos vector is random
Q(:,1) = Q(:,1) / norm(Q(:,1));

for i = 1 : N
    Q(:, i+1) = W * Q(:, i) ;				% Apply M to previous Lanczos vector
    aleph(i) = real(Q(:, i)' * Q(:, i+1));		% Compute diagonal coefficients
    
    if (i == 1)                     % Lanczos iteration
        Q(:, i+1) = Q(:, i+1) - aleph(i) * Q(:, i);
    else
        Q(:, i+1) = Q(:, i+1) - aleph(i) * Q(:, i) - beth(i-1) * Q(:, i-1);
    end
    
    beth(i) = norm( Q(:, i+1) );            % Compute off-diagonal coefficients
    
    if ( abs(beth(i)) < sqrt(d)*eps ), break; end
    
    Q(:, i+1) = Q(:, i+1) / beth(i);        % Normalize
    
end

% i contains number of completed iterations

B = diag(aleph(1:i), 0) + diag(beth(1:(i-1)), +1) + diag(beth(1:(i-1)), -1);

[V, D] = eig(0.5*(B+B'));
eigenvalues = real(diag(D));
[lambdan, min_index] = min(eigenvalues);
[lambda1, max_index] = max(eigenvalues);
vn = Q(:, 1:i)*V(:, min_index);
v1 = Q(:, 1:i)*V(:, max_index);

lam_max = max(lambda1, -lambdan);
gamma = lam_max + delta;
if lam_max < 1-delta
    S = 0;
elseif lambda1 > -lambdan
    S = v1*v1';
else
    S = -vn*vn';
end

% % set the tolerance
% [V,D] = eigs(W,2,'bothendsreal','Tolerance',1e-6,'MaxIterations',min(300,ceil(d/5)));
% if ~any(isnan(diag(D)))
%     lambdan = real(D(1,1));
%     lambda1 = real(D(2,2));
%     vn = V(:,1);
%     v1 = V(:,2);
%     if lambdan>lambda1
%         temp = lambda1;
%         lambda1 = lambdan;
%         lambdan = temp;
%         temp = v1;
%         v1 = vn;
%         vn = temp;
%     end
% else
%     % In case the eigenvalues do not converge
%     [V,D] = eig(W);
%     eigenvalues = real(diag(D));
%     [lambdan, min_index] = min(eigenvalues);
%     [lambda1, max_index] = max(eigenvalues);
%     vn = V(:, min_index);
%     v1 = V(:, max_index);
% end

% gamma = max(lambda1,-lambdan);
% if gamma<= 1
%     S = 0;
% elseif v1>=-vn
%     S = v1*v1';
% else
%     S = -vn*vn';
% end


end

function [gamma,S] = ExtEvec_eig(W)
W = (W+W')/2;

    % In case the eigenvalues do not converge
[V,D] = eig(W);
eigenvalues = real(diag(D));
[lambdan, min_index] = min(eigenvalues);
[lambda1, max_index] = max(eigenvalues);
vn = V(:, min_index);
v1 = V(:, max_index);

gamma = max(lambda1,-lambdan);
if gamma<= 1
    S = 0;
elseif v1>=-vn
    S = v1*v1';
else
    S = -vn*vn';
end


end