import sys
import os
sys.path.append((os.path.abspath('../utils')))
import time
import pickle
import numpy as np
import pandas as pd
import InspectData as ID
#from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
#from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold, cross_validate
from Torelli import *

# python URN_classify.py URN2-1S TE 3e-7
# python URN_classify.py URN2-1S TW 1e-4
# python URN_classify.py URN2-1M TE 1e-7
# python URN_classify.py URN2-1M TW 1e-5
# python URN_classify.py URN2-2S TW 1e-4
# python URN_classify.py URN2-2S TE 1e-7
# python URN_classify.py URN2-2M TW 1e-4
# python URN_classify.py URN2-2M TE 4e-8
# python URN_classify.py URN3-S TE 1e-6
# python URN_classify.py URN3-S TW 1e-4
# python URN_classify.py URN3-M TE 4e-8
# python URN_classify.py URN3-M TW 1e-4
# python URN_classify.py URN4-S TE 3e-7
# python URN_classify.py URN4-S TW 1e-4
# python URN_classify.py URN4-M TE 4e-8
# python URN_classify.py URN4-M TW 1e-4


data_name = str(sys.argv[1])
kernel_name = str(sys.argv[2])
gamma = float(sys.argv[3])

data_path = os.path.join(os.path.abspath('../data/URN/'),data_name)

# load dataset
with open(data_path+'networks.pkl', 'rb') as f:
    graphs = pickle.load(f)
with open(data_path+'labels.pkl', 'rb') as f:
    labels = pickle.load(f)

ID.init_labels(graphs)

# genus bound
stats = ID.show_stats(graphs)
nodes = stats['num_nodes']
edges = stats['num_edges']
genus = stats['genus']
ave_nodes = np.mean(nodes)
ave_edges = np.mean(edges)
ave_genus = np.mean(genus)
std_nodes = np.std(nodes)
std_edges = np.std(edges)
std_genus = np.std(genus)

# compute kernel matrix

if kernel_name == 'TW':
    kernel = TorelliWasserstein(dimbound=np.floor(ave_genus+std_genus).astype(int),normalize=False)
else:
    kernel = TorelliEuclidean(dimbound=np.floor(ave_genus+std_genus).astype(int),normalize=False)
t0 = time.time()
kernel.fit(graphs)
t1 = time.time()
K = kernel.fit_transform(graphs)
t2 = time.time()
D = np.outer(np.diag(K), np.ones(K.shape[0])) + np.outer(np.ones(K.shape[0]), np.diag(K)) - 2*K

# apply gaussian    
K_gaussian = np.exp(- gamma * D)
# normalize the kernel
normalize_diag = np.diag(np.diag(K_gaussian)**(-0.5))
K_classify = normalize_diag @ K_gaussian @ normalize_diag  
    
# save the kernel matrix
path = os.path.abspath('../results/URN/')
np.save(os.path.join(path,data_name+kernel_name+'kernel.npy'), K_classify)
print('kernel matrix saved')

# cross-validation
clf = SVC(kernel='precomputed',max_iter=10000)
kf = KFold(n_splits=10, shuffle=True, random_state=42)
results = cross_validate(clf, K_gaussian, labels, cv=kf, scoring='accuracy', return_estimator=True, n_jobs=-1)
acc = np.mean(results['test_score'])
std = np.std(results['test_score'])

# store results
csv_file = os.path.join(path, data_name+kernel_name+".csv")

kernel_acc_key = f'{kernel_name}_acc'
kernel_std_key = f'{kernel_name}_std'
kernel_fit_time_key = f'{kernel_name}_fit_time'
kernel_trans_time_key = f'{kernel_name}_trans_time'

try:
    df = pd.read_csv(csv_file, index_col=0)
except FileNotFoundError:
    df = pd.DataFrame()

if data_name in df.index:
    df.loc[data_name, 'ave_node'] = round(ave_nodes, 2)
    df.loc[data_name, 'std_node'] = round(std_nodes, 2)
    df.loc[data_name, 'ave_edge'] = round(ave_edges, 2)
    df.loc[data_name, 'std_edge'] = round(std_edges, 2)
    df.loc[data_name, 'ave_genus'] = round(ave_genus, 2)
    df.loc[data_name, 'std_genus'] = round(std_genus, 2)
    df.loc[data_name, kernel_acc_key] = round(acc*100, 2)
    df.loc[data_name, kernel_std_key] = round(std*100, 2)
    df.loc[data_name, kernel_fit_time_key] = round(t1-t0, 4)
    df.loc[data_name, kernel_trans_time_key] = round(t2-t1, 4)
else:
    row = pd.DataFrame({
        'ave_node': round(ave_nodes, 2),
        'std_node': round(std_nodes, 2),
        'ave_edge': round(ave_edges, 2),
        'std_edge': round(std_edges, 2),
        'ave_genus': round(ave_genus, 2),
        'std_genus': round(std_genus, 2),
        kernel_acc_key: round(acc*100, 2),
        kernel_std_key: round(std*100, 2),
        kernel_fit_time_key: round(t1-t0, 4),
        kernel_trans_time_key: round(t2-t1, 4)
    }, index=[data_name])
    df = pd.concat([df, row], axis=0)

df.to_csv(csv_file, mode='a', encoding='utf-8', errors='ignore')