clear all
close all
clc
%% add paths
addpath(genpath('manopt'));
addpath(genpath('./utils/'));
%%
dataset = "minnesota"; 
partiality_perc = 0.3;
partiality_type = "patch"; 
n_eigen_load = 50;
n_eigen = 50;
use_perc_eigen = false;
perc_eigen=0.75;
L_func = "graphNorm_laplacian";
n_landmarks = 50;
%%
icp_max_iters = 30;
fL = str2func(L_func);
%% load_dataset M
M = datasetLoader(dataset,0);
M.A = logical(M.A);
M.distMatrix = distances(graph(M.A));
%%
if use_perc_eigen
    results_path=sprintf("results/%s/%.2f %s/FM/%.2f %s/",...
     dataset,partiality_perc,partiality_type,perc_eigen,L_func);
else
    results_path=sprintf("results/%s/%.2f %s/FM/%i %s/",...
     dataset,partiality_perc,partiality_type,n_eigen,L_func);
end
mkdir(results_path);
%% compute M eigs
if use_perc_eigen
   k_eigM = ceil(perc_eigen * M.n);
else
    if n_eigen == 0
        k_eigM = M.n;
    else 
        k_eigM = min(n_eigen,M.n);
    end
end
if n_eigen_load > M.n
    n_eigen_load_M = M.n;
else
    n_eigen_load_M = n_eigen_load;
end
% check if already computed
eig_path = sprintf("data/partial/%s/Meigs_%i_%s.mat",...
            dataset,n_eigen_load_M,L_func);
if isfile(eig_path)
     L = load(eig_path); % load L
     if not(isfield(L,{'evecs'}))
         load(eig_path);
     end
     M = mergestructs(M,L); clear L;
else
    if n_eigen_load_M == 0
        n_eigen_load_M = M.n;
    end
    tic
    L = fL(M,n_eigen_load_M);
    M = mergestructs(M,L);
    time_eigs_full=toc;
    save(eig_path,'L','time_eigs_full'); clear L;
end 
M.evecs = M.evecs(:,1:k_eigM);
M.evals = M.evals(1:k_eigM);
%%
lookup_path=sprintf("data/partial/%s",M.name);
lookup_path=sprintf("%s/N %.2f %s/", lookup_path,partiality_perc,partiality_type);

result_pattern = "N_[0-9]+.mat";
listing = dir(lookup_path);
samples = regexp({listing(:).name},result_pattern,'match');
samples =  samples(~cellfun('isempty',samples));

n_samples = size(samples,2);
i=1;
%% loop
for i = 1:n_samples
     %% test
    close all;
    i_sample = str2double(samples{i}{1}(3:end-4));
    %% load N
    load(strcat(lookup_path,samples{i})); % N, nodes_removed  
    M.nodes_removed = nodes_removed;
    N.n = size(N.A,1);
    if dataset == "cora"
        N.coords = M.coords(setdiff(1:M.n,nodes_removed),:);
    end
    nodes_perc = N.n / M.n;
    edges_perc = full(sum(triu(N.A),'all') / sum(triu(M.A),'all'));
    fprintf("Loaded N%i: %.2f nodes; %.3f edges.\n",i_sample,nodes_perc,edges_perc);
    title_plot = sprintf("M-N%i \n%s:%.2f; %s; %i eigs; %s\n %.2f nodes; %.3f edges\n",...
                i_sample,M.name,partiality_perc,partiality_type,n_eigen,L_func,nodes_perc,edges_perc);
    %% compute laplacian 
    if use_perc_eigen
       k_eigN = ceil(perc_eigen * N.n);
    else
        if n_eigen == 0
            k_eigN = N.n;
        else 
            k_eigN = min(n_eigen,N.n);
        end
    end
    if n_eigen_load > N.n
        n_eigen_load_N = N.n;
    else
        n_eigen_load_N = n_eigen_load;
    end
    eig_path = sprintf("%sN_%ieigs_%i_%s.mat",...
                lookup_path,i_sample,n_eigen_load_N,L_func);
    if isfile(eig_path)
          L = load(eig_path); % load L
         if not(isfield(L,{'evecs'}))
             load(eig_path);
         end
         N = mergestructs(N,L); clear L;
         time_eigs;
    else
        if n_eigen_load_N == 0
            n_eigen_load_N = N.n;
        end
        tic
        L = fL(N,n_eigen_load_N);
        time_eigs=toc; 
        N = mergestructs(N,L);
        save(eig_path,'L','time_eigs'); clear L;
    end
    N.evecs = N.evecs(:,1:k_eigN);
    N.evals = N.evals(1:k_eigN);
    %% ground truth
    method_name = "gt";
    P = speye(M.n);
    P(:,M.nodes_removed) = []; % M.n x N.n
    C = CfromP(N.evecs,M.evecs,P);
    Cstruct.(method_name) = C; Pstruct.(method_name) = P;
    %% descriptor base on landmarks
     alpha=10;
    M_landmarks = N.gt(N.landmarks(1:n_landmarks));
    G = M.distMatrix(M_landmarks,:)';
    G = rescale(G,'InputMin',min(G),'InputMax',max(G));
    G = exp(-0.5*alpha*G);
    F = Pstruct.gt' * G;
    % coefficient
    A = N.evecs\F;
    B = M.evecs\G;
    %% eigs mask
    [W] = eigs_mask(N,M,'complRes');
    %% partial FM
    clear params
     method_name = "fmW";
    fprintf('\n#### Fmap with W ####\r\n');
    params.mask_algorithm = 'complRes';
    params.mu0 = 1;
    params.mu1 = 1e-1; % mask 1e-1 1e-3
    params.mu2 = 1e-2; % ortho 1e-2 1e-4
    params.refinement = 'icp'; % zoomout
    params.refinement_max_iter = icp_max_iters;
    
    fileID = fopen(strcat(method_path,"note.txt"),'w');
    fprintf(fileID,"Fmap with W\r\n");
    fn = fieldnames(params);
    for k=1:numel(fn)
        fprintf(fileID,"%s = %.2e \r\n",fn{k},params.(fn{k}));
    end
    fclose(fileID);
    
    tic
    [C,info] = fmW(M,N,G,F,params);  
    fprintf('\n#### ####\n');
    %% Zoom Out 
    clear params
    method_name = "fmW_ZM";
    div_factor = 3/4;
    fprintf('\n#### FmapW and ZoomOut ####\r\n');
    Co = C(1:fix(k_eigM*div_factor),1:fix(k_eigN*div_factor));
    [C,~] = zoom_out_rect(M,N, Co,k_eigM-size(Co,1),k_eigN-size(Co,2));
    fprintf('\n#### ####\n');
end
fprintf("Finish!");