function [time, fval] = APGA(a, b, C, eta, options)

max_iter = options.max_iter;
N = options.N;
n = options.n;
tau = options.tau;

% Initialization
f = ones(n, 1);
g = ones(n, 1);
% % lambda = zeros(1,1, N);
% lambda(1, 1, 1) = 1;
% lambda = rand(1,1, N);
% lambda = lambda / sum(lambda, 'all');
lambda = ones(1,1, N)/N;

f_ = f;
g_ = g;
lam_ = lambda;

fprintf('%f \n', sum(lambda));
A = exp(-C/eta);

Alam = A.^lambda;
pi = reshape(exp(f/eta), n, 1).* Alam .* reshape(exp(g/eta), 1, n);
pi = pi / sum(pi, 'all');

grad_f = reshape(a, n, 1) - sum(pi, 3) * ones(n, 1) ;
grad_g = reshape(b, n, 1) - sum(pi, 3)' * ones(n, 1) ;
grad_lam = reshape(sum(sum(pi .* C, 1), 2), 1, 1, N);

f = f + tau * grad_f;
g = g + tau * grad_g;  
lambda = lambda + tau * grad_lam;
tmp_lambda = SimplexProj(reshape(lambda, 1, N));
lambda = reshape(tmp_lambda, 1, 1, N);


fval = zeros(1, max_iter + 1);
time = zeros(1, max_iter + 1);
fval(1, 1) = sum(lambda .* grad_lam, 'all');

for iter = 2:max_iter
    
    timetic = tic();
    
    Alam = A.^lambda;

    f = f + (iter - 2)/ (iter + 1) * (f - f_);
    g = g + (iter - 2)/ (iter + 1) * (g - g_);
    lambda = lambda + (iter - 2)/ (iter + 1) * (lambda - lam_);

    pi = reshape(exp(f/eta), n, 1).* Alam .* reshape(exp(g/eta), 1, n);
    pi = pi / sum(pi, 'all');

    grad_f = reshape(a, n, 1) - sum(pi, 3) * ones(n, 1);
    grad_g = reshape(b, n, 1) - sum(pi, 3)' * ones(n, 1);
    grad_lam = reshape(sum(sum(pi .* C, 1), 2), 1, 1, N);
    
    f_ = f;
    g_ = g;
    lam_ = lambda;

    f = f + tau * grad_f;
    g = g + tau * grad_g;  
    lambda = lambda + tau * grad_lam;
    tmp_lambda = SimplexProj(reshape(lambda, 1, N));
    lambda = reshape(tmp_lambda, 1, 1, N);

    time(1, iter + 1) = time(iter) + toc(timetic);
    fval(1, iter + 1) = sum(lam_ .* grad_lam, 'all');
    
    fprintf("iter: %d EOT: %.16f\n", iter, sum(lam_ .* grad_lam, 'all'));

end

end


function X = SimplexProj(Y)
[N,D] = size(Y);
X = sort(Y,2, 'descend');
Xtmp = (cumsum(X,2)-1)*diag(sparse(1./(1:D)));
X = max(bsxfun(@minus,Y,Xtmp(sub2ind([N,D],(1:N)',sum(X>Xtmp,2)))),0);
end


