import numpy as np
import torch as th
from data_handler import dataloader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split 
import os
import pandas as pd
from sys import argv
from OT_GNN import OT_GNN
import pickle
from multiprocessing import cpu_count
from time import time
#os.environ["CUDA_VISIBLE_DEVICES"]=""
# %% parser for GW measure machines experiments

# To un-quote if you want to use a parser"""
"""
dataset_name = str(argv[1])
assert dataset_name in ['imdb-b', 'imdb-m', 'mutag', 'ptc', 'nci1', 'bzr', 'cox2', 'enzymes', 'protein']
if '[' in argv[2]: # number of graph templates
    list_Katoms = [int(x) for x in argv[2][1:-1].split(',')] 
else:
    list_Katoms = [int(argv[2])]
print('list_Katoms:', list_Katoms)

if '[' in argv[3]:
    list_batch= [int(x) for x in argv[3][1:-1].split(',') ]
else:
    list_batch= [int(argv[3])]
print('list_batch:', list_batch)
if '[' in argv[4]: 
    list_split_trainval = [int(x) for x in argv[4][1:-1].split(',')] 
else:
    list_split_trainval = [int(argv[4])]
print('list_split_trainval:', list_split_trainval)

if '[' in argv[5]:
    list_dropout = [float(x) for x in argv[5][1:-1].split(',') ]
else:
    list_dropout = [float(argv[5])]
if '[' in argv[6]:
    hidden_dim_netgin = [int(x) for x in argv[6][1:-1].split(',') ]
else:
    hidden_dim_netgin = [int(argv[6])]

num_layers_gin = int(argv[7])

force_cpu = bool(int(argv[8]))
try:
    cpu_njobs = int(argv[9]) #number of cpus to parallelize computation of Wass distances using Frank-Wolfe algorithm from POT
except:
    cpu_njobs = None
try:
    features_trans = str(argv[10])
except:
    features_trans = None
"""
#%%

dataset_name ='mutag'
graph_mode = 'ADJ'
list_Katoms = [16]
lr = 0.01
use_lrschedule = False
if dataset_name in ['mutag', 'ptc', 'protein', 'nci1', 'enzymes']:
    list_hidden_dim_netgin = [16, 32]
elif dataset_name in ['imdb-b', 'imdb-m', 'collab']:
    list_hidden_dim_netgin = [64]
num_layers_gin = 4
list_batch = [128]
epochs = 500
list_split_trainval = [0,1,2,3,4,5,6,7,8,9]
sizes_scaling = False
if dataset_name in ['mutag']:
    """
    exception to initialize the samples with shapes the closest possible to the median
    ensuring that for each class there exists in the train dataset a sample with the required shape for templates
    (this detail is developed in the supplementary material)
    """
    init_mode_atoms = 'sampling_supervised_max'
else:
    init_mode_atoms = 'sampling_supervised_median'

list_dropout = [0.0]
skip_first_features = False

force_cpu = True
cpu_njobs = 3
if force_cpu:
    os.environ["CUDA_VISIBLE_DEVICES"]=""
    device = 'cpu'
else:
    if th.cuda.is_available():
        device = th.device('cuda:0')
    else:
        device = 'cpu'

print('device:', device)
features_trans = None

str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                 'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency'}

abspath = os.path.abspath('../')
res_repo = abspath+'/results_OTGIN/%s/'%dataset_name
data_path = abspath+'/real_datasets/'
if not use_lrschedule:
    str_lrschedule = ''
else:
    str_lrschedule = 'scheduled'

algo_seed = 0
split_traintest_seed = 0
val_timestamp = 5#5
track_templates = True

dtype = th.float64
verbose= True

eps_gin = 0.
train_eps_gin = False
num_hidden_netgin = 1
hidden_dim_clf = 128
num_hidden_clf = 2

graph_mode = 'ADJ'
if cpu_njobs is None:
    n_jobs = cpu_count() - 2
else:
    n_jobs = cpu_njobs
skip_first_features  = False

dtype = th.float64

