function[f_vec,g_vec,time_vec,x,acc_vec] = SEA(fun_f,grad_f,grad_g,fun_g,TSA, param,...
    initial_inner, initial_bi, initial_dual)

stepsize_inner = 1e-4;
stepsize_bi = 1e-4;
slater = 1e-4;

z0 = initial_inner;
x0 = initial_bi;
y0 = initial_dual;
g0 = -slater;
lambda=param.lam;

f_vec = [];
g_vec = [];
time_vec = [];
acc_vec = [];

min_z = z0;
fz0 = fun_g(z0);

maxiter = param.maxiter;
maxtime = param.maxtime;
tic;
for i = 1 : maxiter
    z1 = ProjectOntoL2Ball0(z0 - stepsize_inner * grad_g(z0),lambda);
    fz1 = fun_g(z1);
    if fz1 < fz0
        min_z = z1;
    x1 = ProjectOntoL2Ball0(x0 - stepsize_bi * (grad_f(x0) + y0 * grad_g(x0)),lambda);
    g1 = fun_g(x1) - fun_g(min_z) - slater;
    y1 = max(0, y0 + 2*g1 - g0);

    z0 = z1;
    x0 = x1;
    y0 = y1;
    g0 = g1;
    fz0 = fz1;
    end
    x = x0;
    cpu_t = toc;
    f_vec = [f_vec;fun_f(x)];
    g_vec = [g_vec;fun_g(x)];
    time_vec = [time_vec;cpu_t];
    % test set accuracy
    acc_vec = [acc_vec;TSA(x)];
    if cpu_t>maxtime
        break
    end
end



