% Script to test sliced distances/divergences for 2D Gaussians
rng(10);
n_max = 2^16;
max_iter = 1000; % for Adam
d = 2; 
Z1 = randn(d,d);
Z1 = Z1./sqrt(sum(Z1.^2,2));
C1 = Z1*Z1';

thetas = linspace(0,pi,2000);

Z2 = eye(d);
C2 = eye(d);

sb_exact = zeros(numel(thetas),1);
sw_exact = zeros(numel(thetas),1);
q = linspace(0,1,4000);
q = q(2:end-1);

for t = 1 : numel(thetas)
   theta = thetas(t);
   w = [cos(theta) sin(theta)]';
   sigma1 = sqrt(w'*C1*w);
   sigma2 = sqrt(w'*C2*w);
   sb_exact(t) = abs( sigma1 - sigma2 );
   diff_F_inv = norminv(q,0,sigma1) - norminv(q,0,sigma2); 
   sw_exact(t) = sqrt(diff(q(1:2))*sum(diff_F_inv.^2)); 
end
% Functions to get slice angle from slice vector
proj2pos = @(theta) (theta*(theta>0)) + ((theta+pi)*(theta<=0));
vec2angle = @(w) proj2pos(atan2(w(2),w(1)));

%
msb_pop = nan(1,2);
[msb_pop(1),w_star1] = one_side_max_sliced_bures_path(C1,C2);
theta_1 = vec2angle(w_star1);
[msb_pop(2),w_star2] = one_side_max_sliced_bures_path(C2,C1);
theta_2 = vec2angle(w_star2);

figure(1),clf
subplot(2,2,1)
plot(thetas,sw_exact,'-','linewidth',5)
hold all
plot(thetas,sb_exact,':','linewidth',2)
plot([0;pi],repmat(max(msb_pop),2,1),'-.','linewidth',2)
h2=plot([0;pi],repmat(msb_pop,2,1),'--','linewidth',2);
set(h2(2),'linestyle','-')
plot([theta_1;theta_1],[0,msb_pop(1)],'--','linewidth',2,'color',get(h2(1),'color'))
plot([theta_2;theta_2],[0,msb_pop(2)],'-','linewidth',2,'color',get(h2(2),'color'))


legend('sliced Wasserstein-2','sliced Bures','max-sliced Wasserstein-2',...
    'one-sided max-sliced Bures(X,Y)','one-sided max-sliced Bures(Y,X)',...
    'location','northeast')
title('Divergence between 2D Gaussians with known covariance','fontweight','normal')
xlabel('angle (radians) defining slice')
axis tight
set(gca,'ylim',[0,0.5],'fontsize',12)

% Sample version
N_tests = [100,1000,10000];
rng(10);
X = randn(n_max,d)*Z1';
Y = randn(n_max,d)*Z2';
%

figure(2),clf
tt = linspace(-3,3,50);
[XX,YY] =meshgrid(tt,tt);
p_X = 0*XX;
p_Y = 0*XX;
p_X(:) = mvnpdf([XX(:),YY(:)],[0 0],C1);
p_Y(:) = mvnpdf([XX(:),YY(:)],[0 0],C2);

contour(tt,tt,p_X,8,'-')
hold all
contour(tt,tt,p_Y,8,':','linewidth',2)
plot([0 cos(theta_1)],[0 sin(theta_1)],'--k','linewidth',3,'color',get(h2(1),'color'))
plot([0 cos(theta_2)],[0 sin(theta_2)],'-b','linewidth',3,'color',get(h2(2),'color'))

colorbar
legend('\mu','\nu')
set(gca,'fontsize',14)
axis equal


%%
figure(1)

sb_est = zeros(numel(thetas),numel(N_tests));
sgw_est = zeros(numel(thetas),numel(N_tests));
sw_est = sgw_est;
sw_est2 = sgw_est;

msb_est = nan(2,numel(N_tests));
msb_est_A = msb_est;
msb_MF = msb_est;

msw_est = nan(1,numel(N_tests));
msw_MF = msw_est;

q = linspace(0,1,4000);
q = q(2:end-1);

n_plot=100;

for n_i = 1 : numel(N_tests)
    n = N_tests(n_i);
    X_hat = X(1:n,:);
    Y_hat = Y(1:n,:);
    Q1_hat = 1/n*X(1:n,:)'*X(1:n,:);
    Q2_hat = 1/n*Y(1:n,:)'*Y(1:n,:);

for t = 1 : numel(thetas)
   theta = thetas(t);
   w = [cos(theta) sin(theta)]';
   sigma1 = sqrt(w'*Q1_hat*w);
   sigma2 = sqrt(w'*Q2_hat*w);
   sb_est(t,n_i) = abs( sigma1 - sigma2 );
   diff_F_inv = norminv(q,0,sigma1) - norminv(q,0,sigma2); 
   sgw_est(t,n_i) = sqrt(diff(q(1:2))*sum(diff_F_inv.^2));
   Xw_sort = sort(X_hat*w);
   Yw_sort = sort(Y_hat*w);   
   sw_est(t,n_i) = sqrt(sigma1^2+sigma2^2 - 2*Xw_sort'*Yw_sort/n); 
   sw_est2(t,n_i) = sqrt(mean((Xw_sort-Yw_sort).^2));
end
[msb_est(1,n_i),w_star1] = one_side_max_sliced_bures_path(Q1_hat,Q2_hat);

[msb_est(2,n_i),w_star2] = one_side_max_sliced_bures_path(Q2_hat,Q1_hat);

[msb_est_A(1,n_i),w_starA1] = one_side_max_bures_adam(Q1_hat,Q2_hat,max_iter);
[msb_est_A(2,n_i),w_starA2] = one_side_max_bures_adam(Q2_hat,Q1_hat,max_iter);

[msw_est(1,n_i),w_starW2] = max_sliced_w2_adam(X_hat,Y_hat,max_iter);

subplot(2,2,1+n_i)
plot(thetas,sw_est2(:,n_i),'-','linewidth',3)
hold all
%plot(thetas,sw_est(:,n_i),'-.','linewidth',2) % check for fidelity
%plot(thetas,sgw_est(:,n_i),'-.','linewidth',2) % check for fidelity

plot(thetas,sb_est(:,n_i),':','linewidth',2)
h1=plot([0;pi],repmat(msw_est(1,n_i),2,1),'-.','linewidth',2);
h2=plot([0;pi],repmat(max(msb_est(:,n_i)),2,1),'m-','linewidth',3);
h3=plot([0;pi],repmat(max(msb_est_A(:,n_i)),2,1),'c--','linewidth',2);

theta_star = vec2angle(w_starW2);
plot([theta_star;theta_star],[0,msw_est(1,n_i)],'-.','linewidth',2,'color',get(h1,'color'))
if msb_est(1,n_i)>msb_est(2,n_i)
    theta_star = vec2angle(w_star1);
else
    theta_star = vec2angle(w_star2);
end
plot([theta_star;theta_star],[0;max(msb_est(:,n_i))],'m-','linewidth',2,'color',get(h2,'color'))
if  n==n_plot
    theta_w2 = vec2angle(w_starW2);
    theta_b = theta_star;
end
if msb_est_A(1,n_i)>msb_est_A(2,n_i)
    theta_star = vec2angle(w_starA1);
else
    theta_star = vec2angle(w_starA2);
end
plot([theta_star;theta_star],[0;max(msb_est_A(:,n_i))],'c--','linewidth',2,'color',get(h3,'color'))

loc = 'northeast';
    legend('sliced Wasserstein-2','sliced Bures','max-sliced Wasserstein-2 \{ADAM\}',...
    'max-sliced Bures','max-sliced Bures \{ADAM\}',...
    'location',loc)

title(sprintf('Estimated divergence between 2D Gaussians (m=n=%i)',n),...
    'fontweight','normal')
xlabel('angle (radians) defining slice')
axis tight
set(gca,'ylim',[0,0.5],'fontsize',12)
end

%%

figure(3),clf

plot(X(1:n_plot,1),X(1:n_plot,2),'k+')
hold all
plot(Y(1:n_plot,1),Y(1:n_plot,2),'ko')

plot([0 cos(theta_w2)],[0 sin(theta_w2)],'-.','linewidth',3,'color',get(h1,'color'))
plot([0 cos(theta_b)],[0 sin(theta_b)],'-b','linewidth',3,'color',get(h2,'color'))
legend('\mu','\nu')
set(gca,'fontsize',14)
axis equal
set(gca,'xlim',[-3 3],'ylim',[-3 3])
%%
N_tests = 2.^[7:16];

sb_est = zeros(numel(thetas),numel(N_tests));
sgw_est = zeros(numel(thetas),numel(N_tests));
sw_est = sgw_est;
sw_est2 = sgw_est;

msb_est = nan(2,numel(N_tests));
msb_est_A = nan(2,numel(N_tests));
msb_MF = msb_est;

msw_est = nan(1,numel(N_tests));
msw_MF = msw_est;

q = linspace(0,1,4000);
q = q(2:end-1);

for n_i = 1 : numel(N_tests)
    n = N_tests(n_i);
    X_hat = X(1:n,:);
    Y_hat = Y(1:n,:);
    Q1_hat = 1/n*X(1:n,:)'*X(1:n,:);
    Q2_hat = 1/n*Y(1:n,:)'*Y(1:n,:);

for t = 1 : numel(thetas)
   theta = thetas(t);
   w = [cos(theta) sin(theta)]';
   sigma1 = sqrt(w'*Q1_hat*w);
   sigma2 = sqrt(w'*Q2_hat*w);
   sb_est(t,n_i) = abs( sigma1 - sigma2 );
   diff_F_inv = norminv(q,0,sigma1) - norminv(q,0,sigma2); 
   sgw_est(t,n_i) = sqrt(diff(q(1:2))*sum(diff_F_inv.^2));
   Xw_sort = sort(X_hat*w);
   Yw_sort = sort(Y_hat*w);   
   sw_est(t,n_i) = sqrt(sigma1^2+sigma2^2 - 2*Xw_sort'*Yw_sort/n); 
   sw_est2(t,n_i) = sqrt(mean((Xw_sort-Yw_sort).^2));
end
msb_est(1,n_i) = one_side_max_sliced_bures_path(Q1_hat,Q2_hat);
msb_est(2,n_i) = one_side_max_sliced_bures_path(Q2_hat,Q1_hat);

msb_est_A(1,n_i) = one_side_max_bures_adam(Q1_hat,Q2_hat,max_iter);
msb_est_A(2,n_i) = one_side_max_bures_adam(Q2_hat,Q1_hat,max_iter);

msw_est(1,n_i) = max_sliced_w2_adam(X_hat,Y_hat,max_iter);
end
%%

grid_sw = max(sw_est2,[],1)';
grid_sb = max(sb_est,[],1)';

figure(4),clf
set(gcf,'position',get(gcf,'position').*[1 1 0 0] +  [0 0   383   274])
semilogx(N_tests,max(sb_exact)*ones(size(N_tests)),'k-','linewidth',5)

hold all

semilogx(N_tests,grid_sw,'o-','linewidth',3)
semilogx(N_tests,msw_est,'d-','linewidth',2,'markersize',12)
semilogx(N_tests,grid_sb,'v','linewidth',2,'markersize',14)
semilogx(N_tests,max(msb_est,[],1),'m-','linewidth',3)
semilogx(N_tests,max(msb_est_A,[],1),'c--','linewidth',2)

loc = 'northeast';
    legend('max-sliced Bures/Wasserstein-2 \{True covariance\}',...
    'max sliced Wasserstein-2 \{grid\}',...
    'max-sliced Wasserstein-2 \{ADAM\}',...
    'max sliced Bures \{grid\}',...
    'max-sliced Bures \{Eig.\}',...
    'max-sliced Bures \{ADAM\}',...
    'location',loc)

title('Estimated divergence between 2D Gaussians',...
    'fontweight','normal')
xlabel('m=n')

axis tight
set(gca,'ylim',[.1,.4],'fontsize',12)


