function [sol, info] = proxsgd(A, b, gamma, beta, init_x, maxiter, tol, ...
    early_stop, alpha_0, show_info)

batch = 1;
% Get problem size
[m, ~] = size(A);

% Get initial point
x_before = init_x;
x_after = x_before;
bestx = init_x;

objs     = zeros(maxiter * floor(m / batch) + 1, 1);
bestobjs = zeros(maxiter * floor(m / batch) + 1, 1);
obj = getprObj(A, b, x_before);
bestobj = obj;
objs = objs + obj;
bestobjs= bestobjs + bestobj;

% Number of epochs before reaching tolerance
nepochs = maxiter;
nbatchiter = maxiter * m / batch;

% Initialize info struct
info.status = "x";

if show_info
    fprintf("%6s %10s %10s %10s\n", 'epoch', 'obj', 'bobj', 'status');
end % End if

niter = m;
idx = 0; 
for k = 1:maxiter
    
    if bestobj < tol && nepochs == maxiter
        nepochs = k;
        nbatchiter = k * niter + idx;
        info.status = "*";
        if early_stop
            if show_info
                fprintf("%6d %10.2e %10.2e %10s\n", k, obj, bestobj, info.status);
                fprintf("Early stopped. Status: Optimal\n");
                bestobjs((k - 1) * niter + 1:end) = bestobj;
                objs((k - 1) * niter + 1:end) = obj;
            end % End if
            break;
        end
    end % End if
    
    idx = 0;    
    niter = m;
        
    for i = randperm(niter) % for i = randsample(1:m, m, true)
        
        idx = idx + 1;
                
        % Sample from dataset
        a = A(i, :);
        
        % Update momentum
        y = (1 + beta) * x_after - beta * x_before;
        
        aTx = a * x_after;
        x_before = x_after;
        
        aTxA = aTx .* a;
        
        sgn = sign(aTx.^2 - b(i));
        sgn = sgn + (2 * rand - 1).*(1 - abs(sgn));
        subgrad = 2 * aTxA' .* sgn;
        x_after = y - (alpha_0 * subgrad / gamma);
        obj = getprObj(A, b, x_after);
        
        if obj < bestobj
            bestobj = obj;
            bestx = x_after;
        end % End if
        
        bestobjs((k * niter - niter) + idx + 1) = bestobj;
        objs((k * niter - niter) + idx + 1) = obj;
        
        if isnan(obj) || isinf(obj)
            info.status = "Diverged";
            break;
        end % End if
        
    end % End for
    
    if isnan(obj) || isinf(obj)
        break;
    end % End if
    
    if show_info && (mod(k, 50) == 0 || k == 1)
        fprintf("%6d %10.2e %10.2e %10s\n", k, obj, bestobj, info.status);
    end % End if
    
end % End for

% Collect information
% Solution array
sol.x = x_after;
sol.bestx = bestx;

% Information array
info.nepochs = nepochs;
info.niter = nbatchiter;
info.objs = objs;
info.bestobjs = bestobjs;

% Display summary
if show_info
    if info.status == "*"
        disp("- Algorithm reaches optimal after " + nepochs + " epochs (" + ...
            niter + " iterations)");
    elseif info.status == "x"
        disp("- Algorithm fails to reach desired accuracy after " +...
            nepochs + " epochs");
    elseif info.status == "Diverged"
        disp("- Algorithm diverges");
        info.niter = maxiter * m;
    end % End if 
end % End if

end % End function