import sys
import os
sys.path.append((os.path.abspath('../utils')))
import time
import numpy as np
import pandas as pd
import InspectData as ID
import grakel
from grakel import GraphKernel
#from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
#from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold, cross_validate
from Torelli import *

# arguments: dataset index, kernel index
data_names = ['AIDS','BZR','BZR_MD','COIL-DEL','COX2','COX2_MD','DHFR','DHFR_MD',
              'ER_MD','DD','ENZYMES','FIRSTMM_DB','FRANKENSTEIN','IMDB-BINARY',
              'IMDB-MULTI','MSRC_9','MSRC_21','MSRC_21C','MUTAG','NCI1','NCI109',
              'PROTEINS','REDDIT-BINARY','REDDIT-MULTI-5K']
kernel_names = ['torelli_wasserstein','torelli_euclidean','edge_histogram',
              'shortest_path','graphlet_sampling','odd_sth','core_framework']  
# random_walk has problem with scipy version
# svm_theta has problem with scipy version
# multiscale_laplacian has problem with graph formats
# graph_hopper has problem with code
# i= 0--23
# j = 0--6

i = int(sys.argv[1])
j = int(sys.argv[2])
data_name = data_names[i]
kernel_name = kernel_names[j]

# load dataset
dataset = ID.load_data(data_name)
graphs, labels = dataset.data, dataset.target
ID.init_weights_labels(graphs)

# genus bound
stats = ID.show_stats(graphs)
genus = stats['genus']
ave_genus = np.mean(genus)
std_genus = np.std(genus)


# train/test split
#G_train, G_test, y_train, y_test = train_test_split(
#    graphs, labels, test_size=0.1, random_state=42
#)

# classification
print(f"\nDataset: {data_name}")
print(f"\nTesting kernel: {kernel_name}")


if kernel_name in ['torelli_wasserstein','torelli_euclidean']:
    if kernel_name == 'torelli_wasserstein':
        kernel = TorelliWasserstein(dimbound=np.floor(ave_genus+std_genus).astype(int),normalize=False)
    else:
        kernel = TorelliEuclidean(dimbound=np.floor(ave_genus+std_genus).astype(int),normalize=False)
    t0 = time.time()
    kernel.fit(graphs)
    t1 = time.time()
    K = kernel.fit_transform(graphs)
    t2 = time.time()
    D = np.outer(np.diag(K), np.ones(K.shape[0])) + np.outer(np.ones(K.shape[0]), np.diag(K)) - 2*K
    # Apply Gaussian kernel
    gamma = 0.5
    K_gaussian = np.exp(-gamma* D)

    clf = SVC(kernel='precomputed',max_iter=10000)
    kf = KFold(n_splits=10, shuffle=True, random_state=i+j)
    results = cross_validate(clf, K_gaussian, labels, cv=kf, scoring='accuracy', return_estimator=True, n_jobs=-1)
    acc = np.mean(results['test_score'])
    std = np.std(results['test_score'])
else:
    if kernel_name == 'graphlet_sampling':
        # set hyperparameters for graphlet
        kernel = GraphKernel(kernel={"name":kernel_name,"k":6,"n_samples":300},normalize=False)
    else:
        kernel = GraphKernel(kernel=kernel_name,normalize=False)

    # time and fit_transform on train
    t0 = time.time()
    kernel.fit(graphs)
    t1 = time.time()
    K = kernel.fit_transform(graphs) # the transform method is bugged for some kernels
    t2 = time.time()
    clf = SVC(kernel='precomputed',max_iter=10000)
    kf = KFold(n_splits=10, shuffle=True, random_state=i+j)
    results = cross_validate(clf, K, labels, cv=kf, scoring='accuracy', return_estimator=True, n_jobs=-1)
    acc = np.mean(results['test_score'])
    std = np.std(results['test_score'])

# store results
path = os.path.abspath('../results/benchmark/')
csv_file = os.path.join(path, data_name+kernel_name+".csv")

kernel_acc_key = f'{kernel_name}_acc'
kernel_std_key = f'{kernel_name}_std'
kernel_fit_time_key = f'{kernel_name}_fit_time'
kernel_trans_time_key = f'{kernel_name}_trans_time'

try:
    df = pd.read_csv(csv_file, index_col=0)
except FileNotFoundError:
    df = pd.DataFrame()

if data_name in df.index:
    df.loc[data_name, kernel_acc_key] = round(acc*100, 2)
    df.loc[data_name, kernel_std_key] = round(std*100, 2)
    df.loc[data_name, kernel_fit_time_key] = round(t1-t0, 4)
    df.loc[data_name, kernel_trans_time_key] = round(t2-t1, 4)
else:
    row = pd.DataFrame({
        kernel_acc_key: round(acc*100, 2),
        kernel_std_key: round(std*100, 2),
        kernel_fit_time_key: round(t1-t0, 4),
        kernel_trans_time_key: round(t2-t1, 4)
    }, index=[data_name])
    df = pd.concat([df, row], axis=0)

df.to_csv(csv_file, mode='a', encoding='utf-8', errors='ignore')
