function [last_iter , f_hist, g_hist, sample_vec] = CG_lowerlevel_STORM(A2,b2,fun_f,fun_g,grad_g,x0,param)
% Standard CG algorithm for the lower-level problem
disp('CG for the lower level starts');

epsilon_g= param.epsilong;
lambda1 = param.lam1;
% gamma = param.gamma;
x = x0; % Intital point
x_prev = x0;

f_hist = fun_f(x0);
g_hist = fun_g(x0);
sample_vec = 0;
g = 0;

iteration = 0;
maxiteration = param.maxiter;
n = height(A2);
while iteration <= maxiteration
    iteration = iteration+1;
    gam = 2/(iteration+2);
    alpha = 1/iteration;
    % Sample functions
    lowidx = randsample(n,1);
    grad_gi= @(x) (n)*A2(lowidx,:)'*(A2(lowidx,:)*x-b2(lowidx,:));
    
    % Build estimators
    g = (1-alpha)*g + grad_gi(x) - (1-alpha)*grad_gi(x_prev);

    % Solve linear minimization
    dir = linear_l1(g,lambda1);
%     if grad_g(x)'*(x-dir)<=epsilon_g
%         break;
%     end
    x_prev = x; 
    x = (1-gam)*x + gam*dir;
    f_hist = [f_hist;fun_f(x)];
    g_hist = [g_hist;fun_g(x)];
    sample_vec = [sample_vec;iteration];

end
disp('CG for the lower level is solved!');
last_iter = x;
end

function x = linear_l1(c,lambda)
% find x to minimize c'*x  s.t.  norm(x,1)<=lambda
x = sparse(length(c),1);
[~,ind] = max(abs(c));
x(ind) = -sign(c(ind))*lambda;
end
