%% For all shapes this demo implements matching such that M is smaler shape and N is larger shape. 

addpath(genpath('.')),
close all
clear all
clc


%% options
options = struct;
options.shot_num_bins = 10; % number of bins for shot
options.shot_radius = 5;

options.shuffle = false; % leave false as not implemented on different shapes for DIR
options.isometric = false;
options.topological_noise = true;
options.remeshed = false;
options.partial = false;

options.option1.nb_iter_max = 30; % iteration number needed for fast marching %% this should be smallest
options.option2.nb_iter_max = 120; %% this largest
options.nb_iter_max = inf;


%% loading... for all pairs this demo implements matching larger to smaller shape.
name ='';
% 
%% %%For Isometric uncoment any of the isometric sample pairs below as well as
%% the gt_in line and lines 40-42
if options.isometric   
    i ='data\isometric\mesh03';
    j = 'data\isometric\mesh050';
    gt_in = [1:12500]';
    
    % i ='data\isometric\wolf1';
    % j = 'data\isometric\wolf2';
    % gt_in = [1:4344]';  

    options.isometric = true;
    gt = [gt_in, gt_in];
    % disp(size(gt))
end
%% %%for non_isometric uncomment any of the pairs below and the gt_in line
%% as well as lines 58-61
if options.topological_noise
    % i = 'data\topological_noise\kid16'; 
    % j = 'data\topological_noise\kid17';
    % gt_in = [1:10988]';%11292
       
    i = 'data\topological_noise\kid19'; 
    j = 'data\topological_noise\kid20';
    gt_in = [1:8515]';%11292
    
    gt_M_null = read_correspondence(strcat(name, i, '_ref.txt'));% load
    gt_N_null = read_correspondence(strcat(name, j, '_ref.txt')); %load
    gt = merge_ground_truth(gt_M_null, gt_N_null); % merge
   
    
end

%% for FAUST remeshed 
if options.remeshed
    i = 'data\remeshed\tr_reg_000';
    j = 'data\remeshed\tr_reg_001';

    % i = 'data\remeshed\8_gorilla_04';
    % j = 'data\remeshed\8_gorilla_02';

    % j = 'data\remeshed\8_woman_12';
    % i = 'data\remeshed\8_woman_01';

    % i = 'data\remeshed\8_man_04';
    % j = 'data\remeshed\8_man_02';

    % j = 'data\remeshed\mesh055';
    % i = 'data\remeshed\mesh02';

    lmk2 = load(strcat(j, '.vts'));
    lmk1 = load(strcat(i, '.vts'));
   
    options.remeshed = true;
    % options.isometric = false;
    % % options.partial = true;
end

%% specs for DIR
if ~(options.isometric)
    options.maxIter = 60;
    options.spec_dim = 200;
    options.th = [100];
else
    options.maxIter = 60;
    options.spec_dim = 200;
    options.spec_dim_cut = 180;
    low = 0.14; % local distortion lower bound
    gap = 5; % local distortion gaps
    options.th = 0.6-(0.5-low)/gap:-(0.5-low)/gap:low; % local distortion
    % options.th = 2.5:-0.1:1.5;
    options.th = [2.5, 2, 1.8, 1.4, 1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4];
    disp(options.th)
end

%% LOAD MESHES
[N, M, n1, n2, diameters, corr_true] = load_and_preprocess(i, j,'', '', '', '', options);
M.surface.X = M.VERT(:,1);
M.surface.Y = M.VERT(:,2);
M.surface.Z = M.VERT(:,3);
M.surface.VERT = M.VERT;
M.surface.TRIV = M.TRIV;
M.nf = size(M.TRIV,1);
M.nv = size(M.VERT,1);

N.surface.X = N.VERT(:,1);
N.surface.Y = N.VERT(:,2);
N.surface.Z = N.VERT(:,3);
N.surface.VERT = N.VERT;
N.surface.TRIV = N.TRIV;
N.nf = size(N.TRIV,1);
N.nv = size(N.VERT,1);

