function out = simplex_proj(x,lam)
% find the closest point in an n-dimensional simplex set to x

n = length(x);
if (nargin < 2)
    lam = 1;
end
x = x ./ lam; % scale x by lambda
    
ind = zeros(n,1);
y = sortrows(x,-1);

for i = 1:n
    z = y - y(i); 
    ind(i) = sum( z(1:i));
end

t = -1;
for i = 2:n
    if (ind(i-1) <= 1) && (ind(i) >= 1) 
        t = (sum(y(1:i-1)) - 1) / (i-1);
    end
end
if (t == -1)
    t = (sum(y(1:n)) - 1) / n;
end

out = x - t;
out(out < 0) = 0;
out = out .* lam;

return;
