from sklearn.metrics.pairwise import pairwise_kernels
from analysis.eden import vectorize
from utils import mols_to_nx
import numpy as np
import time


def nspdk_stats(graph_ref_list, graph_pred_list):
    graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0]

    prev = time.time()
    mmd_dist = compute_nspdk_mmd(graph_ref_list, graph_pred_list_remove_empty, metric='nspdk', is_hist=False, n_jobs=20)
    elapsed = time.time() - prev
    print('Time computing degree mmd: ', elapsed)
    return mmd_dist


def compute_nspdk_mmd(samples1, samples2, metric, is_hist=True, n_jobs=None):
    def kernel_compute(X, Y=None, is_hist=True, metric='linear', n_jobs=None):
        # transform graphs into numerical feature vectors using NSPDK
        ## extract subgraphs of radius up to complexity around each node
        ## compute pairwise kernel values
        X = vectorize(X, complexity=4, discrete=True)
        if Y is not None:
            Y = vectorize(Y, complexity=4, discrete=True)
        return pairwise_kernels(X, Y, metric='linear', n_jobs=n_jobs)

    X = kernel_compute(samples1, is_hist=is_hist, metric=metric, n_jobs=n_jobs)
    Y = kernel_compute(samples2, is_hist=is_hist, metric=metric, n_jobs=n_jobs)
    Z = kernel_compute(samples1, Y=samples2, is_hist=is_hist, metric=metric, n_jobs=n_jobs)

    return np.average(X) + np.average(Y) - 2 * np.average(Z)


def compute_nspdk(graph_ref_list, graph_pred_list, methods=['nspdk'], kernels=None):
    results = {}
    NUM=4
    for method in methods:
        result = nspdk_stats(graph_ref_list, graph_pred_list)
        results[method] = round(result, NUM)
        print(f'{method:10s}' + ' : ' + f'{results[method]:.4f}')
    return results