function varargout = A2GD_weaklyconvex(x0, ~, data, max_iters, tol)
% An implementation of A2GD method for weakly convex objectives.

%% Initialization
f = data.f; grad = data.gradf; L = data.L; f_true = data.f_true;
arf = zeros(3*max_iters,1); % function values by gradient eval #
gls = zeros(3*max_iters,1); % gradient eval for line search, 0:no/1:yes
arf_iter = zeros(max_iters+2,1);
x = x0; f0 = f(x); pb = 0; elapse = 0;
g = grad(x); g0 = g; arf(1) = 1; threshold = norm(g0,"fro") * tol;
arf_iter(1) = 1;
epsilon = L/10^6; mk = 5; % lower bound for mu_k

%% Warmup
warmup = 10; 
theta = 1/3; Lk = L; muk = L/10; 
alpha = 1/L; alpha_old = alpha; 
x_old = x; g_old = g;
x = x - alpha*g; g = grad(x); arf(2) = (f(x)-f_true)/(f0-f_true); arf_iter(2) = (f(x)-f_true)/(f0-f_true);
geval = 2; % number of grad evaluations


for iter = 1:warmup
    tic;
    Lk = norm(g-g_old,"fro")/norm(x-x_old,"fro"); muk = min(muk, Lk);
    alpha = min(sqrt(2/3+theta)*alpha_old, alpha_old/sqrt(max(0,2*alpha_old^2*Lk^2-1)));
    x_new = x - alpha*g; theta = alpha/alpha_old; x_old = x; x = x_new; 
    g_old = g; g = grad(x); geval = geval + 1; arf(geval) = (f(x)-f_true)/(f0-f_true); arf_iter(iter+2) = (f(x)-f_true)/(f0-f_true);
    alpha_old = alpha;
    elapse = elapse + toc;
    if norm(norm(g),"fro") < threshold
        fprintf('Converged after %d iterations.', iter); break;
    end
end

%% Accelerated Adaptive Gradient Descent
fx = f(x); y = x; lambda = muk; R = 100*norm(g)/muk; inner_iter = 0;
fprintf("Estimated Radius %f. Actual Radius %f.", R, norm(x,"fro"));
for iter = warmup+1:max_iters
    tic;
    ls_flag = 0;
    while true
        alpha = sqrt(lambda/Lk);
        x_new = alpha/(1+alpha)*y + 1/(1+alpha)*x - 1/(Lk*(1+alpha))*g; 
        g_new = grad(x_new); geval = geval + 1; gls(geval) = ls_flag;
        y_new = y/(1+alpha) + alpha*x_new/(1+alpha)-(alpha/(lambda*(1+alpha)))*g_new;
        b11 = 0.5*norm(g-g_new,'fro')^2/Lk; b12 = 0.5*sum(dot(g_new-g,x_new-x));
        b1 =  b11 - b12 ;
        b2 = 0.5*(lambda^(3/2)/sqrt(Lk))*(R^2-(1+alpha)*norm(x_new-y_new)^2) - 0.5*norm(g,"fro")^2/Lk;
        pb_temp = (pb + b1 + b2)/(1+alpha);
        if (pb_temp > 0) && (lambda > epsilon)
            ls_flag = 1;
            arf(geval) = arf(geval-1); 
            if b1 > 0
                v1 = b12/b11;
                Lk = min(3 / v1, 100) * Lk;
            end
            if b2 > 0
                lambda_temp = norm(g,"fro")^(4/3)/((R^2-(1+alpha)*norm(x_new-y_new)^2)^(2/3)*Lk^(1/3));
                lambda = max(epsilon, min(lambda, lambda_temp) );
            end
        else
            pb = pb_temp;
            v1 = b12/b11;
            Lk = min(max(1/v1, 0.1), 10) * Lk;
            if b2 > 0
                lambda_next = norm(g,"fro")^(4/3)/((R^2-(1+alpha)*norm(x_new-y_new)^2)^(2/3)*Lk^(1/3));
                lambda = max(epsilon, min(lambda, lambda_next) );
            end
            break;
        end
    end
    % finally, update lower bound
    
    if lambda > epsilon
        fprintf('pb works. ')
    else
        fprintf('epsilon works. ')
    end
    %sq_normG_D = L * norm(y_new-x_new-g_new/(Lk*alpha))^2 + lambda * norm(x_new-y_new-g_new/lambda)^2;
    sq_ratio = norm(g,"fro")^2/norm(g0,"fro")^2;
    if (inner_iter < mk) && (sq_ratio > (R^2+1)*epsilon/2)
        inner_iter = inner_iter + 1;
    else
        inner_iter = 0; epsilon = epsilon/2; mk = floor(sqrt(2)*mk)+1;
    end
    % update iterates
    f_new = f(x_new);
    if fx - f_new > 0
        x = x_new; y = y_new; g = g_new; fx = f_new; 
        arf(geval) = (fx-f_true)/(f0-f_true); 
    else
        y = y_new;
        arf(geval) = arf(geval-1);
    end
    % fprintf("Iter %d, f= %f.\n", iter,fx);
    arf_iter(iter+2) = (fx-f_true)/(f0-f_true); 
    if (fx-f_true)/(f0-f_true) >= arf_iter(iter-2) % function value not decay for 5 iterations
        y = x; % restart
        % fprintf("Iter %d restart.",iter);
    end
    elapse = elapse + toc;
    if norm(g,"fro") < threshold
        fprintf('A2GD converges after %d iterations, %d gradient evaluations.\n', iter, geval);
        break;
    end
    if geval >= max_iters
        break;
    end
end
t = elapse;
fprintf("Elapsed time = %f.\n",t);
fprintf("Minimum value of f(x): %f.\n", f(x_new));
arf = arf(1:geval); gls = gls(1:geval);
if nargout >= 1
    varargout{1} = arf;
end
if nargout >= 2
    varargout{2} = gls;
end
if nargout >= 3
    varargout{3} = x_new;
end
end