if options.remeshed
    l1 = [1:M.nv]';
    l2 = [1:N.nv]';
    gt_in = l1;
    % gt_ = merge_ground_truth([l1, lmk1], [l2, lmk2]);
    % % disp(gt(1:5, :))
    % perm = sparse(gt_(:, 2), gt_(:, 1), 1, N.nv, M.nv);
    % [corr_true, col] = find(perm);
elseif options.partial 
    if N.nv <= M.nv
        gt_in = [1:N.nv]';
        gt = [gt_in, lmk2];
    else
        % disp(size(lmk1))
        % disp(M.nv)
        gt_in = [1:M.nv]';
        gt = [gt_in, lmk1];
    end
end

%% PREPROCESS FOR DIR/HOPE
disp('---- preprocess ----')
tic
fprintf('geodesic processing for M...'); 
M.distances = [];
vec = double(int32(linspace(n1/5,n1,5)));
%     disp(vec)
begining = 1;
for kk=1:length(vec)%1:n1
    ending = vec(kk);
    idx = begining:ending;
    begining = ending + 1;
    distances = perform_fast_marching_mesh(M.VERT', M.TRIV', idx, options.option1);
    M.distances = [M.distances, distances];
%         disp(size(M.distances))
end
fprintf('done \n'); 
fprintf('geodesic processing for N...'); 
N.distances = [];
vec = double(int32(linspace(n2/5,n2,5)));
begining = 1;
for kk=1:length(vec)%1:n2
   ending = vec(kk);
   idx = begining:ending;
   begining = ending + 1;
   distances = perform_fast_marching_mesh(N.VERT', N.TRIV', idx, options.option2);
   N.distances = [N.distances, distances];
end
fprintf('done \n');
toc

%% geodesic distances for plot
fprintf('geodesic processing Full N for plot... ');
tic
distances = geodesic_distance(N.TRIV,N.VERT); %Added by JX  
%     distances = perform_fast_marching_mesh(N.VERT', N.TRIV', 1:n2, options);
% distances = sparse(distances);
fprintf('done \n')
toc

%% plot multiple
if ~(options.isometric)
    plot_ids = [1, 10, 15, 20, 25, 30, 35, 40, 50];
else
    plot_ids = [1,2,3,4,5,6,7,8,9];
end

%% plot feats

adj1 = sparse(M.adj);
adj2 = sparse(N.adj); 

tic
D1 = build_witnesses(adj1, n1, 20);
G1 = build_witnesses(adj1, n1, 2);
toc

tic
D2 = build_witnesses(adj2, n2, 20);
G2 = build_witnesses(adj2, n2, 2);
toc


% GRAPH LAP
% [~, U, lambda] = laplacian_from_TRIV_adj(surf2.adj, options.spec_dim);
[~, V, mu] = laplacian_from_TRIV_adj(M.adj, 5);

% LBO
try
    [evecs, evals] = eigs(M.Phi.W, M.Phi.A, 5, 1e-6);
catch
    % In case of trouble make the laplacian definite
    [evecs, evals] = eigs(M.Phi.W - 1e-8*speye(M.n), M.Phi.A, 5, 'sm');
end
evals = diag(evals);
if ~isreal(evecs)
    evecs(1:2:end) = real(evecs(1:2:end));
    evecs(2:2:end) = imag(evecs(1:2:end));
end
[evals, order] = sort(abs(evals),'ascend');
evecs = evecs(:,order);
B1_all = evecs;

hop2_100 = full(double(G1(:, 100)));
hop6_100 = full(double(D1(:, 100)));
disp(size(hop2_100))
disp(size(hop6_100))

figure();
subplot(1,6,1); plot_func_on_mesh(M, V(:, 2)); title('uniform laplacian 2');
subplot(1,6,2); plot_func_on_mesh(M, B1_all(:, 2)); title('LBO 2');
subplot(1,6,3); plot_func_on_mesh(M, hop2_100); title('2-hop of n=100');
subplot(1,6,4); plot_func_on_mesh(M, hop6_100); title('6-hop of n=100');
subplot(1,6,5); plot_func_on_mesh(M, M.shots(:, 2)); title('SHOT 2');
subplot(1,6,6); plot_func_on_mesh(M, full(M.distances(:, 2))); title('1-ring geodesic of n=2');

%% SHOT
tic
corr_SHOT = knnsearch(N.shots, M.shots,'NSMethod','kdtree');

