% Script to test sliced distances/divergences for zero-mean Gaussians
% does Monte Carlo trials for d=2 and d=1000
% for different numbers of max_iters


log2_max = 16;
n_max = 2^log2_max;
n_monte = 10;

max_iters = [50 100 200 500 1000]; % for Adam

for n_iter = 1:numel(max_iters)
max_iter = max_iters(n_iter);
d = 2;
N_tests = 2.^[7:16]; 



sb_est = zeros(n_monte,numel(N_tests));
sw_est = sb_est;
msb_est = nan(n_monte,numel(N_tests),2);
msb_est_A = nan(n_monte,numel(N_tests),2);
msw_est = nan(n_monte,numel(N_tests));

thetas = linspace(0,pi,2000); % grid of slice angles in 2D

for monte = 1:n_monte
rng(monte);
Z1 = randn(d,d);
Z1 = Z1./sqrt(sum(Z1.^2,2));
C1 = Z1*Z1';

X = randn(n_max,d)*Z1';
Y = randn(n_max,d);

for n_i = 1 : numel(N_tests)
    n = N_tests(n_i);
    X_hat = X(1:n,:);
    Y_hat = Y(1:n,:);
    Q1_hat = 1/n*X(1:n,:)'*X(1:n,:);
    Q2_hat = 1/n*Y(1:n,:)'*Y(1:n,:);
    sb = nan(numel(thetas),1);
    sw = nan(numel(thetas),1);
for t = 1 : numel(thetas)
   theta = thetas(t);
   w = [cos(theta) sin(theta)]';
   sigma1 = sqrt(w'*Q1_hat*w);
   sigma2 = sqrt(w'*Q2_hat*w);
   sb(t) = abs( sigma1 - sigma2);
   Xw_sort = sort(X_hat*w);
   Yw_sort = sort(Y_hat*w);   
   sw(t) = sqrt(mean((Xw_sort-Yw_sort).^2));
end
sb_est(monte,n_i) = max(sb);
sw_est(monte,n_i) = max(sw);

msb_est(monte,n_i,1) = one_side_max_sliced_bures_path(Q1_hat,Q2_hat);
msb_est(monte,n_i,2) = one_side_max_sliced_bures_path(Q2_hat,Q1_hat);

msb_est_A(monte,n_i,2) = one_side_max_bures_adam(Q1_hat,Q2_hat,max_iter);
msb_est_A(monte,n_i,2) = one_side_max_bures_adam(Q2_hat,Q1_hat,max_iter);

msw_est(monte,n_i) = max_sliced_w2_adam(X_hat,Y_hat,max_iter);
end
end
%% Check accuracy

pct_tol = 1;%tolerance in percent


tol_rate_msb_path = mean(abs(sb_est - max(msb_est,[],3)) < sb_est*pct_tol/100,1);
tol_rate_msb_adam = mean(abs(sb_est - max(msb_est_A,[],3)) < sb_est*pct_tol/100,1);
tol_rate_msw = mean(abs(sw_est - msw_est) < sw_est*pct_tol/100,1);


figure(n_iter+100),clf
set(gcf,'position',get(gcf,'position').*[1 1 0 0] +  [0 0   383   274])
semilogx(N_tests,tol_rate_msw,'o--','linewidth',3)
hold all
semilogx(N_tests,tol_rate_msb_path,'o-','linewidth',3)
semilogx(N_tests,tol_rate_msb_adam,'d-','linewidth',2,'markersize',12)

loc = 'northeast';
legend('max-sliced Wasserstein-2 \{ADAM\}',...
    'max-sliced Bures',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)
ylabel('Success rate (w/in 1% of exact)')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'ylim',[-0.05,1.05],'fontsize',12)

%%
mse_msb_path = mean(abs(sb_est - max(msb_est,[],3)).^2,1);
mse_msb_adam = mean(abs(sb_est - max(msb_est_A,[],3)).^2,1);
mse_msw = mean(abs(sw_est - msw_est).^2,1);

figure(n_iter+200),clf
set(gcf,'position',get(gcf,'position').*[1 1 0 0] +  [0 0   383   274])
loglog(N_tests,mse_msw,'o--','linewidth',3)
hold all
loglog(N_tests,mse_msb_path,'o-','linewidth',3)
loglog(N_tests,mse_msb_adam,'d-','linewidth',2,'markersize',12)

loc = 'east';
legend('max-sliced Wasserstein-2 \{ADAM\}',...
    'max-sliced Bures',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)
ylabel('Mean squared error')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'fontsize',12)