for batch_size in list_batch:
    for dropout in list_dropout:
        for hidden_dim_netgin in list_hidden_dim_netgin:
            for split_trainval_seed in list_split_trainval:
                for Katoms in list_Katoms:
                    if (dataset_name == 'mutag') and (batch_size == 128):
                        supervised_sampler = False
                    else:
                        supervised_sampler = True

                    if supervised_sampler:
                        str_batch=''
                    else:
                        str_batch='random'
                    str_skip_first_features = ''
                    if skip_first_features :
                        str_skip_first_features = 'skipfl' 
                    if sizes_scaling:
                        str_sizes_scaling = 'sizesscaling'
                    else:
                        str_sizes_scaling = ''
                    if features_trans is None:
                        str_features_trans= ''
                    else:
                        str_features_trans = features_trans
                    #try:
                    output_dim_netgin = hidden_dim_netgin
                    th.manual_seed(algo_seed)
                    np.random.seed(algo_seed)
                    init_graph_str = {'sampling_supervised_median':'samplabelmed','sampling_supervised_modes':'samplabelmod',
                                    'sampling_supervised_max':'samplabelmax'}
                    proj_str  = {'clipped':'', 'nonnegative':'_nonnegative_'}
                    
                    gin_net_dict = {'hidden_dim':hidden_dim_netgin,
                                    'num_hidden':num_hidden_netgin}
                    
                    gin_layer_dict = {'num_layers':num_layers_gin}
                    
                    gnn_str = 'GINL%s%s_MLPh%so%sL%sbatchnorm_'%(num_layers_gin, str_skip_first_features , hidden_dim_netgin, hidden_dim_netgin, num_hidden_netgin)
                    clf_net_dict = {'hidden_dim':hidden_dim_clf,
                                    'num_hidden':num_hidden_clf,
                                    'dropout':dropout}
                    clf_str = 'clfMLPh%sL%sdropout%s'%(hidden_dim_clf, num_hidden_clf, dropout)
                    
                    
                    experiment_name = '/OTGIN%s_%s_aggregf%sdistK%s%s_%s_%slr%s_%sbatch%s_ep%s_splitseed%s/'%(
                            str_features_trans, gnn_str, Katoms, 
                            init_graph_str[init_mode_atoms], str_sizes_scaling,
                            clf_str, str_lrschedule, lr, str_batch, batch_size, epochs, split_traintest_seed)
                    
                    experiment_repo = res_repo + experiment_name
                    model_name = 'model_splitvalseed%s'%split_trainval_seed
                    print('experiment_repo:', experiment_repo)
                    print('model_name:', model_name)
                    train_model = True
                    if os.path.exists(experiment_repo):
                        print('already existing repository')
                        try:
                            training_log = pickle.load(open(experiment_repo+'/%s_training_log.pkl'%model_name,'rb'))
                            print('training log found / saved epochs:', len(training_log['train_cumulated_batch_loss']))
                            if (len(training_log['train_cumulated_batch_loss'])  == epochs):
                                print('experiment already completed')
                                train_model = False
                            else:
                                train_model = True
                        except:
                            train_model = True
                    print('train model? ', train_model)

                    if train_model:
                        
                        # Load graphs and corresponding labels for clustering benchmark
                        if dataset_name in ['mutag', 'ptc', 'nci1']:  # One-hot encoding
                            one_hot = True
                            standardized_features = False
                        elif dataset_name in ['enzymes', 'protein', 'bzr', 'cox2']:
                            one_hot = False
                            standardized_features = True
                        elif dataset_name in ['imdb-b', 'imdb-m']:
                            one_hot = False
                            standardized_features = False
                        X,labels=dataloader.load_local_data(data_path,dataset_name, one_hot=one_hot)                
                        unique_labels = np.unique(labels)
                        #print('unique_labels to fit to [0,C-1] :', unique_labels)
                        unprocessed_labels = np.array(labels)
                        new_labels = np.zeros_like(unprocessed_labels)
                        for idx_y, y in enumerate(unique_labels):
                            idx_samples = np.argwhere(unprocessed_labels==y)[:,0]
                            new_labels[idx_samples] = idx_y
                        labels = new_labels
                        unique_labels = np.unique(labels)
                        #print('unique labels after processing: ', unique_labels)
                        graphs = [th.tensor(X[t].distance_matrix(method=str_to_method[graph_mode]), device=device, dtype=dtype) for t in range(X.shape[0])]
                        # Add self-loops for node features propagation
                        graphs = [C + th.eye(C.shape[0], device=device, dtype=dtype) for C in graphs]
                        # Self-loops are then omitted in the optimization of FGW templates
                        masses = [th.ones(C.shape[0], dtype=dtype, device=device)/C.shape[0] for C in graphs]                            
                        shapes = [C.shape[0] for C in graphs]
                        
                        if not dataset_name in ['imdb-b', 'imdb-m', 'collab']:
                            features= [th.tensor(np.array(X[t].values()), dtype=dtype, device=device) for t in range(X.shape[0])]                        
                            
                        else:
                            if features_trans is None: # one-hot encoding of node degrees
                                degs = [C.sum(0).to(th.int64) for C in graphs]
                                
                                max_deg = th.tensor([deg.max() for deg in degs]).max().item()                                    
                                min_deg = th.tensor([deg.min() for deg in degs]).min().item()
                                diff_deg = max_deg - min_deg
                                features = []
                                for deg in degs:
                                    N = deg.shape[0]
                                    F = th.zeros((N, diff_deg + 1), device=device, dtype=dtype)
                                    for i in range(N):
                                        F[i, deg[i].item() - min_deg ] = 1.
                                    features.append(F.to(device))
                                standardized_features = False
                            
                            elif features_trans == 'degree':
                                features = [C.sum(0)[:, None].to(device=device, dtype=dtype) for C in graphs]
                                standardized_features = True
        
                            elif features_trans == 'normalizeddegree':
                                degs = [C.sum(0).to(device=device, dtype=dtype) for C in graphs]
                                features = [d[:,None]/d.sum() for d in degs]
                                standardized_features = True
                            elif features_trans == 'ones':                            
                                features = [th.ones((C.shape[0], 1), device=device, dtype=dtype) for C in graphs]
                                standardized_features = False
                        if standardized_features:
                            print('stardardizing features')
                            stacked_features = features[0].mean(axis=0).cpu().numpy()
                            for F in features[1:]:
                                mean_F = F.mean(axis=0).cpu().numpy()
                                stacked_features = np.vstack([stacked_features, mean_F])
                            for i in range(stacked_features.shape[1]):
                                mean_ = stacked_features[:, i].mean()
                                std_ = stacked_features[:, i].std()
                                for F in features:
                                    F[:, i] = (F[:, i] - mean_)/std_
                        
                        dataset_size = len(graphs)
                        input_shape = features[0].shape[-1]
                        # Instantiate the class for GW measure machine
                        model = OT_GNN.OT_GIN(
                            input_shape = input_shape,
                            Katoms = Katoms, 
                            n_labels = unique_labels.shape[0], 
                            experiment_repo = experiment_repo, 
                            gin_net_dict = gin_net_dict,
                            gin_layer_dict = gin_layer_dict,
                            clf_net_dict = clf_net_dict,
                            skip_first_features =skip_first_features , 
                            sizes_scaling = sizes_scaling,
                            dtype = dtype,
                            device = device)
                        #split dataset
                        idx_train, idx_test, y_train, y_test = train_test_split(np.arange(dataset_size), labels, test_size=0.1, stratify=labels, random_state=split_traintest_seed)
                        X_train, X_test = [graphs[idx] for idx in idx_train], [graphs[idx] for idx in idx_test]
                        F_train, F_test = [features[idx] for idx in idx_train], [features[idx] for idx in idx_test]
                        idx_subtrain, idx_val, y_subtrain, y_val = train_test_split(np.arange(len(X_train)),y_train, test_size=0.1, stratify=y_train, random_state=split_trainval_seed)
                        X_subtrain, X_val = [X_train[idx] for idx in idx_subtrain], [X_train[idx] for idx in idx_val] 
                        F_subtrain, F_val = [F_train[idx] for idx in idx_subtrain], [F_train[idx] for idx in idx_val] 
                        
                        if ('median' in init_mode_atoms):
                            if dataset_name in ['mutag']:
                                Satoms = int(np.median(shapes))
                            else:
                                Satoms = round(np.median(shapes))
                            list_Satoms = [Satoms] * Katoms
                            print('median atom shapes:', list_Satoms)
                        elif 'max' in init_mode_atoms:
                            if dataset_name in ['mutag']:
                                list_Satoms = [16] * Katoms
                            else:
                                raise 'not supported'
                        else:
                            raise 'unknown init_mode_atoms: %s'%init_mode_atoms
                        
                        model.init_parameters_with_aggregation(
                            list_Satoms=list_Satoms, init_mode_atoms=init_mode_atoms, labels=th.tensor(y_subtrain, device=device, dtype=th.long),
                            graphs=X_subtrain, features= F_subtrain)
                            
                        
                        model.fit(model_name=model_name, X_train=X_subtrain, F_train=F_subtrain, y_train=th.tensor(y_subtrain, device=device, dtype=th.long),
                                  X_val=X_val, F_val=F_val, y_val=th.tensor(y_val, device=device, dtype=th.long),
                                  X_test=X_test, F_test=F_test, y_test=th.tensor(y_test, device=device, dtype=th.long),
                                  lr=lr, batch_size=batch_size, 
                                  supervised_sampler=supervised_sampler, epochs=epochs, val_timestamp=val_timestamp, use_lrschedule = use_lrschedule,
                                  algo_seed=algo_seed, track_templates=track_templates, verbose=verbose, n_jobs=n_jobs)
                        
                    