function [pert, C21, all_T12, all_C21] = NEXUS(M, N, options, corr_true)
%% get number of nodes
n1 = M.n;
n2 = N.n;
num = n1;
if n1 > n2
    num = n2;
end
if nargout > 2, all_T12 = {}; all_C21 = {}; end

adj1 = sparse(M.adj);
adj2 = sparse(N.adj);

%% params
max_T_o = 1;
max_T = 1;
max_F = 8;

%% neigborhoods
D1 = cell(max_F);
D2 = cell(max_F);
for i=1:max_F
    D1{i} = build_witnesses(adj1, n1, i);
    D2{i} = build_witnesses(adj2, n2, i);
end

%% initialize map with shot
T12 = knnsearch(N.shots, M.shots,'NSMethod','kdtree');
pert = sparse(T12, 1:length(T12), 1, n2, n1);
if nargout > 2, all_T12{end+1} = pert; all_C21{end+1} = 0;end

%% prepare output map
C21 = 0;
pert = sparse(T12, 1:length(T12), 1, n2, n1);
if nargout > 2, all_T12{end+1} = pert; all_C21{end+1} = C21;end

%% init refine param
th = linspace(1.8, 0.6, 10); %ceil(options.maxIter/3));
% th = 0.5;
num_lmks = [];
size_L2 = 0;
%% LOCAL CORRECTION smoothness, and accuracy
for kk = 1:options.maxIter   

    % Initialization
    good = 1:num;    
    D1T = D1{2}(good, good);
    D2T = D2{2}(T12(good), T12(good));

    % compute 56local error
    abs_diff = abs(D1T - D2T);
    e = sum(abs_diff, 2)./sum(D1T, 2);
   
    % figure()
    % plot(e)

    % get landmarks
    % landmarks = find(e<th);
    if kk<=length(th)
        landmarks = find(e<th(kk));
    else
        landmarks = find(e<th(end));
    end
    size_L = length(landmarks);
    num_lmks = [num_lmks, size_L];
    bad = setdiff(1:n1, landmarks);
    size_b = length(bad);

    % fprintf("[LOCAL] iter: %i, num_lmks: %i, num_bad: %i\n", kk, size_L, size_b);    

    % global improvement
    if size_b > 0                    
        cost = D2{max_T}*(pert*D1{max_T}(:, bad)) ;  
        pert = greedy_match(cost);
        [idx2, idx] = find(pert);
        bad = bad(idx);       
        T12(bad) = idx2;         

        max_T = max_T + 1;
        if max_T > max_F
            max_T = max_T_o;
        end        

        pert = sparse(T12, 1:length(T12), 1, n2, n1);
    end    

    if nargout > 2, all_T12{end+1} = pert; all_C21{end+1} = C21;end
end