figure();
subplot(1,2,1); visualize_map_on_source(M, N, corr_SHOT); title('Source');
subplot(1,2,2); visualize_map_on_target(M, N, corr_SHOT); title('SHOT')
toc

%% HOPE
disp('------ NEXUS ------')
tic
[corr_NEXUS, C21, all_corr_NEXUS, all_C21] = NEXUS(M, N, options, corr_true);
toc
[corr_NEXUS, ~] = find(corr_NEXUS);

figure();
subplot(1,2,1); visualize_map_on_source(M, N, corr_NEXUS); title('Source');
subplot(1,2,2); visualize_map_on_target(M, N, corr_NEXUS); title('NEXUS')

figure()
% plot_ids = [1, 2, 10, 20, 25, 30, 35, 40, 50];
plot_ids = [1, 2, 3, 4, 5];
subplot(2,3,1); visualize_map_on_source(M,N,corr_NEXUS); title('Source');
for iteration = 1:length(plot_ids)
    tt = plot_ids(iteration);
    [corr, ~] = find(all_corr_NEXUS{tt});
    subplot(2,3,iteration+1); visualize_map_on_target(M, N, corr);title(['iteration: ' num2str(tt)])
end

figure()
% plot_ids = [1, 2, 10, 20, 25, 30, 35, 40, 50];
plot_ids = [1, 2, 3, 4, 5];
subplot(2,3,1); imagesc(C21); title('Source');
for iteration = 1:length(plot_ids)
    tt = plot_ids(iteration);
    subplot(2,3,iteration+1); imagesc(all_C21{tt});title(['iteration: ' num2str(tt)])
end

if options.remeshed
    corr = corr_NEXUS;
    for m=1:size(lmk1,1)

        errors(m) = distances(corr(lmk1(m)), lmk2(m));
    end
else
    % disp(size(corr_NEXUS))
    % disp(size(gt_in))
    % disp(size(gt))
    corr = [1:length(corr_NEXUS)]';
    corr = [corr,corr_NEXUS];
    n_ = length(corr_NEXUS);
    if n_ > length(gt_in)
        n_ = length(gt_in);
    end
    disp(n_)
    errors = zeros(size(corr,1), 1);
    for m=1:n_  
        gt_match = gt(gt(:,1) == corr(m,1), 2);
        match = corr(m,2);
        % fprintf("m: %i, gt_match: %i, match: %i, size(distances):" + ...
        %     " (%i, %i)\n", m, gt_match, match, size(distances));

        if ~isempty(gt_match)                                                    
            % using the geodesic distance of the second graph
            errors(m) = distances(gt_match, match); % TODO include your geodesics here
        else
            errors(m) = -1;
        end
    end
end

thresholds = 0:0.01:0.25;
errors = errors / diameters;
disp(errors(1:10))
curve = zeros(1, length(thresholds));
for m=1:length(thresholds)
    curve(m) = 100*sum(errors <= thresholds(m)) / length(errors);
end
curve_NEXUS = curve;

figure()
plot(thresholds', mean(curve, 1)'), 
ylim([0 100]), 
line_width=1.5;
hline = findobj(gcf, 'type', 'line');
set(hline,'LineWidth',line_width);
legend({'NEXUS'},'FontSize', 10);

%% HOPE
disp('------HOPE------')
tic
[corr_HOPE, C21, all_corr_HOPE, all_C21] = HOPE(M, N, corr_true, options);
toc
[corr_HOPE, ~] = find(corr_HOPE);

figure();
subplot(1,2,1); visualize_map_on_source(M, N, corr_HOPE); title('Source');
subplot(1,2,2); visualize_map_on_target(M, N, corr_HOPE); title('HOPE')

figure()
% plot_ids = [1, 10, 15, 20, 25, 30, 35, 40, 50];
subplot(2,5,1); visualize_map_on_source(M,N,corr_HOPE); title('Source');
for iteration = 1:length(plot_ids)
    tt = plot_ids(iteration);
    [corr, ~] = find(all_corr_HOPE{tt});
    subplot(2,5,iteration+1); visualize_map_on_target(M, N, corr);title(['iteration: ' num2str(tt)])
