import numpy as np
import torch as th
from data_handler import dataloader
from sklearn.model_selection import train_test_split 
import os
from sys import argv
from GIN import GIN
import pickle


#os.environ["CUDA_VISIBLE_DEVICES"]=""
# %% parser for GW measure machines experiments

# python run_GIN.py 'mutag' [128] [0,1,2,3,4,5,6,7,8,9] [0.0] 16 4 1

# To un-quote if you want to use a parser
"""
dataset_name = str(argv[1])
assert dataset_name in ['mutag', 'ptc', 'nci1', 'enzymes', 'protein', 'imdb-b', 'imdb-m', 'collab']

if '[' in argv[2]:
    list_batch= [int(x) for x in argv[2][1:-1].split(',') ]
else:
    list_batch= [int(argv[2])]

if '[' in argv[3]: # random seed for initializating the atoms.
    list_split_trainval = [int(x) for x in argv[3][1:-1].split(',')] 
else:
    list_split_trainval = [int(argv[3])]

if '[' in argv[4]:
    list_dropout = [float(x) for x in argv[4][1:-1].split(',') ]
else:
    list_dropout = [float(argv[4])]
if '[' in argv[5]:
    list_hidden_dim_netgin = [int(x) for x in argv[5][1:-1].split(',') ]
else:
    list_hidden_dim_netgin = [int(argv[5])]

num_layers_gin = int(argv[6])
force_cpu = bool(int(argv[7]))
try:
    features_trans = str(argv[8])
except:
    features_trans = None
"""

#%%
# To code if you want to use the parser above
#"""
dataset_name ='mutag'
features_trans = None#'degree'
use_lrschedule = True
if dataset_name in ['mutag', 'ptc', 'protein', 'nci1', 'enzymes']:
    list_hidden_dim_netgin = [16, 32]
elif dataset_name in ['imdb-b', 'imdb-m', 'collab']:
    list_hidden_dim_netgin = [64]
num_layers_gin = 4
list_batch = [128]
epochs = 350
list_split_trainval = [0]#[0,1,2,3,4,5,6,7,8,9]
list_dropout = [0.0]

num_hidden_netgin = 1
hidden_dim_clf = 128
num_hidden_clf = 2

lr = 0.01
dtype = th.float64
epochs = 500
skip_first_features = False

force_cpu = True
#"""
if force_cpu:
    os.environ["CUDA_VISIBLE_DEVICES"]=""
    device = 'cpu'
else:
    if th.cuda.is_available():
        device = th.device('cuda:0')
    else:
        device = 'cpu'

print('device:', device)
abspath = os.path.abspath('../')
res_repo = abspath+'/results_GIN/%s/'%dataset_name
data_path = abspath+'/real_datasets/'
use_lrschedule = True
str_lrschedule = 'scheduled'

algo_seed = 0
split_traintest_seed = 0
val_timestamp = 5
dtype = th.float64
verbose= True


