clear all
close all
clc

%============
% For 1 run

rng('default')
addpath_toolbox

% number of points.
Nx = 20;
Ny = 15;

dim = 5;

% input supports
x = rand(dim, Nx);
y = rand(dim, Ny);

wx = rand(Nx, 1);
wx = wx/sum(wx);
wx = Nx*wx;

wy = rand(Ny, 1);
wy = wy/sum(wy);
wy = Ny*wy;

% input unbalanced measures
mu = wx;
nu = wy;

% construct balanced measures
mu_hat = [wx; sum(wy)];
nu_hat = [wy; sum(wx)];

% normalization
norm_term = sum(mu) + sum(nu);
mu_hat = mu_hat/norm_term;
nu_hat = nu_hat/norm_term;

% cost func c
EPS = 1e-10;
c = sqdistance(x, y);
c(c < EPS) = 0;
c = sqrt(c);

maxL = max(c(:));

% parameter w_1, w_2
b = 1;
lambda = 1;
a0 = 1;
% root node
z0 = rand(dim, 1);

% weight func w1, w2
w1 = sqdistance(z0, x);
w1(w1 < EPS) = 0;
w1 = sqrt(w1);
w1 = b*w1 + a0;

w2 = sqdistance(z0, y);
w2(w2 < EPS) = 0;
w2 = sqrt(w2);
w2 = b*w2 + a0;

% cost func c_hat
c_hat = zeros(Nx+1, Ny+1);
c_hat(1:Nx, 1:Ny) = b*c;
c_hat(Nx+1, 1:Ny) = w2' + b*lambda;
c_hat(1:Nx, Ny+1) = w1 + b*lambda;
c_hat(Nx+1, Ny+1) = b*lambda;

% sinkhorn parameters
epsilon = 0.1;

options.niter = 1000;
options.tau = 0;
options.verb = 0;
tol = 1e-6;

% N-function
phi = @(X) exp(X.^2) - 1;
invphi = @(y) sqrt(log(y+1));

tic
val = OrliczEPT(phi, invphi, mu, nu, maxL, mu_hat, nu_hat, c_hat, b, lambda, epsilon, tol);
toc




