%%
clear all
close all
clc

addpath(genpath(pwd));
addpath('H:/My Drive/Workspace/matlab/common')  % for spyc()

n = 11;  % n. of points
d = 84;  % n. of feature dimensions

% generate three identical point sets with random reordering
X = rand(n, d);

% conceptually, we can think of X as the universe
P_XX = speye(n);

P_YX = sparse(1:n, randperm(n), 1, n, n); % permutation from universe (X) to Y
Y = P_YX*X;

P_ZX = sparse(1:n, randperm(n), 1, n, n); % permutation from universe (X) to Z
Z = P_ZX*X;

U_gt = [P_XX; P_YX; P_ZX];
P_gt = U_gt*U_gt';

% sanity check: P_gt cycle-consistent <=> rank(P_gt) == n (numerics may be an issue)
assert(rank(full(P_gt)) == n)

P_ZY = P_ZX*P_YX';

% we can also manually construct P_gt:
P_gt2 = [
    speye(n), P_YX', P_ZX'; ...
    P_YX, speye(n), P_ZY'; ...
    P_ZX, P_ZY, speye(n)];

assert(norm(P_gt - P_gt2, 'fro') < 1e-6)

% old code

% P_ZY = P_ZX*P_YX'; % permutation from Y to universe (X) to Z
% % however, we need from universe to Z
%
% % ground-truth permutations
% P_gt = [
%     speye(n), P_YX, P_ZX;
%     P_YX', speye(n), P_ZY;
%     P_ZX', P_ZY', speye(n)
%     ];

% pairwise similarities
% S_YX = exp(-pdist2(Y, X));
% S_ZX = exp(-pdist2(Z, X));
% S_ZY = exp(-pdist2(Z, Y));
S_YX = Y * X';
S_ZX = Z * X';
S_ZY = Z * Y';

% old
% S_in = [
%     eye(n), S_YX, S_ZX;
%     S_YX', eye(n), S_ZY;
%     S_ZX', S_ZY', eye(n)
%     ];

% new
S_in = [
    eye(n), S_YX', S_ZX';
    S_YX, eye(n), S_ZY';
    S_ZX, S_ZY, eye(n)
    ];

% run the three algorithms
[P_our, U_our] = SparseStiefelSync(S_in, n*ones(3, 1), n, 0);
assert(norm(P_our - U_our*U_our', 'fro') < 1e-6)

[P_spectral, ~, ~, U_spectral] = mmatch_spectral(S_in, n*ones(3, 1), n);
assert(norm(P_spectral - U_spectral*U_spectral', 'fro') < 1e-6)

[P_nmf, U_nmf] = nmfSync(full(S_in), n*ones(3, 1), n, []);
assert(norm(P_nmf - U_nmf*U_nmf', 'fro') < 1e-6)

% P_our = P_spectral;
% U_our = U_spectral;

% extract the permutations from P_our
P_YX_our = P_our(1:n, n+1:2*n);
P_ZX_our = P_our(1:n, 2*n+1:end);
P_ZY_our = P_our(n+1:2*n, 2*n+1:end);

figure,spy(P_ZX_our*P_ZY_our'*P_YX_our'),title('identity means cycle-consistent (our)')

P_YX_gt = P_gt(1:n, n+1:2*n);
P_ZX_gt = P_gt(1:n, 2*n+1:end);
P_ZY_gt = P_gt(n+1:2*n, 2*n+1:end);
figure,spy(P_ZX_gt*P_ZY_gt'*P_YX_gt'),title('identity means cycle-consistent (gt)')

figure;
subplot 131
imagesc(S_in);
hold on;
spy(P_gt),
title(['gt objP = ' num2str(P_gt(:)'*S_in(:)) ' objU = ' num2str(P_gt(:)'*S_in(:))])

subplot 132
imagesc(S_in);
hold on;
spy(P_our),
title(['StiefelSync objP = ' num2str(P_our(:)'*S_in(:)) ' objU = ' num2str(trace(U_our'*S_in*U_our))])

subplot 133
imagesc(S_in);
hold on;
spy(P_spectral),
title(['Spectral objP = ' num2str(P_spectral(:)'*S_in(:)) ' objU = ' num2str(trace(U_spectral'*S_in*U_spectral))])



figure
subplot(331), spy(P_YX_gt), title('YX gt')
subplot(334), spy(P_YX_our), title('YX our')
subplot(332), spy(P_ZX_gt), title('ZX gt')
subplot(335), spy(P_ZX_our), title('ZX our')
subplot(333), spy(P_ZY_gt), title('ZY gt')
subplot(336), spy(P_ZY_our), title('ZY our')
subplot(337), spyc(P_YX_gt-P_YX_our), colorbar off, title('error')
subplot(338), spyc(P_ZX_gt-P_ZX_our), colorbar off, title('error')
subplot(339), spyc(P_ZY_gt-P_ZY_our), colorbar off, title('error')
