import numpy as np
from data_handler import dataloader
import os
from sys import argv
from Kernels import FGW_kernels

os.environ["CUDA_VISIBLE_DEVICES"]=""
# %% parser for FGW kernels
"""
# python run_FGWkernels.py 'mutag' 'ADJ' [0.5] 3

dataset_name = str(argv[1])
assert dataset_name in ['imdb-b', 'imdb-m', 'mutag', 'ptc', 'nci1', 'bzr', 'cox2', 'enzymes', 'protein', 'colab']
graph_mode =str(argv[2]) # ['ADJ','SP'...]
if '[' in argv[3]: 
    list_alphas = [float(x) for x in argv[3][1:-1].split(',')] 
else:
    list_alphas = [float(argv[3])]
n_jobs = int(argv[4])
try:
    features_mode = str(argv[5])
except:
    features_mode = None
"""
# %%

n_jobs = 3
dataset_name ='mutag'
graph_mode = 'ADJ'
dist_features = 'euclidean'
list_alphas = [0.2] 
# validated values from truncated logspace :[0., 0.000001, 0.0003, 0.005, 0.1, 0.25, 0.5, 0.75, 0.9, 0.995, 0.9997, 0.999999, 1.]

features_mode = None

str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path'}
abspath = os.path.abspath('../')
res_repo = abspath+'/kernel_results/%s/'%dataset_name
data_path = abspath+'/real_datasets/'

for alpha in list_alphas:
    if not dataset_name in ['imdb-b', 'imdb-m', 'colab']:
        experiment_name = '/FGWkernel_%s_alpha%s_dist%s'%(graph_mode, alpha, dist_features)
    else:
        assert features_mode in ['degree', 'ones', 'onehot']
        experiment_name = '/FGWkernel_%s%s_alpha%s_dist%s'%(graph_mode, features_mode, alpha, dist_features)
            
    experiment_repo = res_repo + experiment_name
    if not os.path.exists(experiment_repo):
        os.makedirs(experiment_repo)
    if dataset_name in ['mutag', 'ptc', 'nci1']:  # One-hot encoding
        one_hot = True
        standardized_features = False
    elif dataset_name in ['enzymes', 'protein', 'bzr', 'cox2']:
        one_hot = False
        standardized_features = True
    elif dataset_name in ['imdb-b', 'imdb-m']:
        one_hot = False
        standardized_features = False
    X,labels=dataloader.load_local_data(data_path, dataset_name, one_hot=one_hot)                               
    unique_labels = np.unique(labels)
    unprocessed_labels = np.array(labels)
    new_labels = np.zeros_like(unprocessed_labels)
    for idx_y, y in enumerate(unique_labels):
        idx_samples = np.argwhere(unprocessed_labels==y)[:,0]
        new_labels[idx_samples] = idx_y
    labels = new_labels
    unique_labels = np.unique(labels)
    graphs = [X[t].distance_matrix(method=str_to_method[graph_mode]) for t in range(X.shape[0])]
    masses = [np.ones(C.shape[0])/C.shape[0] for C in graphs]
    if not dataset_name in ['imdb-b', 'imdb-m', 'colab']:
        features= [np.array(X[t].values()) for t in range(X.shape[0])]
    
    else:
        if not alpha == 0:
            if features_mode == 'degree':    
                features = [C.sum(axis=0) for C in graphs]
                standardized_features = True
            elif features_mode == 'ones':
                features = [np.ones(C.shape[0], dtype=np.float64) for C in graphs]
                standardized_features = False
            elif features_mode == 'onehot':
                degs = [C.sum(0) for C in graphs]
                max_deg = np.max([deg.max() for deg in degs]).astype(np.int64)                                        
                min_deg = np.min([deg.min() for deg in degs]).astype(np.int64)
                diff_deg = max_deg - min_deg
                features = []
                for deg in degs:
                    N = deg.shape[0]
                    F = np.zeros((N, diff_deg + 1), dtype=np.float64)
                    for i in range(N):
                        F[i, deg[i].item() - min_deg ] = 1.
                    features.append(F)
                standardized_features = False

    if standardized_features:
        print('stardardizing features')
        stacked_features = features[0].mean(axis=0).numpy()
        for F in features[1:]:
            mean_F = F.mean(axis=0).numpy()
            stacked_features = np.vstack([stacked_features, mean_F])
        for i in range(stacked_features.shape[1]):
            mean_ = stacked_features[:, i].mean()
            std_ = stacked_features[:, i].std()
            for F in features:
                F[:, i] = (F[:, i] - mean_)/std_
        
    if not os.path.exists(experiment_repo+'/pairwise_distance_matrix.npy'):
        
        D = FGW_kernels.FGW_matrix_parallel(graphs, features, masses, alpha, dist = dist_features, dtype=np.float64, n_jobs = 3)
        np.save(experiment_repo+'/pairwise_distance_matrix.npy', D)
    else:
        D = np.load(experiment_repo+'/pairwise_distance_matrix.npy')
    if not os.path.exists(experiment_repo+'/res_SVC.csv'):
    
        res_svc, full_res_svc = FGW_kernels.classification_SVC(D, labels, 10, 10)
        res_svc.to_csv(experiment_repo+'/res_SVC.csv')
        full_res_svc.to_csv(experiment_repo+'/full_res_SVC.csv')
        