function X = prox_l1_sphere(A,lambda,r)
% min_X 0.5||X-A||_{fro}^2 + lambda ||X||_1, s.t. ||X||_F^2 = r
x = prox_l12_vec(A(:),lambda,r);
X = reshape(x,size(A));

function [x] = prox_l12_vec(a,lambda,r)
% min_x 0.5||x-a||_2^2 + lambda ||x||_1, s.t. x'x = r
abs_a = abs(a);
if(sum(abs_a)==0)
    % return a random vector x
    n = length(a);
    x = randn(n,1)*1e-10;
    x(randperm(n,1))= sqrt(r);
else
    [x] = sign(a).*prox_l12_vec_positive(abs(a),lambda,r);
end

x = sqrt(r)*x / norm(x);

function [x] = prox_l12_vec_positive(b,lambda,r)
% min_x 0.5||x-b||_2^2 + lambda ||x||_1, s.t. x'x = r,
% Note that b >= 0 

% min_x 0.5||x-b||_2^2 + lambda ||x||_1, s.t. x'x = r, x >= 0
% min_x 0.5||x-b||_2^2 + lambda <x,1>, s.t. x'x = r, x >= 0
% min_x 0.5||x-b + lambda 1||_2^2, s.t. x'x = r, x>=0
% min_x 0.5||x-(b - lambda 1)||_2^2, s.t. x'x = r, x>=0
x = proj_l2_nn2(b-lambda,r);

function x = proj_l2_nn2(a,r)
% min_{x} 0.5||x-a||_2^2, s.t. x'x = r, x>=0
% a can be negative or positive

x = a;
x(x<0)=1e-10;
norm_x = norm(x,'fro');
if(norm_x==0)
    [~,ind] = min(abs(a));
    x = 1e-10*randn(size(a));
    x(ind) = sqrt(r);
else
    x = (x/norm_x) * sqrt(r);
end