end

if options.remeshed
    corr = corr_HOPE;
    for m=1:size(lmk1,1)

        errors(m) = distances(corr(lmk1(m)), lmk2(m));
    end
else
    % disp(size(corr_HOPE))
    % disp(size(gt_in))
    corr = [1:length(corr_HOPE)]';
    corr = [corr,corr_HOPE];
    n_ = length(corr_HOPE);
    if n_ > length(gt_in)
        n_ = length(gt_in);
    end
    errors = zeros(size(corr,1), 1);
    for m=1:n_  
        gt_match = gt(gt(:,1) == corr(m,1), 2);
        match = corr(m,2);

        if ~isempty(gt_match)                                                    
            % using the geodesic distance of the second graph
            errors(m) = distances(gt_match, match); % TODO include your geodesics here
        else
            errors(m) = -1;
        end
    end
end

thresholds = 0:0.01:0.25;
errors = errors / diameters;
disp(errors(1:10))
curve = zeros(1, length(thresholds));
for m=1:length(thresholds)
    curve(m) = 100*sum(errors <= thresholds(m)) / length(errors);
end

figure()
plot(thresholds', mean(curve, 1)'), 
ylim([0 100]), 
line_width=1.5;
hline = findobj(gcf, 'type', 'line');
set(hline,'LineWidth',line_width);
legend({'HOPE'},'FontSize', 10);

figure()
subplot(2,2,1);
visualize_map_on_source(M,N,corr_HOPE); title('Source');

subplot(2,2,2);
plot(thresholds', mean(curve, 1)', ...
    thresholds', mean(curve_NEXUS, 1)');
ylim([0 100]), 
line_width=1.5;
hline = findobj(gcf, 'type', 'line');
set(hline,'LineWidth',line_width);
legend({'HOPE', 'NEXUS'},'FontSize', 10);

subplot(2,2,3); visualize_map_on_target(M, N, corr_NEXUS); title('NEXUS')
subplot(2,2,4); visualize_map_on_target(M, N, corr_HOPE); title('HOPE')


%% DIR

options.th = [1, 0.9, 0.8, 0.7, 0.6];
options.maxIter = 60; 
options.spec_dim = 200;
options.spec_dim_cut = 180;
if ~(options.isometric)
    options.th = [100];    
end 
disp('------DIR------')
tic
corr_DIR = DIR(strcat(i, '.off'), strcat(j, '.off'), options, corr_true);
toc
figure();
subplot(1,2,1); visualize_map_on_source(M, N, corr_DIR); title('Source');
subplot(1,2,2); visualize_map_on_target(M, N, corr_DIR); title('DIR')

if options.remeshed
    corr = corr_DIR;
    for m=1:size(lmk1,1)

        errors(m) = distances(corr(lmk1(m)), lmk2(m));
    end
else
    disp(size(corr_DIR))
    disp(size(gt_in))
    corr = [1:length(corr_DIR)]';
    corr = [corr,corr_DIR];
    n_ = length(corr_DIR);
    if n_ > length(gt_in)
        n_ = length(gt_in);
    end
    errors = zeros(size(corr,1), 1);
    for m=1:n_  
        gt_match = gt(gt(:,1) == corr(m,1), 2);
        match = corr(m,2);

        if ~isempty(gt_match)                                                    
            % using the geodesic distance of the second graph
            errors(m) = distances(gt_match, match); % TODO include your geodesics here
        else
            errors(m) = -1;
        end
    end
end

thresholds = 0:0.01:0.25;
errors = errors / diameters;
curve = zeros(1, length(thresholds));
for m=1:length(thresholds)
    curve(m) = 100*sum(errors <= thresholds(m)) / length(errors);
end
    
figure()
plot(thresholds', mean(curve, 1)'), 
ylim([0 100]), 
line_width=1.5;
hline = findobj(gcf, 'type', 'line');
set(hline,'LineWidth',line_width);
legend({'DIR'},'FontSize', 10);