for batch_size in list_batch:
    for dropout in list_dropout:
        for hidden_dim_netgin in list_hidden_dim_netgin:
            for split_trainval_seed in list_split_trainval:
                if (dataset_name == 'mutag') and (batch_size == 128):
                    supervised_sampler = False
                else:
                    supervised_sampler = True
    
                if supervised_sampler:
                    str_batch=''
                else:
                    str_batch='random'
                str_skip_first_features = ''
                if skip_first_features :
                    str_skip_first_features = 'skipfl' 
                if features_trans is None:
                    str_features_trans= ''
                else:
                    str_features_trans = features_trans
                   
                output_dim_netgin = hidden_dim_netgin
                th.manual_seed(algo_seed)
                np.random.seed(algo_seed)
                gin_net_dict = {'hidden_dim':hidden_dim_netgin,
                                'num_hidden':num_hidden_netgin}
                
                gin_layer_dict = {'num_layers':num_layers_gin}
                
                gnn_str = 'GINL%s%s_MLPh%so%sL%sbatchnorm_'%(num_layers_gin, str_skip_first_features , hidden_dim_netgin, hidden_dim_netgin, num_hidden_netgin)
                clf_net_dict = {'hidden_dim':hidden_dim_clf,
                                'num_hidden':num_hidden_clf,
                                'dropout':dropout}
                clf_str = 'clfMLPh%sL%sdropout%s'%(hidden_dim_clf, num_hidden_clf, dropout)
                
                experiment_name = '/GINagg%s_%s_%s_%slr%s_%sbatch%s_ep%s_splitseed%s/'%(
                        str_features_trans, gnn_str, clf_str, str_lrschedule, lr, str_batch, batch_size, epochs, split_traintest_seed)
                
                experiment_repo = res_repo + experiment_name
                model_name = 'model_splitvalseed%s'%split_trainval_seed
                print('experiment_repo:', experiment_repo)
                print('model_name:', model_name)
                train_model = True
                if os.path.exists(experiment_repo):
                    print('already existing repository')
                    try:
                        training_log = pickle.load(open(experiment_repo+'/%s_training_log.pkl'%model_name,'rb'))
                        print('training log found / saved epochs:', len(training_log['train_cumulated_batch_loss']))
                        if (len(training_log['train_cumulated_batch_loss']) == epochs): #safety check
                            print('experiment already completed')
                            train_model = False
                        else:
                            train_model = True
                    except:
                        train_model = True
                print('train model? ', train_model)
    
                if train_model:
                    
                    # Load graphs and corresponding labels for clustering benchmark
                    if dataset_name in ['mutag', 'ptc', 'nci1']:  # One-hot encoding
                        one_hot = True
                        standardized_features = False
                    elif dataset_name in ['enzymes', 'protein']:
                        one_hot = False
                        standardized_features = True
                    elif dataset_name in ['imdb-b', 'imdb-m', 'collab']:
                        one_hot = False
                        standardized_features = False
                    X, labels =dataloader.load_local_data(data_path,dataset_name, one_hot=one_hot)                
                    unique_labels = np.unique(labels)
                    unprocessed_labels = np.array(labels)
                    new_labels = np.zeros_like(unprocessed_labels)
                    for idx_y, y in enumerate(unique_labels):
                        idx_samples = np.argwhere(unprocessed_labels==y)[:,0]
                        new_labels[idx_samples] = idx_y
                    labels = new_labels
                    unique_labels = np.unique(labels)
                    graphs = [th.tensor(X[t].distance_matrix(method='adjacency'), device=device, dtype=dtype) for t in range(X.shape[0])]
                    # Add self-loops for node features propagation
                    graphs = [C + th.eye(C.shape[0], device=device, dtype=dtype) for C in graphs]
                    shapes = [C.shape[0] for C in graphs]
                    
                    if not dataset_name in ['imdb-b', 'imdb-m', 'collab']:
                        features= [th.tensor(np.array(X[t].values()), dtype=dtype, device=device) for t in range(X.shape[0])]                        
                        
                    else:
                        if features_trans is None: # one-hot encoding of node degrees
                            degs = [C.sum(0).to(th.int64) for C in graphs]
                            
                            max_deg = th.tensor([deg.max() for deg in degs]).max().item()                                    
                            min_deg = th.tensor([deg.min() for deg in degs]).min().item()
                            diff_deg = max_deg - min_deg
                            features = []
                            for deg in degs:
                                N = deg.shape[0]
                                F = th.zeros((N, diff_deg + 1), device=device, dtype=dtype)
                                for i in range(N):
                                    F[i, deg[i].item() - min_deg ] = 1.
                                features.append(F.to(device))
                            standardized_features = False
                        
                        elif features_trans == 'degree':
                            features = [C.sum(0)[:, None].to(device=device, dtype=dtype) for C in graphs]
                            standardized_features = True
    
                        elif features_trans == 'normalizeddegree':
                            degs = [C.sum(0).to(device=device, dtype=dtype) for C in graphs]
                            features = [d[:,None]/d.sum() for d in degs]
                            standardized_features = True
                        elif features_trans == 'ones':                            
                            features = [th.ones((C.shape[0], 1), device=device, dtype=dtype) for C in graphs]
                            standardized_features = False
                    if standardized_features:
                        print('stardardizing features')
                        stacked_features = features[0].mean(axis=0).cpu().numpy()
                        for F in features[1:]:
                            mean_F = F.mean(axis=0).cpu().numpy()
                            stacked_features = np.vstack([stacked_features, mean_F])
                        for i in range(stacked_features.shape[1]):
                            mean_ = stacked_features[:, i].mean()
                            std_ = stacked_features[:, i].std()
                            for F in features:
                                F[:, i] = (F[:, i] - mean_)/std_
                        
                    dataset_size = len(graphs)
                    input_shape = features[0].shape[-1]
                    # Instantiate the class for GW measure machine
                    model = GIN.GIN(
                        input_shape = input_shape,
                        n_labels = unique_labels.shape[0], 
                        experiment_repo = experiment_repo, 
                        gin_net_dict = gin_net_dict,
                        gin_layer_dict = gin_layer_dict,
                        clf_net_dict = clf_net_dict,
                        skip_first_features =skip_first_features , 
                        dtype = dtype,
                        device = device)
                    #split dataset
                    idx_train, idx_test, y_train, y_test = train_test_split(np.arange(dataset_size), labels, test_size=0.1, stratify=labels, random_state=split_traintest_seed)
                    X_train, X_test = [graphs[idx] for idx in idx_train], [graphs[idx] for idx in idx_test]
                    F_train, F_test = [features[idx] for idx in idx_train], [features[idx] for idx in idx_test]
                    idx_subtrain, idx_val, y_subtrain, y_val = train_test_split(np.arange(len(X_train)),y_train, test_size=0.1, stratify=y_train, random_state=split_trainval_seed)
                    X_subtrain, X_val = [X_train[idx] for idx in idx_subtrain], [X_train[idx] for idx in idx_val] 
                    F_subtrain, F_val = [F_train[idx] for idx in idx_subtrain], [F_train[idx] for idx in idx_val] 
                    
                    
                    model.fit(model_name=model_name, X_train=X_subtrain, F_train=F_subtrain, y_train=th.tensor(y_subtrain, device=device, dtype=th.long),
                              X_val=X_val, F_val=F_val, y_val=th.tensor(y_val, device=device, dtype=th.long),
                              X_test=X_test, F_test=F_test, y_test=th.tensor(y_test, device=device, dtype=th.long),
                              lr=lr, batch_size=batch_size, 
                              supervised_sampler=supervised_sampler, epochs=epochs, val_timestamp=val_timestamp, use_lrschedule = use_lrschedule,
                              algo_seed=algo_seed, verbose=verbose)
                    