function [time, fval] = PAME(a, b, C, eta, options)

max_iter = options.max_iter;
N = options.N;
n = options.n;
theta = options.theta;

tau = options.tau;

% Initialization
f = ones(1, n);
g = ones(1, n);
lambda = ones(1,1, N)/N;
fprintf('%f \n', sum(lambda));
A = exp(-C/eta);

Alam = A.^lambda;
K = sum(Alam, 3);
    
Kp = reshape(1 ./ a , n, 1) .* K;
KtransposeF = K'* reshape(f, n, 1);
g = reshape(b, n, 1) ./ KtransposeF;
f = 1 ./ (Kp * g);
pi = reshape(f, n, 1).* Alam .* reshape(g, 1, n);
grad = reshape(sum(sum(pi .* C, 1), 2), 1, 1, N);

fval = zeros(1, max_iter + 1);
time = zeros(1, max_iter + 1);
fval(1, 1) = sum(lambda .* grad, 'all');

lam_ = lambda;

for iter = 1:max_iter
    
    timetic = tic();
    
    Alam = A.^lambda;
    K = sum(Alam, 3);
    
    Kp = reshape(1 ./ a , n, 1) .* K;
    KtransposeF = K'* reshape(f, n, 1);
    g = reshape(b, n, 1) ./ KtransposeF;
    f = 1 ./ (Kp * g);
    
    lambda =  lambda + (1-theta)* (lambda - lam_);
    tmp_lambda = SimplexProj(reshape(lambda, 1, N));
    lambda = reshape(tmp_lambda, 1, 1, N);
    
    % We save the following recomputation of Alam for improving the
    % performance.
    % Alam = A.^lambda;
    
    lam_ = lambda;
    pi = reshape(f, n, 1).* Alam .* reshape(g, 1, n);
 
    grad = reshape(sum(sum(pi .* C, 1), 2), 1, 1, N);
    lambda = lambda + tau * grad;
    tmp_lambda = SimplexProj(reshape(lambda, 1, N));
    lambda = reshape(tmp_lambda, 1, 1, N);
    
    time(1, iter + 1) = time(iter) + toc(timetic);
    fval(1, iter + 1) = sum(lam_ .* grad, 'all');
    
    fprintf("iter: %d EOT: %.16f\n", iter, sum(lam_ .* grad, 'all'));

    

end

end


function X = SimplexProj(Y)
[N,D] = size(Y);
X = sort(Y,2, 'descend');
Xtmp = (cumsum(X,2)-1)*diag(sparse(1./(1:D)));
X = max(bsxfun(@minus,Y,Xtmp(sub2ind([N,D],(1:N)',sum(X>Xtmp,2)))),0);
end


