% Script to test sliced distances/divergences for zero-mean Gaussians
% does Monte Carlo trials for d=2 a
n_monte = 10;
d = 2;

log2_max = 16;
n_max = 2^log2_max;
N_tests = 2.^[7:log2_max];
max_iter = 1000; % for Adam
sb_est = zeros(n_monte,numel(N_tests));
sw_est = sb_est;
msb_est = nan(n_monte,numel(N_tests),2);
msb_est_A = nan(n_monte,numel(N_tests),2);
msw_est = nan(n_monte,numel(N_tests));

thetas = linspace(0,pi,2000); % grid of slice angles in 2D

for monte = 1:n_monte
rng(monte);
Z1 = randn(d,d);
Z1 = Z1./sqrt(sum(Z1.^2,2));
C1 = Z1*Z1';

X = randn(n_max,d)*Z1';
Y = randn(n_max,d);

for n_i = 1 : numel(N_tests)
    n = N_tests(n_i);
    X_hat = X(1:n,:);
    Y_hat = Y(1:n,:);
    Q1_hat = 1/n*X(1:n,:)'*X(1:n,:);
    Q2_hat = 1/n*Y(1:n,:)'*Y(1:n,:);
    sb = nan(numel(thetas),1);
    sw = nan(numel(thetas),1);
for t = 1 : numel(thetas)
   theta = thetas(t);
   w = [cos(theta) sin(theta)]';
   sigma1 = sqrt(w'*Q1_hat*w);
   sigma2 = sqrt(w'*Q2_hat*w);
   sb(t) = abs( sigma1 - sigma2);
   Xw_sort = sort(X_hat*w);
   Yw_sort = sort(Y_hat*w);   
   sw(t) = sqrt(mean((Xw_sort-Yw_sort).^2));
end
sb_est(monte,n_i) = max(sb);
sw_est(monte,n_i) = max(sw);

msb_est(monte,n_i,1) = one_side_max_sliced_bures_path(Q1_hat,Q2_hat);
msb_est(monte,n_i,2) = one_side_max_sliced_bures_path(Q2_hat,Q1_hat);

msb_est_A(monte,n_i,2) = one_side_max_bures_adam(Q1_hat,Q2_hat,max_iter);
msb_est_A(monte,n_i,2) = one_side_max_bures_adam(Q2_hat,Q1_hat,max_iter);

msw_est(monte,n_i) = max_sliced_w2_adam(X_hat,Y_hat,max_iter);
end
end
%% Check accuracy

pct_tol = 1;%tolerance in percent


tol_rate_msb_path = mean(abs(sb_est - max(msb_est,[],3)) < sb_est*pct_tol/100,1);
tol_rate_msb_adam = mean(abs(sb_est - max(msb_est_A,[],3)) < sb_est*pct_tol/100,1);
tol_rate_msw = mean(abs(sw_est - msw_est) < sw_est*pct_tol/100,1);


figure(3),clf
set(gcf,'position',get(gcf,'position').*[1 1 0 0] +  [0 0   383   274])
semilogx(N_tests,tol_rate_msw,'o--','linewidth',3)
hold all
semilogx(N_tests,tol_rate_msb_path,'o-','linewidth',3)
semilogx(N_tests,tol_rate_msb_adam,'d-','linewidth',2,'markersize',12)

loc = 'northeast';
legend('max-sliced Wasserstein-2 \{ADAM\}',...
    'max-sliced Bures',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)
ylabel('Success rate (w/in 1% of exact)')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'ylim',[-0.05,1.05],'fontsize',12)

%%
mse_msb_path = mean(abs(sb_est - max(msb_est,[],3)).^2,1);
mse_msb_adam = mean(abs(sb_est - max(msb_est_A,[],3)).^2,1);
mse_msw = mean(abs(sw_est - msw_est).^2,1);

figure(4),clf
set(gcf,'position',get(gcf,'position').*[1 1 0 0] +  [0 0   383   274])
loglog(N_tests,mse_msw,'o--','linewidth',3)
hold all
loglog(N_tests,mse_msb_path,'o-','linewidth',3)
loglog(N_tests,mse_msb_adam,'d-','linewidth',2,'markersize',12)

loc = 'east';
legend('max-sliced Wasserstein-2 \{ADAM\}',...
    'max-sliced Bures',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)
ylabel('Mean squared error')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'fontsize',12)
