function [sol, info] = proxlintaurobust(A, b, gamma, beta, init_x, maxiter, tol, ...
    early_stop, alpha_0, show_info, robust)

% Get problem size
[m, ~] = size(A);

% Get initial point
x_before = init_x;
x_after = x_before;
bestx = init_x;

% Initialize arrays for storing results
objs = zeros(m * maxiter + 1, 1);
bestobjs = zeros(m * maxiter + 1, 1);
obj = getprObj(A, b, x_before);
bestobj = obj;
objs = objs + obj;
bestobjs = bestobjs + bestobj;

% Number of epochs before reaching tolerance
nepochs = maxiter;
niter = maxiter * m;

% Initialize info struct
info.status = "x";

if show_info
    fprintf("%6s %10s %10s %10s\n", 'epoch', 'obj', 'bobj', 'status');
end % End if

for k = 1:maxiter
    
    if bestobj < tol && nepochs == maxiter
        nepochs = k;
        niter = k * m + idx;
        info.status = "*";
        if early_stop
            if show_info
                fprintf("%6d %10.2e %10.2e %10s\n", k, obj, bestobj, info.status);
                fprintf("Early stopped. Status: Optimal\n");
                 bestobjs((k - 1) * m + 1:end) = bestobj;
                 objs((k - 1) * m + 1:end) = obj;
            end % End if
            break;
        end
    end % End if
    
    idx = 0;
    
    for i = randperm(m) % for i = randsample(1:m, m, true)
        
        idx = idx + 1;
        
        % Sample from dataset
        a = A(i, :);
        
        % Update momentum
        y = (1 + beta) * x_after - beta * x_before;
        
        if i > niter - 2
            xtarget = init_x;
            if k * niter + idx > 0.1 * sqrt(m * maxiter) && robust
                bestobjs((k * niter - niter) + idx + 1) = bestobj;
                objs((k * niter - niter) + idx + 1) = obj;
                continue;
            end % End if 
        else
            xtarget = x_after;
        end % End if
        
        x_before = x_after;
           
        gamma = gamma / alpha_0;
        aTx = a * xtarget;
        ksi = aTx^2 + 2 * aTx * a * (y - x_after) - b(i);
        zeta = 2 * aTx * a';
        coef = - ksi / (zeta'*zeta) * gamma;
        coef = sign(coef) * min(1, abs(coef));
        x_after = y + (coef * zeta / gamma);
        
        gamma = gamma * alpha_0;
        obj = sum(abs((A * x_after).^2 - b)) / m;
        
        if obj < bestobj
            bestobj = obj;
            bestx = x_after;
        end % End if
        
        bestobjs(k * m - m + idx + 1) = bestobj;
        objs(k * m - m + idx + 1) = obj;
        
        if isnan(obj) || isinf(obj)
            info.status = "Diverged";
            break;
        end % End if
        
    end % End for
    
    if isnan(obj) || isinf(obj)
        break;
    end % End if
    
    if show_info && (mod(k, 50) == 0 || k == 1)
        fprintf("%6d %10.2e %10.2e %10s\n", k, obj, bestobj, info.status);
    end % End if
    
end % End for

% Collect information
% Solution array
sol.x = x_after;
sol.bestx = bestx;

% Information array
info.nepochs = nepochs;
info.niter = niter;
info.objs = objs;
info.bestobjs = bestobjs;

% Display summary
if show_info
    if info.status == "*"
        disp("- Algorithm reaches optimal after " + nepochs + " epochs (" + ...
            niter + " iterations)");
    elseif info.status == "x"
        disp("- Algorithm fails to reach desired accuracy after " +...
            nepochs + " epochs");
    elseif info.status == "Diverged"
        disp("- Algorithm diverges");
        info.niter = maxiter * m;
    end % End if 
end % End if

end % End function