# -----------------------------------------------------------------------------
# Adapted from:
# This script runs the experiments reported in the WWL paper 
# /!\ it requires a remastered version of their package wwl to support continuous features
# to use same dependencies than us.
# October 2019, M. Togninalli, E. Ghisu, B. Rieck
# -----------------------------------------------------------------------------
import numpy as np
import os
import wwl
import igraph as ig
from data_handler import dataloader
from Kernels import FGW_kernels, WL_kernels
from sys import argv
#%%
"""
# python run_WWLkernels.py 'imdb-b' 1 [1,2,3] 3

dataset_name = str(argv[1])
assert dataset_name in ['imdb-b', 'imdb-m', 'mutag', 'ptc', 'nci1', 'bzr', 'cox2', 'enzymes', 'protein', 'collab']
use_wass = bool(int(argv[2]))
if '[' in argv[3]: # alphas for the FGW loss > if alpha = -1, it will be considered as a learnable parameter initialized at 0.5
    list_wl = [int(x) for x in argv[3][1:-1].split(',')] 
else:
    list_wl = [int(argv[3])]
n_jobs = int(argv[4])
"""
#%%

dataset_name = 'mutag'
list_wl = [1,2,3,4,5,6,7,8,9,10]
use_wass = True
n_jobs = 3

if dataset_name in ['mutag', 'ptc', 'enzymes', 'protein', 'nci1']:
    features_mode = None
else:
    features_mode = 'degree' # as set by authors

abspath = os.path.abspath('../')
res_repo = abspath+'/kernel_results/%s/'%dataset_name
data_path = abspath+'/real_datasets/'
#n_jobs = 3

for wl in list_wl:
    str_use_wass = ''
    if use_wass:
        str_use_wass = 'Wass'
    if not dataset_name in ['imdb-b', 'imdb-m', 'collab']:
        experiment_name = '/%sWLkernel_wl%s'%(str_use_wass , wl)
    else:
        assert features_mode in ['degree', 'ones', 'onehot']
        experiment_name = '/%sWLkernel_%s_wl%s'%(str_use_wass , features_mode, wl)
    experiment_repo = res_repo + experiment_name
    if not os.path.exists(experiment_repo):
        os.makedirs(experiment_repo)
    

    if dataset_name in ['mutag', 'ptc', 'nci1', 'imdb-b', 'imdb-m', 'collab']:
        embedding_type = 'discrete' 
        categorical = True
        dist = 'hamming'
    elif dataset_name in ['protein', 'enzymes']:
        embedding_type = 'continuous'
        categorical = False
        dist = 'euclidean'
    str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path'}
    
    
    if dataset_name in ['mutag', 'ptc', 'nci1']:  # One-hot encoding
        one_hot = False
    elif dataset_name in ['enzymes', 'protein', 'bzr', 'cox2']:
        one_hot = False
    elif dataset_name in ['imdb-b', 'imdb-m', 'collab']:
        one_hot = False
    X,labels=dataloader.load_local_data(data_path, dataset_name, one_hot=one_hot)                               
    unique_labels = np.unique(labels)
    #print('unique_labels to fit to [0,C-1] :', unique_labels)
    unprocessed_labels = np.array(labels)
    new_labels = np.zeros_like(unprocessed_labels)
    for idx_y, y in enumerate(unique_labels):
        idx_samples = np.argwhere(unprocessed_labels==y)[:,0]
        new_labels[idx_samples] = idx_y
    labels = new_labels
    unique_labels = np.unique(labels)
    #print('unique labels after processing: ', unique_labels)
    graphs = [X[t].distance_matrix(method=str_to_method['ADJ']) for t in range(X.shape[0])]
    masses = [np.ones(C.shape[0])/C.shape[0] for C in graphs]
    if not dataset_name in ['imdb-b', 'imdb-m', 'collab']:
        features= [np.array(X[t].values()) for t in range(X.shape[0])]
    else:
        if features_mode == 'degree':    
            features = [C.sum(axis=0) for C in graphs]
        elif features_mode == 'ones':
            features = [np.ones(C.shape[0], dtype=np.float64) for C in graphs]
        elif features_mode == 'onehot':
            degs = [C.sum(0) for C in graphs]
            max_deg = np.max([deg.max() for deg in degs]).astype(np.int64)                                        
            min_deg = np.min([deg.min() for deg in degs]).astype(np.int64)
            diff_deg = max_deg - min_deg
            features = []
            for deg in degs:
                N = deg.shape[0]
                F = np.zeros((N, diff_deg + 1), dtype=np.float64)
                for i in range(N):
                    F[i, deg[i].item() - min_deg ] = 1.
                features.append(F)
        for idx_x, x in enumerate(X):
            for idx_node, node in enumerate(x.nx_graph.nodes):
                x.add_one_attribute(node, features[idx_x][idx_node])
        
    ig_graphs = [ig.Graph.from_networkx(x.nx_graph) for x in X]
    if embedding_type == 'discrete':
        wl_transform = wwl.WeisfeilerLehman()
        label_dicts = wl_transform.fit_transform(ig_graphs, wl)
        
    elif embedding_type == 'continuous':
        wl_transform = wwl.ContinuousWeisfeilerLehman()
        node_features = np.concatenate(features, axis=0)
        label_dicts = wl_transform.fit_transform(ig_graphs, node_features, wl)
        
    #node_features_labels, adj_mat, n_nodes = wl_transform._preprocess_graphs(ig_graphs)
    
    
    if use_wass:
        # Compute the Wasserstein distance
        if not os.path.exists(experiment_repo+'/pairwise_distance_matrix.npy'):
            masses = [np.ones(F.shape[0])/F.shape[0] for F in label_dicts]
            D = WL_kernels.compute_wasserstein_distance_parallel(label_dicts, masses, dist, n_jobs)
            np.save(experiment_repo+'/pairwise_distance_matrix.npy', D)
        else:
            D = np.load(experiment_repo+'/pairwise_distance_matrix.npy')
    else:
        if not os.path.exists(experiment_repo+'/pairwise_distance_matrix.npy'):
            D = WL_kernels.compute_wl_distance_parallel(label_dicts, dist, n_jobs)
            np.save(experiment_repo+'/pairwise_distance_matrix.npy', D)
        else:
            D = np.load(experiment_repo+'/pairwise_distance_matrix.npy')
    
    if not os.path.exists(experiment_repo+'/res_SVC.csv'):
    
        res_svc, full_res_svc = FGW_kernels.classification_SVC(D, labels, 10, 10)
        res_svc.to_csv(experiment_repo+'/res_SVC.csv')
        full_res_svc.to_csv(experiment_repo+'/full_res_SVC.csv')
    else:
        continue
