clear all
close all
clc
%% add paths
addpath('./data/')
addpath(genpath('./utils/'));
%%
dataset = "citeseer";
n_eigen_load = 0;
n_eigen = 1000;
use_perc_eigen = true;
perc_eigen=0.01;
L_func = "graphNorm_laplacian";
n_landmarks = 50;
fL = str2func(L_func);
normalized = true;
pe = "20_rwM";
%%
sub_perc = 0.03;
%%
data_path=sprintf("data/%s/",dataset);
if use_perc_eigen
    results_path=sprintf("results/func_transf/%s/%.2f %s/", dataset,perc_eigen,L_func);
else
    results_path=sprintf("results/func_transf/%s/%i %s/", dataset,n_eigen,L_func);
end
if normalized
    results_path = sprintf("%snorm_",results_path);
end
mkdir(results_path);
%%
M = load(sprintf("%s/M.mat",data_path));
M.n = size(M.A,1);
M.name = dataset;
%% load pe
M.feature = load(sprintf("data/%s/Mpe_%s.mat",dataset,pe)).pe;
M.feature = double(M.feature);
%% compute M eigs
if use_perc_eigen
   k_eigM = ceil(perc_eigen * M.n);
else
    if n_eigen == 0
        k_eigM = M.n;
    else 
        k_eigM = min(n_eigen,M.n);
    end
end
if n_eigen_load > M.n
    n_eigen_load_M = M.n;
else
    n_eigen_load_M = n_eigen_load;
end
% check if already computed
eig_path = sprintf("%sMeigs_%i_%s.mat",...
            data_path,n_eigen_load_M,L_func);
if isfile(eig_path)
     L = load(eig_path); % load L
     if not(isfield(L,{'evecs'}))
         load(eig_path);
     end
     M = mergestructs(M,L); clear L;
else
    tic
    L = fL(M,n_eigen_load_M);
    M = mergestructs(M,L);
    save(eig_path,'L'); clear L;
end 
M.evecs = M.evecs(:,1:k_eigM-1);
%%
result_pattern = "N_[0-9]+.mat";
listing = dir(data_path);
samples = regexp({listing(:).name},result_pattern,'match');
samples =  samples(~cellfun('isempty',samples));
n_samples = size(samples,2);
i=1;
mse = zeros(n_samples,1);
%%
for i = 1:n_samples
    close all;
    i_sample = str2double(samples{i}{1}(3:end-4));
    N = load(strcat(data_path,samples{i}));
    N.n = size(N.A,1);
    nodes_perc = N.n / M.n;
    edges_perc = full(sum(triu(N.A),'all') / sum(triu(M.A),'all'));
    fprintf("Loaded N%i: %.2f nodes; %.3f edges.\n",i_sample,nodes_perc,edges_perc);
    %% compute laplacian  
    if use_perc_eigen
       k_eigN = ceil(perc_eigen * N.n);
    else
        if n_eigen == 0
            k_eigN = N.n;
        else 
            k_eigN = min(n_eigen,N.n);
        end
    end
    if n_eigen_load > N.n
        n_eigen_load_N = N.n;
    else
        n_eigen_load_N = n_eigen_load;
    end

    eig_path = sprintf("%sN_%ieigs_%i_%s.mat",...
                data_path,i_sample,n_eigen_load_N,L_func);
    if isfile(eig_path)
         L = load(eig_path); % load L
         if not(isfield(L,{'evecs'}))
             load(eig_path);
         end
         N = mergestructs(N,L); clear L;
    else
        L = fL(N,n_eigen_load_N);        
        N = mergestructs(N,L);
        save(eig_path,'L'); clear L;
    end
    N.evecs = N.evecs(:,1:k_eigN-1);
    %% define C gt
    P = sparse(int64(1:N.n),int64(N.gt),1,N.n,M.n);
    C = CfromP(M.evecs,N.evecs,P); % full -> sub
    %% normalize (std = 1, mean=0)
    if normalized
        feat = M.feature;
        feat = (feat ./ std(feat,0,1));
        feat = feat - mean(feat,1);
        feat = feat + 1;
        M.feature = feat;
    end
    %% transfer features
    feat_gt = M.feature(N.gt,:);
    feat_transfer =  N.evecs * C * (M.evecs \ M.feature);
    feat_full = M.feature;
    full2sub = N.gt;
    %% compute mse
    diff_features = ((feat_gt - feat_transfer)./(max(M.feature)-min(M.feature))).^2;
    mse_features = mean(diff_features,1);
    mse(i) = mean(mse_features(~isnan(mse_features)));
end
%%
mean_mse = mean(mse);
std_mse = std(mse);
fprintf("Mse mean:%.2f; std:%.2f\n",...
    mean_mse,std_mse);