%% 1000 dimensional case

log2_max = 12;
N_tests = 2.^[6:2:log2_max];
d = 1000;
sb_est = zeros(n_monte,numel(N_tests));
sw_est = sb_est;
msb_est = nan(n_monte,numel(N_tests),2);
msb_est_A = nan(n_monte,numel(N_tests),2);
msw_est = nan(n_monte,numel(N_tests));

for monte = 1:n_monte
rng(monte);
Z1 = randn(d,d);
Z1 = Z1./sqrt(sum(Z1.^2,2));
C1 = Z1*Z1';

X = randn(n_max,d)*Z1';
Y = randn(n_max,d);

for n_i = 1 : numel(N_tests)
    n = N_tests(n_i);
    X_hat = X(1:n,:);
    Y_hat = Y(1:n,:);
    Q1_hat = 1/n*X(1:n,:)'*X(1:n,:);
    Q2_hat = 1/n*Y(1:n,:)'*Y(1:n,:);

    msb_est(monte,n_i,1) = one_side_max_sliced_bures_path(Q1_hat,Q2_hat);
    msb_est(monte,n_i,2) = one_side_max_sliced_bures_path(Q2_hat,Q1_hat);

    msb_est_A(monte,n_i,1) = one_side_max_bures_adam(Q1_hat,Q2_hat,max_iter);
    msb_est_A(monte,n_i,2) = one_side_max_bures_adam(Q2_hat,Q1_hat,max_iter);

    msw_est(monte,n_i) = max_sliced_w2_adam(X_hat,Y_hat,max_iter);
end
end
%%

colors = lines(8);
figure(300+n_iter),clf
h1 = semilogx(N_tests,msw_est,'-.','linewidth',2,'color',colors(3,:));
hold all
h2 = semilogx(N_tests,max(msb_est,[],3),'m-','linewidth',2);
h3 = semilogx(N_tests,max(msb_est_A,[],3),'c--','linewidth',2);

loc = 'northeast';
legend([h1(1) h2(1) h3(1)],'max-sliced Wasserstein-2 \{ADAM\} ',...
    'max-sliced Bures \{Eig.\}',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)

ylabel('Divergence')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'fontsize',12)



%%

bound_rate_msw = mean(...
   msw_est> max(msb_est,[],3) );
pct_tol = 5; %tolerance in percentage
tol_rate_msw = mean(...
   msw_est> (1-pct_tol/100)*max(msb_est,[],3) );
pct_tol = 5; %tolerance in percentage
tol_rate = mean(...
   abs(max(msb_est,[],3) -max(msb_est_A,[],3))  < pct_tol/100*max(msb_est,[],3));

figure(400+n_iter),clf
semilogx(N_tests,bound_rate_msw,'*-','linewidth',2)
hold all
semilogx(N_tests,tol_rate_msw,'o--','linewidth',2)
semilogx(N_tests,tol_rate,'d-','linewidth',2,'markersize',12)

loc = 'best';
legend(' max-sliced Wasserstein-2 \{ADAM\} > max-sliced Bures',...
    ' max-sliced Wasserstein-2 \{ADAM\} > 95%  max-sliced Bures',...
    '| max-sliced Bures \{ADAM\} - MSB | < 5% max-sliced Bures',...
    'location',loc)

ylabel('Success rate')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'ylim',[-0.05,1.05],'fontsize',12)

end