function [div,w,P] = max_sliced_w2_adam(X,Y,max_iter)
nMonte = 1; % number of random initializations

if nargin<3
    max_iter = 50;
end

m = size(X,1);
n = size(Y,1);
d = size(X,2);
stop_delta = 1e-6; % value or variable

verbose  = 1;
assert( d == size(Y,2))

R = 1/m*(X'*X) + 1/n*(Y'*Y);

h = @(w) w'*R*w...
    -2*(w'*X')*sortOT(X*w, Y*w)*(Y*w);

g = @(w) w'*w;

obj = @(x) h(x)/g(x); %to Maximize
objs = nan(nMonte,1); % to collect objective across different
% initial starting points

for monte_ii = 1:nMonte
    % INITIALIZATION FOR PRIMAL
    if monte_ii == 1
        w_k = randn(d,1); % random
    end    
    % ensure that it is feasible
    w_k = w_k / sqrt(w_k'*w_k);
    obj_val = -inf;
    % ADAM
    lr = 1e-3;
    beta1 = 0.9;
    beta2 = 0.999;
    epsilon = 1e-08;

    m = 0*w_k;
    v = 0*w_k;

    for iter = 1 : max_iter
    [~,grad] = W2_grad(w_k,R,X,Y);
    m = beta1*m + (1-beta1)*grad; % AR model for grad
    v = beta2*v + (1-beta2)*grad.^2; % AR model for grad squared    
    alpha_t = lr*sqrt(1-beta2^(iter))/(1-beta1^(iter));
    w_k_1 = w_k;
    w_k = w_k - alpha_t*m./(sqrt(v)+epsilon);

    obj_val = obj(w_k/sqrt(w_k'*w_k));

    diff = abs(w_k - w_k_1);
    stop_crit = max( diff);
        if verbose
            fprintf('%2d:  del=%.7f obj=%.7f \n',iter, norm(diff),obj_val)
        end
        if stop_crit < stop_delta
            if verbose
                fprintf('Stop: stop_crit=%0.7f < %0.7f \n', stop_crit,stop_delta)
            end
            break;
        end
    end
   
    objs(monte_ii) = obj_val;
    if monte_ii == 1
        w = w_k/sqrt(w_k'*w_k);
        P = sortOT(X*w, Y*w);
    elseif objs(monte_ii)>max(objs(1:monte_ii-1))
        w = w_k/sqrt(w_k'*w_k);
        P = sortOT(X*w, Y*w);
    end
    
end

div = sqrt(max(objs));
fprintf('Obj= %.7f \n\n', div)
end
   

function [f,grad] = W2_grad(w,R,X,Y)

P = sortOT( X*w, Y*w);
% quadratic step
Q = X'*P*Y;
C = R - Q - Q';
n1 = w'*C*w;
n2 = w'*w;
Grad = 2/n2*(C*w) - 2*n1/n2^2*(w);
f = -n1/n2;
grad = -Grad;
end    
    