import argparse
import random
import numpy as np
import networkx as nx
from grakel import GraphKernel, Graph
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx



def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    

def parse_args():
    parser = argparse.ArgumentParser(description="sub2vec.")
    parser.add_argument('--dataset', default='MUTAG', type=str)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--k', default=5, type=int, help='k for graphlet kernel')
    parser.add_argument('--n_iter', default=5, type=int, help= 'n_iter for wl kernel')
    parser.add_argument('--kernel', default='wl', choices=['gl', 'wl', 'dgk'])            
    args = parser.parse_args()
    return args


def graphlet_kernel():
    # Graphlet Kernel (GL)
    if search:
        params = {'C':[0.001, 0.01,0.1,1,10,100,1000]}
        classifier = GridSearchCV(SVC(), params, cv=5, scoring='accuracy', verbose=0)
    else:
        classifier = SVC(C=10)
    gk = GraphKernel(kernel={"name": "graphlet_sampling", "k": args.k}, normalize=True)
    K_train = gk.fit_transform(X_train)
    K_test = gk.transform(X_test)
    classifier.fit(K_train, y_train)
    y_pred = classifier.predict(K_test)
    acc = accuracy_score(y_test, y_pred)
    return acc
    
    
def weisfeiler_lehman_kernel():
    # Weisfeiler-Lehman Kernel (WL)
    if search:
        params = {'C':[0.001, 0.01,0.1,1,10,100,1000]}
        classifier = GridSearchCV(SVC(), params, cv=5, scoring='accuracy', verbose=0)
    else:
        classifier = SVC(C=10)  
    wl_kernel = GraphKernel(kernel={"name": "weisfeiler_lehman", "n_iter": args.n_iter}, normalize=True)
    K_train = wl_kernel.fit_transform(X_train)
    K_test = wl_kernel.transform(X_test)
    classifier.fit(K_train, y_train)
    y_pred = classifier.predict(K_test)
    acc = accuracy_score(y_test, y_pred)
    return acc

def deep_graph_kernel():
    # Deep Graph Kernel (DGK)
    if search:
        params = {'C':[0.001, 0.01,0.1,1,10,100,1000]}
        classifier = GridSearchCV(SVC(), params, cv=5, scoring='accuracy', verbose=0)
    else:
        classifier = SVC(C=10)
    dgk_kernel = GraphKernel(kernel={"name": "random_walk", "with_labels": False}, normalize=True)
    K_train = dgk_kernel.fit_transform(X_train)
    K_test = dgk_kernel.transform(X_test)
    classifier.fit(K_train, y_train)
    y_pred = classifier.predict(K_test)
    acc = accuracy_score(y_test, y_pred)
    return acc


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    set_seed(args.seed)
    
    dataset = TUDataset(root='data', name=args.dataset)
    graphs = [to_networkx(data) for data in dataset]
    labels = np.array([data.y.item() for data in dataset])
    features = []
    for idx, graph in enumerate(graphs):
        features.append({int(k): v for k, v in nx.degree(graph)})
    
    kernels = {'gl': graphlet_kernel, 'wl': weisfeiler_lehman_kernel, 'dgk': deep_graph_kernel}
    kernel_func = kernels[args.kernel]

    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
    search = True
    accuracies = []
    for train_index, test_index in kf.split(graphs, labels):
        X_train = [Graph(list(graphs[i].edges()), node_labels=features[i]) for i in train_index]
        X_test = [Graph(list(graphs[i].edges()), node_labels=features[i]) for i in test_index]
        y_train, y_test = labels[train_index], labels[test_index]
        acc = kernel_func()
        accuracies.append(acc)
    print(f'test acc: {np.mean(accuracies):.4f} +- {np.std(accuracies):.4f}')






