function [arf,gls] = A2GD_weaklyconvex_composite(x0, ~, data, max_iters, tol)
% An implementation of A2GD method for composite objectives.

%% Initialization
f = data.f; grad = data.gradh; L = data.L; f_true = data.f_true; prox = data.prox;
arf = zeros(3*max_iters,1); % function values by gradient eval #
gls = zeros(3*max_iters,1); % gradient eval for line search, 0:no/1:yes
x = x0; f0 = f(x); pb = 0; elapse = 0;
g = grad(x); g0 = g; arf(1) = 1; threshold = norm(g0,"fro") * tol;
epsilon = L/10^6; mk = 5; % lower bound for mu_k

%% Warmup
warmup = 10; 
theta = 1/3; Lk = L; muk = L/10; 
alpha = 1/L; alpha_old = alpha; 
x_old = x; g_old = g;
x = prox(x-alpha*g,alpha); g = grad(x); arf(2) = (f(x)-f_true)/(f0-f_true);
geval = 2; % number of grad evaluations


for iter = 1:warmup
    tic;
    Lk = norm(g-g_old,"fro")/norm(x-x_old,"fro"); muk = min(muk, Lk);
    alpha = min(sqrt(2/3+theta)*alpha_old, alpha_old/sqrt(max(0,2*alpha_old^2*Lk^2-1)));
    x_new = prox(x-alpha*g,alpha); compo_grad = (x_new - x)/alpha;
    theta = alpha/alpha_old; x_old = x; x = x_new; 
    g_old = g; g = grad(x); geval = geval + 1; arf(geval) = (f(x)-f_true)/(f0-f_true); 
    alpha_old = alpha;
    elapse = elapse + toc;
    if norm(compo_grad,"fro") < threshold
        fprintf('Converged after %d iterations.', iter); break;
    end
end

%% Accelerated Adaptive Gradient Descent
fx = f(x); y = x; lambda = muk; R = min(100*norm(compo_grad)/muk,10000); inner_iter = 0;
% fprintf("Estimated Radius %f. Actual Radius %f.", R, norm(x,"fro"));
for iter = warmup+1:max_iters
    tic;
    ls_flag = 0;
    while true
        alpha = sqrt(lambda/Lk);
        x_new = alpha/(1+alpha)*y + 1/(1+alpha)*x - 1/(Lk*(1+alpha))*g; 
        x_new = prox(x_new, 1/(Lk*(1+alpha)));
        g_new = grad(x_new); geval = geval + 1; gls(geval) = ls_flag;
        q_new = -Lk*(x_new-x) + Lk*alpha*(y-x_new)-g;
        y_new = y/(1+alpha) + alpha*x_new/(1+alpha)-(alpha/(lambda*(1+alpha)))*(g_new+q_new);
        b11 = 0.5*norm(g-g_new,'fro')^2/Lk; b12 = 0.5*sum(dot(g_new-g,x_new-x));
        b1 =  b11 - b12 ;
        b2 = 0.5*(lambda^(3/2)/sqrt(Lk))*(R^2-(1+alpha)*norm(x_new-y_new)^2) - 0.5*norm(g+q_new,"fro")^2/Lk;
        pb_temp = (pb + b1 + b2)/(1+alpha);
        if (pb_temp > 0) && (lambda > epsilon)
            ls_flag = 1;
            arf(geval) = arf(geval-1); 
            if b1 > 0
                v1 = b12/b11;
                Lk = min(3 / v1, 100) * Lk;
            end
            if b2 > 0
                lambda_temp = norm(g+q_new,"fro")^(4/3)/((R^2-(1+alpha)*norm(x_new-y_new)^2)^(2/3)*Lk^(1/3));
                lambda = max(epsilon, min(lambda, lambda_temp) );
            end
        else
            pb = pb_temp;
            v1 = b12/b11;
            Lk = min(max(1/v1, 0.1), 10) * Lk;
            if b2 > 0
                lambda_next = norm(g+q_new,"fro")^(4/3)/((R^2-(1+alpha)*norm(x_new-y_new)^2)^(2/3)*Lk^(1/3));
                lambda = max(epsilon, min(lambda, lambda_next) );
            end
            break;
        end
    end
    sq_ratio = norm(g,"fro")^2/norm(g0,"fro")^2;
    if (inner_iter < mk) && (sq_ratio > (R^2+1)*epsilon/2)
        inner_iter = inner_iter + 1;
    else
        inner_iter = 0; epsilon = epsilon/2; mk = floor(sqrt(2)*mk)+1;
    end
    % update iterates
    f_new = f(x_new);
    if fx - f_new > 0
        x = x_new; y = y_new; g = g_new; fx = f_new; 
        arf(geval) = (f(x)-f_true)/(f0-f_true);
    else
        y = y_new;
        arf(geval) = arf(geval-1);
    end
    if arf(geval) >= arf(geval - 4)
        y = x; % restart
    end
    elapse = elapse + toc;
    if norm(g+q_new,"fro") < threshold
        fprintf('A2GD converges after %d iterations, %d gradient evaluations.\n', iter, geval);
        break;
    end
end
t = elapse;
fprintf("Elapsed time = %f.\n",t);
fprintf("Minimum value of f(x): %f.\n", f(x_new));
arf = arf(1:geval); gls = gls(1:geval);
end