% Demo of Figure 4
% Linear witness functions on MNIST
% Get extended MNIST
%url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/matlab.zip';
%load('emnist-mnist.mat')
fprintf('Need to get, unzip, and load extended MNIST')


%%
max_iter = 200; % For ADAM to compute m.s. W2

not_char = '7';
N_test = [100 200 500 1000 2000 5000 9000];
divs_linear = nan(numel(N_test),2,6);
timings_linear = nan(numel(N_test),6);
W  = nan(size(dataset.train.images,2),numel(N_test),6);



rng(10)
x_sel = randperm(numel(dataset.train.labels),max(N_test));
not_label = dataset.mapping(dataset.mapping(:,2)==not_char,1);

subset = find(dataset.test.labels~=not_label);

subset_eq = find(dataset.test.labels==not_label);
y_sel2 = subset_eq(randperm(numel(subset_eq),round(numel(subset_eq)/2)));
y_sel = cat(1,subset,y_sel2);
y_sel = y_sel(randperm(numel(y_sel),max(N_test)));
% half the time it is missing the other half it appears 1/10
% so in total it is 1/20 =5%



for m = 1:numel(N_test)
    rng(10)
    N= N_test(m);
    
    X = double(dataset.train.images(x_sel(1:N),:));
    Y = double(dataset.test.images(y_sel(1:N),:));
    
    % Compute m.s. Bures
    tic
    rho_X = 1/N*(X'*X);
    rho_Y = 1/N*(Y'*Y);
    [divs_linear(m,1,2),w1] = one_side_max_sliced_bures_path(rho_X,rho_Y);
    [divs_linear(m,2,2),w2] = one_side_max_sliced_bures_path(rho_Y,rho_X);
    t1 = toc;
    timings_linear(m,2)=t1;

    % Compute m.s. Wasserstein 2
    tic
    [divs_linear(m,1,3),w3] = max_sliced_w2_adam(X,Y,max_iter);
    timings_linear(m,3)=toc;    
    W(:,m,1:3) =cat(3,w1,w2,w3);
    
    % Compute m.s. Frechet
    tic    
    Xc = X - mean(X);
    Yc = Y - mean(Y);
    
    Sigma_x = 1/(N-1)*(Xc'*Xc);
    Sigma_y = 1/(N-1)*(Yc'*Yc);
    
    [divs_linear(m,1,4),w4] = one_side_max_sliced_bures_path(Sigma_x,Sigma_y);
    [divs_linear(m,2,4),w5] = one_side_max_sliced_bures_path(Sigma_y,Sigma_x);
    
    diff_mean = sum((mean(X)-mean(Y)).^2);
    divs_linear(m,1,4)= sqrt(diff_mean + divs_linear(m,1,4)^2);
    divs_linear(m,2,4)= sqrt(diff_mean + divs_linear(m,2,4)^2);
    timings_linear(m,4)=toc;
    W(:,m,4:5) =cat(3,w4,w5);
    
    
    % Compute MMD (linear case)
    tic;
    delta = mean(X)- mean(Y);
    timings_linear(m,6)=toc;
    
    W(:,m,6) = delta/norm(delta);
    
end

%%
figure(1),clf
subplot(1,2,1)
semilogx(N_test,divs_linear(:,1,3),'x--','linewidth',2,'markersize',10)
hold all
semilogx(N_test,max(divs_linear(:,:,4),[],2),'d:','linewidth',3,'markersize',10)
semilogx(N_test,max(divs_linear(:,:,2),[],2),'o-','linewidth',3,'markersize',10)
xlabel('Sample size (n = m)')
ylabel('Divergence')
set(gca,'fontsize',12)
legend('max-sliced Wasserstein-2','max-sliced Frechet','max-sliced Bures')


subplot(1,2,2)
loglog(N_test,timings_linear(:,3),'x--','linewidth',2,'markersize',10)
hold all
loglog(N_test,timings_linear(:,4),'d:','linewidth',2,'markersize',10)
loglog(N_test,timings_linear(:,2),'o-','linewidth',2,'markersize',10)
xlabel('Sample size (n = m)')
set(gca,'fontsize',12)
ylabel('Computation time (s)')
legend('max-sliced Wasserstein-2','max-sliced Frechet','max-sliced Bures','location','best')

%%

figure(2),clf
weights = cell(size(W,3),numel(N_test));
for m=1:numel(N_test)
    for l=1:size(weights,1)
        weights{l,m} =reshape(W(:,m,l),28,28);
    end
end
imagesc(cell2mat(weights))
colormap(gray)
axis equal

%% Compute Monte-Carlo trials of precision
w_all  = nan(size(X,2),numel(N_test),6,10);

timings_all = zeros(10,numel(N_test),6);
not_chars = '0123456789'
for  l = 1:10
    not_char = not_chars(l);
    rng(10)
    
    not_label = dataset.mapping(dataset.mapping(:,2)==not_char,1);
    subset = find(dataset.test.labels~=not_label);
    subset_eq = find(dataset.test.labels==not_label);
    y_sel2 = subset_eq(randperm(numel(subset_eq),round(numel(subset_eq)/2)));
    y_sel = cat(1,subset,y_sel2);
    y_sel = y_sel(randperm(numel(y_sel),max(N_test)));
    % half the time it is missing the other half it appears 1/10
    % so in total it is 1/20 =5%
    
    
    for m = 1:numel(N_test)
        N= N_test(m);
        
        X = double(dataset.train.images(x_sel(1:N),:));
        Y = double(dataset.test.images(y_sel(1:N),:));
        
        % Compute m.s. Bures
        tic
        rho_X = 1/N*(X'*X);
        rho_Y = 1/N*(Y'*Y);
        [div1,w1] = one_side_max_sliced_bures_path(rho_X,rho_Y);
        [div2,w2] = one_side_max_sliced_bures_path(rho_Y,rho_X);
        timings_all(l,m,2)=toc;
        
        % Compute m.s. Wasserstein 2
        tic
        [div3,w3] = max_sliced_w2_adam(X,Y,max_iter);
        timings_all(l,m,3)=toc;
        
        w_all(:,m,1:3,l) =cat(3,w1,w2,w3);
        
        tic
        
        Xc = X - mean(X);
        Yc = Y - mean(Y);
        
        Sigma_x = 1/(N-1)*(Xc'*Xc);
        Sigma_y = 1/(N-1)*(Yc'*Yc);
        
        [div4,w4] = one_side_max_sliced_bures_path(Sigma_x,Sigma_y);
        [div5,w5] = one_side_max_sliced_bures_path(Sigma_y,Sigma_x);
        
        diff_mean = sum((mean(X)-mean(Y)).^2);
        div4= sqrt(diff_mean + div4^2);
        div5= sqrt(diff_mean + div5^2);
        timings_all(l,m,4)=toc;
        w_all(:,m,4:5,l) =cat(3,w4,w5);
        
        
        delta = mean(X)- mean(Y);
        w_all(:,m,6,l) = delta/norm(delta);
        
    end
end
%%

k = 10;
% indexing corresponds to  3rd dimension of w_all 
methods = [3    3   1   1   6];
sort_by = [1    2   1   2   2]; % 1 = magnitude/abs, 2 = value
%          W2^2 W2  B^2 B   M
p_at_k = nan(10,numel(N_test),numel(methods),2);
for l = 1: 10
    for m = 1:numel(N_test)
        N = N_test(m);
        L_X = dataset.train.labels(x_sel(1:N));
        X = double(dataset.train.images(x_sel(1:N),:));
        for method_i = 1:numel(methods)
            w = w_all(:,m,methods(method_i),l);
            if sort_by(method_i)==1
                f = @(x) abs(x);
            else
                f = @(x) x;
            end
            [~,idx] = sort(f(X*w),'descend');
            p_at_k(l,m,method_i,1) = mean(L_X(idx(1:k))==l-1);
            p_at_k(l,m,method_i,2) = mean(L_X(idx(end-k+1:end))==l-1);
        end
    end
end
AP_at_k = squeeze(mean(p_at_k,1)); % Average precision at k

%% One figure with subplots
figure(100),clf
subplot(1,3,1)
semilogx(N_test,divs_linear(:,1,3),'x--','linewidth',2,'markersize',10)
set(gca,'fontsize',12)
hold all
semilogx(N_test,max(divs_linear(:,:,4),[],2),'d:','linewidth',2,'markersize',10)
semilogx(N_test,max(divs_linear(:,:,2),[],2),'o-','linewidth',2,'markersize',10)
xlabel('Sample size (n = m)')
ylabel('Divergence')
legend('max-sliced Wasserstein-2','max-sliced Frechet','max-sliced Bures')

subplot(1,3,2)
loglog(N_test,timings_linear(:,3),'x--','linewidth',2,'markersize',10)
set(gca,'fontsize',12)
hold all
loglog(N_test,timings_linear(:,4),'d:','linewidth',2,'markersize',10)
loglog(N_test,timings_linear(:,2),'o-','linewidth',2,'markersize',10)
xlabel('Sample size (n = m)')
ylabel('Computation time (s)')
legend('max-sliced Wasserstein-2','max-sliced Frechet','max-sliced Bures','location','best')

subplot(1,3,3)
h1=semilogx(N_test,AP_at_k(:,1,1),'x--','linewidth',2,'markersize',10);
hold all
set(gca,'fontsize',12)
h2=semilogx(N_test,AP_at_k(:,2,1),'kx--','linewidth',2,'markersize',10);
h5=semilogx(N_test,AP_at_k(:,5,1),'d:','linewidth',2,'markersize',10);
h3=semilogx(N_test,AP_at_k(:,3,1),'o-','linewidth',2,'markersize',10);
xlabel('Sample size (n = m)')
ylabel('Average precision at 10')


legend([h1,h2,h5,h3],'w''X | w = max sliced Wasserstein-2','(w''X)^2 | w = max sliced Wasserstein-2',...
    'w''X | w = m_X-m_Y','(w''X)^2 | w= one-sided max sliced Bures',...
    'location','best')