function arf = AdProxGD(x0, ~, data, max_iters, tol)
% AdProxGD for smooth convex optimization (AdProxGD, Neural IPS 24')

% Gradient Descent for the given data
grad = data.gradh; L = data.L; f = data.f; 
f_true = data.f_true; prox = data.prox; % x_true = data.x_true;
grad = @(x) grad(x);

% Initialization
alpha = 1/L; alpha_old = alpha; theta = 1/3;
x = x0; fx = f(x); f0 = fx; 
g = grad(x); g0 = g; threshold = norm(g0,"fro")*tol;
x_old = x; g_old = g;
x = prox(x-alpha*g,alpha);


% store results
arf = zeros(max_iters+1,1);
arf(1) = 1; arf(2) = (f(x)-f_true)/(f0-f_true);

tic;
% Gradient Descent Loop
for iter = 1:max_iters
    g = grad(x);
    Lk = norm(g-g_old,"fro")/norm(x-x_old,"fro");
    alpha = min(sqrt(2/3+theta)*alpha_old, alpha_old/sqrt(max(0,2*alpha_old^2*Lk^2-1)));
    x_new = prox(x-alpha*g,alpha); compo_grad = (x_new - x)/alpha;
    theta = alpha/alpha_old;
    fx = f(x_new);
    arf(iter+2) = (fx-f_true)/(f0-f_true);
    if norm(compo_grad,"fro")<threshold
        fprintf('AdProxGD converges after %d iterations.]n', iter);
        break;
    end
    x_old = x; x = x_new; g_old = g; alpha_old = alpha;
end

elapsed_time = toc;
fprintf("Elapsed time = %f.\n",elapsed_time);
fprintf("Minimum value of f(x): %f.\n", f(x_new));
arf = arf(1:iter+2);

end