function [time, fval, pi] = PAM(a, b, C, eta, options)

max_iter = options.max_iter;
N = options.N;
n = options.n;
tau = options.tau ;

% Initialization
f = ones(1, n);
g = ones(1, n);
% lambda = zeros(1,1, N);
% lambda(1, 1, 1) = 1;
lambda = ones(1,1, N)/N;
% lambda = rand(1,1, N);
% lambda = lambda / sum(lambda, 'all');
fprintf('%f \n', sum(lambda));
A = exp(-C/eta);

Alam = A.^lambda;
K = sum(Alam, 3);

Kp = reshape(1 ./ a , n, 1) .* K;
KtransposeF = K'* reshape(f, n, 1);
g = reshape(b, n, 1) ./ KtransposeF;
f = 1 ./ (Kp * g);
pi = reshape(f, n, 1).* Alam .* reshape(g, 1, n);
grad = reshape(sum(sum(pi .* C, 1), 2), 1, 1, N);

fval = zeros(1, max_iter + 1);
time = zeros(1, max_iter + 1);
fval(1, 1) = sum(lambda .* grad, 'all');

for iter = 1:max_iter
    
    timetic = tic();
    
    Alam = A.^lambda;
    K = sum(Alam, 3);
    
    Kp = reshape(1 ./ a , n, 1) .* K;
    KtransposeF = K'* reshape(f, n, 1);
    g = reshape(b, n, 1) ./ KtransposeF;
    f = 1 ./ (Kp * g);
    
%     fprintf("%d %d %d \n", size(Alam));
    pi = reshape(f, n, 1).* Alam .* reshape(g, 1, n);
%     fprintf("%f \n", sum(pi, 'all'));
    
    grad = reshape(sum(sum(pi .* C, 1), 2), 1, 1, N);
    lam_ = lambda;
    
    lambda = lambda + tau * grad;
    
    tmp_lambda = SimplexProj(reshape(lambda, 1, N));
    lambda = reshape(tmp_lambda, 1, 1, N);
    
    time(1, iter + 1) = time(iter) + toc(timetic);
    fval(1, iter + 1) = sum(lam_ .* grad, 'all');
    
    fprintf("iter: %d EOT: %.16f\n", iter, sum(lam_ .* grad, 'all'));

%     time(1, iter + 1) = time(iter) + toc(timetic);
%     fval(1, iter + 1) = sum(log(f) .* a, 'all') + sum(log(g) .* b, 'all') - sum(pi, 'all');
%     
%     fprintf("iter: %d F val: %.16f\n", iter, sum(log(f) .* a, 'all') + sum(log(g) .* b, 'all') - sum(pi, 'all'));

    
end

end


function X = SimplexProj(Y)
[N,D] = size(Y);
X = sort(Y,2, 'descend');
Xtmp = (cumsum(X,2)-1)*diag(sparse(1./(1:D)));
X = max(bsxfun(@minus,Y,Xtmp(sub2ind([N,D],(1:N)',sum(X>Xtmp,2)))),0);
end


