% Profile the speeds of W distance
% Load data
clear;


eta = 0.05;
n = 4;

mean_x = [1;1];
cov_x = [5, 1;1, 5];

mean_y = [2;2];
cov_y = [1, -0.2;-0.2, 1];

supp_x = mvnrnd(mean_x,cov_x, n)';
supp_y = mvnrnd(mean_y,cov_y, n)';

a = ones(1, n)/n;
b = ones(1, n)/n;

C1 = pdist2(supp_x', supp_y', 'sqeuclidean');
C2 = pdist2(supp_x', supp_y', 'euclidean');
C3 = pdist2(supp_x', supp_y', 'L1').^1.5;

N = 1;
C = zeros(n, n, N);
C(:,:, 1) = C1;

max_iter = 200;
clear options;
options.max_iter = max_iter;
options.N = N;
options.tau = 5 * eta / max(C,[], 'all')^2 ;
options.n = n;
[time1, fval1, pi1] = PAM(a, b, C, eta, options);

%% Plots
fs = 20;
colors = [ 'b', 'r'];
figure;

coor = zeros(2, 2 * n);
coor(:, 1:n) = supp_x;
coor(:, n+1:2*n) = supp_y;

for k = 1:N
A = zeros(2 * n, 2 * n);
for i = 1:n
   for j = 1:n
       if pi1(i, j, k) > 1e-3
           A(i, n + j) = 1;
           A(n + j, i) = 1;
       end

   end
end
gplot(A,coor', colors(k));
axis square
hold on;
end


sz = 100;
scatter(supp_x(1, 1:n), supp_x(2, 1:n), sz,'red', 'filled');
hold on;
scatter(supp_y(1, 1:n), supp_y(2, 1:n), sz, 'g', 'filled');
hold off;

titlestr = sprintf('OT: Square Euclidean Cost');   
title(titlestr);
ax = gca;
ax.TitleFontSizeMultiplier = 2;


N = 1;
C = zeros(n, n, N);
C(:,:, 1) = C2;

max_iter = 200;
clear options;
options.max_iter = max_iter;
options.N = N;
options.tau = 5 * eta / max(C,[], 'all')^2 ;
options.n = n;
[time1, fval1, pi1] = PAM(a, b, C, eta, options);

%% Plots
fs = 20;
colors = [ 'b', 'r'];
figure;

coor = zeros(2, 2 * n);
coor(:, 1:n) = supp_x;
coor(:, n+1:2*n) = supp_y;

for k = 1:N
A = zeros(2 * n, 2 * n);
for i = 1:n
   for j = 1:n
       if pi1(i, j, k) > 1e-3
           A(i, n + j) = 1;
           A(n + j, i) = 1;
       end

   end
end
gplot(A,coor', colors(k));
axis square
hold on;
end


sz = 100;
scatter(supp_x(1, 1:n), supp_x(2, 1:n), sz,'red', 'filled');
hold on;
scatter(supp_y(1, 1:n), supp_y(2, 1:n), sz, 'g', 'filled');
hold off;

titlestr = sprintf('OT: Euclidean Cost');   
title(titlestr);
ax = gca;
ax.TitleFontSizeMultiplier = 2;

N = 1;
C = zeros(n, n, N);
C(:,:, 1) = C3;

max_iter = 200;
clear options;
options.max_iter = max_iter;
options.N = N;
options.tau = 5 * eta / max(C,[], 'all')^2 ;
options.n = n;
[time1, fval1, pi1] = PAM(a, b, C, eta, options);

%% Plots
fs = 20;
colors = [ 'b', 'r'];
figure;

coor = zeros(2, 2 * n);
coor(:, 1:n) = supp_x;
coor(:, n+1:2*n) = supp_y;

for k = 1:N
A = zeros(2 * n, 2 * n);
for i = 1:n
   for j = 1:n
       if pi1(i, j, k) > 1e-3
           A(i, n + j) = 1;
           A(n + j, i) = 1;
       end

   end
end
gplot(A,coor', colors(k));
axis square
hold on;
end


sz = 100;
scatter(supp_x(1, 1:n), supp_x(2, 1:n), sz,'red', 'filled');
hold on;
scatter(supp_y(1, 1:n), supp_y(2, 1:n), sz, 'g', 'filled');
hold off;

titlestr = sprintf('OT: 1.5 L1 Cost');   
title(titlestr);
ax = gca;
ax.TitleFontSizeMultiplier = 2;


N = 3;
C = zeros(n, n, N);
C(:,:, 1) = C1;
C(:,:, 2) = C2;
C(:,:, 3) = C3;

max_iter = 200;
clear options;
options.max_iter = max_iter;
options.N = N;
options.tau = 5 * eta / max(C,[], 'all')^2 ;
options.n = n;
[time1, fval1, pi1] = PAM(a, b, C, eta, options);


%% Plots
fs = 20;
colors = [ 'b', 'r', 'g'];
agents = ['Square Euclidean', 'Euclidean', '1.5 L1'];

coor = zeros(2, 2 * n);
coor(:, 1:n) = supp_x;
coor(:, n+1:2*n) = supp_y;


figure;
A = zeros(2 * n, 2 * n);
for i = 1:n
   for j = 1:n
       if pi1(i, j, 1) > 1e-3
           A(i, n + j) = 1;
           A(n + j, i) = 1;
       end

   end
end
gplot(A,coor', colors(1));
axis square
hold on;

sz = 100;
scatter(supp_x(1, 1:n), supp_x(2, 1:n), sz,'red', 'filled');
hold on;
scatter(supp_y(1, 1:n), supp_y(2, 1:n), sz, 'g', 'filled');
hold on;

titlestr = sprintf('EOT: Agent Square Euclidean' );   
title(titlestr);
ax = gca;
ax.TitleFontSizeMultiplier = 2;



figure;
A = zeros(2 * n, 2 * n);
for i = 1:n
   for j = 1:n
       if pi1(i, j, 2) > 1e-3
           A(i, n + j) = 1;
           A(n + j, i) = 1;
       end

   end
end
gplot(A,coor', colors(2));
axis square
hold on;

sz = 100;
scatter(supp_x(1, 1:n), supp_x(2, 1:n), sz,'red', 'filled');
hold on;
scatter(supp_y(1, 1:n), supp_y(2, 1:n), sz, 'g', 'filled');
hold on;
titlestr = sprintf('EOT: Agent Euclidean' );   
title(titlestr);
ax = gca;
ax.TitleFontSizeMultiplier = 2;



figure;
A = zeros(2 * n, 2 * n);
for i = 1:n
   for j = 1:n
       if pi1(i, j, 3) > 1e-3
           A(i, n + j) = 1;
           A(n + j, i) = 1;
       end

   end
end
gplot(A,coor', colors(3));
axis square
hold on;



sz = 100;
scatter(supp_x(1, 1:n), supp_x(2, 1:n), sz,'red', 'filled');
hold on;
scatter(supp_y(1, 1:n), supp_y(2, 1:n), sz, 'g', 'filled');
hold on;

titlestr = sprintf('EOT: Agent 1.5 L1' );   
title(titlestr);
ax = gca;
ax.TitleFontSizeMultiplier = 2;







