clear all
clc

maxNumCompThreads(1);

typeGG = 'RandLLE'; % log-linear #edges
% typeGG = 'RandSLE'; % sqrt-linear #edges

dsName = 'amazon';
maxKC = 1000;

load([dsName '_' num2str(maxKC) '_' typeGG '_Graph.mat']);

GM = zeros(nGG, nGG);
disp('compute the ground graph metric');
tic
for ii = 1:(nGG-1)
    [~, TRD_II, ~] = shortestpathtree(GG, ii, [(ii+1):nGG], 'OutputForm', 'cell');  
    GM(ii, (ii+1):nGG) = TRD_II; 
    GM((ii+1):nGG, ii) = TRD_II';
end
runTime_GroundGM = toc;
 
% histogram
XX_ID_vec = zeros(N, nGG);
tic
for ii = 1:N
    % WW{ii}
    tmpWW = WW{ii}; % unit-mass for each word
    tmpXX_ID = XX_ID{ii};
    
    XX_ID_vec(ii, tmpXX_ID) = tmpWW'; 
end
runTime_Hist = toc;

% ================
nPair = 10000;
ff = load([dsName '_ID' num2str(nPair) '.mat']);
% ID: Nx2
ID = ff.ID;
DD_OrliczEPT = zeros(nPair,1);

% N-function (Phi2)
phi = @(X) exp(X.^2) - 1;
invphi = @(y) sqrt(log(y+1));

% sinkhorn parameters
epsilon = 0.1;

options.niter = 1000;
options.tau = 0;
options.verb = 0;
tol = 1e-6;

% parameter w_1, w_2
b = 1;
lambda = 1;
a0 = 1;
% root node
randID = randperm(nGG);
z0_ID = randID(1);

tic
% compute the OT
for iiID = 1:nPair
    
    ii = ID(iiID, 1);
    jj = ID(iiID, 2);

    if mod(iiID, 5) == 0
        disp(['++++++++++' num2str(iiID)]);
    end

    % preprocessing
    idNZII = find(XX_ID_vec(ii, :) > 0);
    idNZJJ = find(XX_ID_vec(jj, :) > 0);

    tmpII = XX_ID_vec(ii, idNZII);
    tmpJJ = XX_ID_vec(jj, idNZJJ);

    % input unbalanced measures
    mu = tmpII';
    nu = tmpJJ';

    % construct balanced measures (with the extra point)
    mu_hat = [mu; sum(nu)];
    nu_hat = [nu; sum(mu)];

    % normalization
    norm_term = sum(mu) + sum(nu);
    mu_hat = mu_hat/norm_term;
    nu_hat = nu_hat/norm_term;

    % cost func c
    c = GM(idNZII, idNZJJ);
    
    maxL = max(c(:));    

    % weight func w1, w2
    w1 = GM(z0_ID, idNZII);
    w1 = b*w1 + a0;
    
    w2 = GM(z0_ID, idNZJJ);
    w2 = b*w2 + a0;
    
    % cost func c_hat
    Nx = length(idNZII);
    Ny = length(idNZJJ);
    c_hat = zeros(Nx+1, Ny+1);
    c_hat(1:Nx, 1:Ny) = b*c;
    c_hat(Nx+1, 1:Ny) = w2' + b*lambda;
    c_hat(1:Nx, Ny+1) = w1 + b*lambda;
    c_hat(Nx+1, Ny+1) = b*lambda;

    % ----------------
    DD_OrliczEPT(iiID) = OrliczEPT(phi, invphi, mu, nu, maxL, mu_hat, nu_hat, c_hat, b, lambda, epsilon, tol);

end
runTime_Dist = toc;

runTime_Dist_ALL = runTime_Dist + runTime_GroundGM + runTime_Hist;

outName = [dsName '_Time_OrliczEPT_EXP2_' num2str(maxKC) '_' typeGG '_' num2str(nPair) '.mat'];

avgRunTime = sum(runTime_Dist_ALL)/nPair;

% for saving
paraALL.phi = phi;
paraALL.invphi = invphi;
paraALL.epsilon = epsilon;
paraALL.options = options;
paraALL.tol = tol;
paraALL.b = b;
paraALL.lambda = lambda;
paraALL.a0 = a0;
paraALL.z0_ID = z0_ID;

save(outName, 'DD_OrliczEPT', ...
     'runTime_Dist', 'runTime_GroundGM', 'runTime_Hist', 'runTime_Dist_ALL', ...
     'nPair', 'paraALL');
    
disp('FINISH !!!');


