% Script to test sliced distances/divergences for zero-mean Gaussians
% does Monte Carlo trials
% for d=1000 dimensional case
n_monte = 10;
d = 1000;

log2_max = 12;
n_max = 2^16;
N_tests = 2.^[7:log2_max];
max_iter = 1000; % for Adam
msb_est = nan(n_monte,numel(N_tests),2);
msb_est_A = nan(n_monte,numel(N_tests),2);
msw_est = nan(n_monte,numel(N_tests));

for monte = 1:n_monte
rng(monte);
Z1 = randn(d,d);
Z1 = Z1./sqrt(sum(Z1.^2,2));
C1 = Z1*Z1';

X = randn(n_max,d)*Z1';
Y = randn(n_max,d);

for n_i = 1 : numel(N_tests)
    n = N_tests(n_i);
    X_hat = X(1:n,:);
    Y_hat = Y(1:n,:);
    Q1_hat = 1/n*X(1:n,:)'*X(1:n,:);
    Q2_hat = 1/n*Y(1:n,:)'*Y(1:n,:);

    msb_est(monte,n_i,1) = one_side_max_sliced_bures_path(Q1_hat,Q2_hat);
    msb_est(monte,n_i,2) = one_side_max_sliced_bures_path(Q2_hat,Q1_hat);

    msb_est_A(monte,n_i,1) = one_side_max_bures_adam(Q1_hat,Q2_hat,max_iter);
    msb_est_A(monte,n_i,2) = one_side_max_bures_adam(Q2_hat,Q1_hat,max_iter);

    msw_est(monte,n_i) = max_sliced_w2_adam(X_hat,Y_hat,max_iter);
end
end
%%

bound_rate_msw = mean(...
   msw_est> max(msb_est,[],3) );
pct_tol = 5; %tolerance in percentage
tol_rate_msw = mean(...
   msw_est> (1-pct_tol/100)*max(msb_est,[],3) );
pct_tol = 1;
tol_rate = mean(...
   abs(max(msb_est,[],3) -max(msb_est_A,[],3))  < pct_tol/100*max(msb_est,[],3));
figure(1),clf
semilogx(N_tests,bound_rate_msw,'*-','linewidth',2)
hold all
semilogx(N_tests,tol_rate_msw,'o--','linewidth',2)
semilogx(N_tests,tol_rate,'d-','linewidth',2,'markersize',12)

loc = 'southeast';
legend(' max-sliced Wasserstein-2 \{ADAM\} > max-sliced Bures',...
    ' max-sliced Wasserstein-2 \{ADAM\} > 95%  max-sliced Bures',...
    '| max-sliced Bures \{ADAM\} - MSB | < 1% max-sliced Bures',...
    'location',loc)

ylabel('Success rate')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'ylim',[-0.05,1.05],'fontsize',12)

%%

mse_msb_adam = mean(abs(max(msb_est_A,[],3) - max(msb_est,[],3)).^2,1);
mse_msw = mean(abs(max(msb_est,[],3) - msw_est).^2,1);


figure(2),clf
set(gcf,'position',get(gcf,'position').*[1 1 0 0] +  [0 0   383   274])
loglog(N_tests,mse_msw,'o-','linewidth',3)
hold all
loglog(N_tests,mse_msb_adam,'d-','linewidth',2,'markersize',12)

loc = 'east';
legend('max-sliced Wasserstein-2 \{ADAM\}',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)
ylabel('Mean squared difference from max-sliced Bures')
xlabel('log_2(sample size)')
set(gca,'xtick',N_tests,'xticklabel',log2(N_tests))
axis tight
set(gca,'fontsize',12)
