% MFISTA; See First order Methods in Optimization (Beck 2017)
function [list_loss,list_iter,list_steps] = fista(f_func, g_func, L0, beta, init_pt, N_iter, epsilon)
list_iter = zeros(length(init_pt),N_iter+1);
list_iter(:,1) = init_pt;
list_loss = zeros(N_iter+1,1);
list_steps = zeros(N_iter,1);
x = init_pt;
y = x;
t = 1;
L = L0/beta;
small_gradient = false;
f_func_x = f_func(x);
list_loss(1) = f_func_x;
for i_iter = 1:N_iter
    L = L*beta; % decrease L
    flag = true;
    steps = 0;

    f = f_func(y);
    g = g_func(y);
    steps = steps+1;
    while(flag) 
        z = y-g/L;
        f_func_z = f_func(z);
        if f_func_z <= f+g'*(z-y)+L/2*(z-y)'*(z-y)
            flag = false;
            if f_func_z < f_func_x
                x_new = z;
                f_func_x = f_func_z;
            else
                x_new = x;
            end
            % non-monotone version
            % x_new = z;
            % f_func_x = f_func_z;

            list_loss(i_iter+1) = f_func_x;
            t_new = (1+sqrt(1+4*t^2))/2;
            y = x_new+(t-1)/t_new*(x_new-x)+t/t_new*(z-x_new);
            x = x_new;
            t = t_new;
            list_iter(:,1+i_iter) = x;
            small_gradient = norm(g)<epsilon;
        else
            L = L/beta; % increase L
        end
    end
    list_steps(i_iter) = steps;
    if small_gradient
            break
    end
end
list_iter = list_iter(:,1:1+i_iter);
list_steps = [0;list_steps(1:i_iter)];
list_loss = list_loss(1:1+i_iter);
end