"""
=========================================================================
                    GRAPH CLASSIFICATION EXPERIMENT 
=========================================================================
"""
from __future__ import print_function
print(__doc__)

# -- dataset loader
import os
import sys
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import get_laplacian
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils.isolated import remove_isolated_nodes, contains_isolated_nodes
from dataset_loader import *

# -- others
import time
import argparse
import numpy as np
from data_generator import *
from Predictor import *
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from collections import defaultdict
from prettytable import PrettyTable
from wl2 import*

# %% --------------------------------------------------------------------------------------------------
# SETTING
# -----------------------------------------------------------------------------------------------------

# -- path where dataset folder is
path = './dataset/'

# -- command line setting
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='MUTAG', help='dataset name')
parser.add_argument('--h', default=4, type=int, help='iteration number')
parser.add_argument('--mode', default='s', type=str, help='SVM:s, k-NN:k')
parser.add_argument('--C', default=100, type=float, help='parameter C of SVC')
parser.add_argument('--kernel', default='precomputed', type=str, help='Kernel function')
parser.add_argument('--gamma', default=0.01, type=float, help='parameter gamma of graph kernel')
parser.add_argument('--k', default=1, type=int, help='parameter k of k-NN')
parser.add_argument('--gridsearch', action='store_true', help='Grid Search')
parser.add_argument('--no_npy', action='store_false', help='does not read .npy file')
args = parser.parse_args()

# Definition of the graph class -----------------------------------------------------------------------
class Graph():
    def __init__(self, data):
        self.edge = data.edge_index.tolist() # information about which nodes are interconnected
        self.label = data.y.tolist()[0]      # the graph label
        self.attr = data['x'].tolist()       # the node label
        self.adj_node = {}
        self.adj_label = {}
        self.wl_subtrees = []
        self.paths = []
        self.wl_subtree_subtrees = []
    
    def _free(self):
        self.edge = []
        self.label = []
        self.attr = []
        self.adj_node = {}
        self.adj_label = {}
        self.wl_subtrees = []
        self.paths = []
        self.wl_subtree_subtrees = []
# -----------------------------------------------------------------------------------------------------

# -- load dataset
print('Dataset is loading...')
name = args.dataset
dataset = get_dataset(name)
print('Done!')
graphs = [Graph(data) for data in dataset]
y = [graph.label for graph in graphs]

# -- SVM kernel
SVC_kernel = 'precomputed'
# -- K-fold
K_fold = 10
# -- random seed
np.random.seed(42)
# -- random states
random_states = [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]

# -- output the setting information
if args.mode == 's':
    table1 = PrettyTable(['Dataset', 'h', 'Classifier', 'C', 'Gamma', 'Grid Search'])
    table1.add_row([name, args.h, 'SVM', args.C, args.gamma, args.gridsearch])
else:
    table1 = PrettyTable(['Dataset', 'h', 'Classifier', 'k', 'Grid Search'])
    table1.add_row([name, args.h, 'k-NN', args.k, args.gridsearch])
print(table1)

# %% --------------------------------------------------------------------------------------------------
# EXPERIMENT
# -----------------------------------------------------------------------------------------------------

# -- ordinal function for outputting the test number
ordinal = lambda n: '%d%s' % (n, 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])

filename = f'npy/{name}/{args.h}.npy'
elapsed_time = 0

if os.path.exists(filename) and args.no_npy:
    M = np.load(filename)
else:
    # -- calclulate the distance matrices
    start = time.time()
    print('Doing message-passing...')
    graph_embeddings = compute_wl_embeddings_discrete(graphs, args.h)
    print('Done!')
    # -- compute wasserstein distance matrices
    # wasserstein_distances = compute_wasserstein_distance(label_sequences, h, sinkhorn=False, discrete=True)
    M = compute_wl_subtree_wasserstein_distance(graph_embeddings, args.h)
    print('\nDone!')
    elapsed_time = time.time() - start

    # -- save the matrix as a numpy file
    os.makedirs('npy', exist_ok=True)
    os.makedirs(f'npy/{name}', exist_ok=True)
    np.save(filename, M)

if args.mode == 's' and SVC_kernel == 'precomputed':
    distance_matrices = np.exp(-args.gamma*M)
else:
    distance_matrices = M

# -- k-fold cross validation with ten different random states
count = 1
accs = []
table2 = PrettyTable(['No.','Result'])
print('Doing %d classification experiments...' % (len(random_states) * K_fold))

for random_state in random_states:
    sfolder = StratifiedKFold(n_splits=K_fold, shuffle=True, random_state=random_state)
    for Train_index, Test_index in sfolder.split(distance_matrices, y):
        X_train, y_train, X_test, y_test = data_generator(y=y, distance_matrices=distance_matrices, 
            Train_index=Train_index, Test_index=Test_index)

        if args.mode == 's':
            y_pred = Graph_predictor(X_train=X_train, y_train=y_train, 
                X_test=X_test, y_test=y_test, SVC_kernel=SVC_kernel, SVC_C=args.C, name=name, mode='s')
        else:
            y_pred = Graph_predictor(X_train=X_train, y_train=y_train, 
                X_test=X_test, y_test=y_test, SVC_kernel=SVC_kernel, SVC_C=args.C, name=name, mode='k')
    
        acc = accuracy_score(y_test, y_pred)
        accs.append(acc)
        table2.add_row([ordinal(count),acc])
        count += 1
accs = np.array(accs)
print('Done!')
print(table2)

# -- print experimental results
average_res = f'{accs.mean()*100:.2f}+-{accs.std()*100:.2f}%'
max_res = f'{accs.max()*100:.2f}%'
min_res = f'{accs.min()*100:.2f}%'
elapsed_time = f'{elapsed_time:.1f}s'
if args.mode == 's':
    table3 = PrettyTable(['Dataset','h','Classifier','Best C','Best Gamma','Average','Max','Min', 'Time'])
    table3.add_row([name, args.h, 'SVM', args.C, args.gamma, average_res, max_res, min_res, elapsed_time])
else:
    table3 = PrettyTable(['Dataset','h', 'Classifier','k','Average','Max','Min', 'Time'])
    table3.add_row([name, args.h, 'k-NN', args.k, average_res, max_res, min_res, elapsed_time])
print(table3)