%% GRAMPA
disp('------GRAMPA------')
tic
corr_GRAMPA = GRAMPA(M, N, options, 1); %eta 1 as in original paper
[corr_GRAMPA, ~] = find(corr_GRAMPA);
toc
figure(4)
subplot(1,2,1); visualize_map_on_source(M, N, corr_GRAMPA); title('Source');
subplot(1,2,2); visualize_map_on_target(M, N, corr_GRAMPA); title('GRAMPA')

%% ZoomOut
options.spec_dim = 200;
disp('------ZoomOut------')
options.k_init = 20;
options.k_step = 1;
options.k_final = 200; % as in original paper from 4 or 20 to a max of 200 
tic
[corr_ZoomOut, ~, ~, ~] = zoomOut_refine(M, N, options);
toc
figure()
subplot(1,2,1); visualize_map_on_source(M, N, corr_ZoomOut); title('Source');
subplot(1,2,2); visualize_map_on_target(M, N, corr_ZoomOut); title('ZoomOut')

%% getting curves
if options.remeshed
    all_corr = cell(6,1);
    all_corr{1} = corr_HOPE;
    all_corr{2} = corr_GRAMPA;
    all_corr{3} = corr_DIR;
    all_corr{4} = corr_ZoomOut;
    all_corr{5} = corr_NEXUS;
    all_corr{6} = corr_SHOT;
else
    
    all_corr = cell(6,1);
    all_corr{1} = [[1:length(corr_HOPE)]', corr_HOPE];
    all_corr{2} = [[1:length(corr_GRAMPA)]', corr_GRAMPA];
    all_corr{3} = [[1:length(corr_DIR)]', corr_DIR];
    all_corr{4} = [[1:length(corr_ZoomOut)]', corr_ZoomOut];
    all_corr{5} = [[1:length(corr_NEXUS)]', corr_NEXUS];
    all_corr{6} = [[1:length(corr_SHOT)]', corr_SHOT];
end

all_curves = cell(6,1);
for i=1:6
    corr = cell2mat(all_corr(i));
    if options.remeshed
        for m=1:size(lmk1,1)
            errors(m) = distances(corr(lmk1(m)), lmk2(m));
        end
    else           
        n_ = length(corr);
        if n_ > length(gt_in)
            n_ = length(gt_in);
        end
        errors = zeros(size(n_,1), 1);
        for m=1:n_  
            gt_match = gt(gt(:,1) == corr(m,1), 2);
            match = corr(m,2);

            if ~isempty(gt_match)                                                    
                % using the geodesic distance of the second graph
                errors(m) = distances(gt_match, match); % TODO include your geodesics here
            else
                errors(m) = -1;
            end
        end
    end
    thresholds = 0:0.01:0.25;
    errors = errors / diameters;
    curve = zeros(1, length(thresholds));
    for m=1:length(thresholds)
        curve(m) = 100*sum(errors <= thresholds(m)) / length(errors);
    end
    all_curves{i} = curve;
end


%% combined
figure();
plot(thresholds', mean(all_curves{1}, 1)', ...
    thresholds', mean(all_curves{2}, 1)', ...
    thresholds', mean(all_curves{3}, 1)', ...
    thresholds', mean(all_curves{4}, 1)', ...
    thresholds', mean(all_curves{5}, 1)', ...
    thresholds', mean(all_curves{6}, 1)'), 
ylim([0 100]), 
line_width=1.5;
hline = findobj(gcf, 'type', 'line');
set(hline,'LineWidth',line_width);
legend({'HOPE', 'GRAMPA', 'DIR', 'ZoomOut', 'NEXUS', 'SHOT'},'FontSize', 10);

figure();
subplot(1,7,1); visualize_map_on_source(M, N, corr_HOPE); title('Source');
subplot(1,7,2); visualize_map_on_target(M, N, corr_HOPE); title('HOPE')
subplot(1,7,3); visualize_map_on_target(M, N, corr_GRAMPA); title('GRAMPA');
subplot(1,7,4); visualize_map_on_target(M, N, corr_DIR); title('DIR')
subplot(1,7,5); visualize_map_on_target(M, N, corr_ZoomOut); title('ZoomOut')
subplot(1,7,6); visualize_map_on_target(M, N, corr_NEXUS); title('NEXUS')
subplot(1,7,7); visualize_map_on_target(M, N, corr_SHOT); title('SHOT')



