import numpy as np
import torch as th
from data_handler import dataloader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split 
import os
import pandas as pd
from sys import argv
import pylab as pl
from TFGW_raw import FGWmachine
import pickle
from multiprocessing import cpu_count


#os.environ["CUDA_VISIBLE_DEVICES"]=""
# %% parser for GW measure machines experiments


"""
graph_mode: input representation of the graphs ['ADJ','SP']
list_Katoms: number of templates
list_alphas: Default to [-1] for learning the alphas, other float values can be given between 0 and 1 
list_batch: batch size to consider for the experiment validated in {32,128}
list_split_trainval: list of seeds corresponding to each fold of the CV taken in {0, ..., 9}
learn_hbar: either to learn the masses or not (see subsection 3.2 and 3.3 of the main paper)
list_dropout: validated values for the dropout ratios (see supplementary), taking values in {0, 0.2, 0.5}
"""
"""
dataset_name = str(argv[1])
assert dataset_name in ['imdb-b', 'imdb-m', 'mutag', 'ptc', 'nci1', 'bzr', 'cox2', 'enzymes', 'protein']
graph_mode =str(argv[2]) # ['ADJ','SP'...]
if '[' in argv[3]: # number of graph templates
    list_Katoms = [int(x) for x in argv[3][1:-1].split(',')] 
else:
    list_Katoms = [int(argv[3])]
if '[' in argv[4]: # alphas for the FGW loss > if alpha = -1, it will be considered as a learnable parameter initialized at 0.5
    list_alphas = [float(x) for x in argv[4][1:-1].split(',')] 
else:
    list_alphas = [float(argv[4])]

if '[' in argv[5]:
    list_batch= [int(x) for x in argv[5][1:-1].split(',') ]
else:
    list_batch= [int(argv[5])]
if '[' in argv[6]: # random seed for initializating the atoms.
    list_split_trainval = [int(x) for x in argv[6][1:-1].split(',')] 
else:
    list_split_trainval = [int(argv[6])]

learn_hbar = bool(int(argv[7]))
if '[' in argv[8]: 
    list_dropout = [float(x) for x in argv[8][1:-1].split(',')] 
else:
    list_dropout = [float(argv[8])]
"""
# %%

dataset_name ='mutag'
graph_mode = 'ADJ'
list_Katoms = [4, 8, 12, 16]
if dataset_name in ['mutag']:
    """
    exception to initialize the samples with shapes the closest possible to the median
    ensuring that for each class there exists in the train dataset a sample with the required shape for templates
    (this detail is developed in the supplementary material)
    """
    init_mode_atoms = 'sampling_supervised_max'
else:
    init_mode_atoms = 'sampling_supervised_median'
    
lr = 0.01
list_batch = [128]
list_alphas = [-1]
epochs = 500
list_split_trainval = [0]
learn_hbar = True
list_dropout = [0.0]

if graph_mode=='SP':
    atoms_projection = 'nonnegative'
else:
    atoms_projection = 'clipped'
str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                 'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency'}
abspath = os.path.abspath('../')
res_repo = abspath+'/results/%s/'%dataset_name
data_path = abspath+'/real_datasets/'


algo_seed = 0
split_traintest_seed = 0
val_timestamp = 5#5
track_templates = True
dtype = th.float64

verbose= True
n_jobs = 3
device = 'cpu'
hidden_dim_clf = 128
num_hidden_clf = 2
features_trans = None

for batch_size in list_batch:
    for dropout in list_dropout:
        for split_trainval_seed in list_split_trainval:
            for alpha in list_alphas:
                for Katoms in list_Katoms:
                    if (dataset_name == 'mutag') and (batch_size == 128):
                        supervised_sampler = False
                    else:
                        supervised_sampler = True
                    if supervised_sampler:
                        str_batch=''
                    else:
                        str_batch='random'

                    #try:
                    th.manual_seed(algo_seed)
                    np.random.seed(algo_seed)
                    init_graph_str = {'sampling_supervised_median':'samplabelmed','sampling_supervised_modes':'samplabelmod',
                                    'sampling_supervised_max':'samplabelmax'}
                    proj_str  = {'clipped':'', 'nonnegative':'_nonnegative_'}
                    if learn_hbar:
                        learn_hbar_str = 'learnhbarbis'
                    else:
                        learn_hbar_str = ''
                    clf_net_dict = {'hidden_dim':hidden_dim_clf,
                                    'num_hidden':num_hidden_clf,
                                    'dropout':dropout}
                    clf_str = 'clfMLPh%sL%sdropout%s'%(hidden_dim_clf, num_hidden_clf, dropout)
                    if features_trans is None:
                        str_features_trans= ''
                    else:
                        str_features_trans = features_trans
                    if alpha == -1:  # learnable alpha
                        experiment_name = '/V5%s%s%s_%sK%s%s_%s_alphalearnable_lr%s_%sbatch%s_ep%s_splitseed%s/'%(
                            graph_mode, learn_hbar_str, str_features_trans, init_graph_str[init_mode_atoms],
                            Katoms, proj_str[atoms_projection], clf_str, lr, str_batch, batch_size, epochs, split_traintest_seed)
                    elif (alpha > 0.) and (alpha <1):
                        experiment_name = '/V5%s%s%s_%sK%s%s_%s_alpha%s_lr%s_%sbatch%s_ep%s_splitseed%s/'%(
                            graph_mode, learn_hbar_str, str_features_trans, init_graph_str[init_mode_atoms], 
                            Katoms, proj_str[atoms_projection], clf_str, alpha, lr, str_batch, batch_size, epochs, split_traintest_seed)
                    else:
                        raise 'alpha= %s / not supported '
                    experiment_repo = res_repo + experiment_name
                    model_name = 'model_splitvalseed%s'%split_trainval_seed
                    print('experiment_repo:', experiment_repo)
                    print('model_name:', model_name)
                    train_model = True
                    if os.path.exists(experiment_repo):
                        print('already existing repository')
                        try:
                            training_log = pickle.load(open(experiment_repo+'/%s_training_log.pkl'%model_name,'rb'))
                            print('training log found / saved epochs:', len(training_log['train_cumulated_batch_loss']))
                            if (len(training_log['train_cumulated_batch_loss']) == epochs):
                                train_model = False
                            else:
                                train_model = True
                        except:
                            train_model = True
                    print('train model? ', train_model)

                    if train_model:
                        # Load graphs and corresponding labels for clustering benchmark
                        if dataset_name in ['mutag', 'ptc', 'nci1']:  # One-hot encoding
                            one_hot = True
                            standardized_features = False
                        elif dataset_name in ['enzymes', 'protein', 'bzr', 'cox2']:
                            one_hot = False
                            standardized_features = True
                        elif dataset_name in ['imdb-b', 'imdb-m']:
                            one_hot = False
                            standardized_features = False
                        X,labels=dataloader.load_local_data(data_path,dataset_name, one_hot=one_hot)                
                        unique_labels = np.unique(labels)
                        #print('unique_labels to fit to [0,C-1] :', unique_labels)
                        unprocessed_labels = np.array(labels)
                        new_labels = np.zeros_like(unprocessed_labels)
                        for idx_y, y in enumerate(unique_labels):
                            idx_samples = np.argwhere(unprocessed_labels==y)[:,0]
                            new_labels[idx_samples] = idx_y
                        labels = new_labels
                        unique_labels = np.unique(labels)
                        #print('unique labels after processing: ', unique_labels)
                        if graph_mode in ['ADJ', 'embADJ']:
                            graphs = [th.tensor(X[t].distance_matrix(method=str_to_method['ADJ']),dtype=dtype) for t in range(X.shape[0])]
                        elif graph_mode in ['SP', 'embSP']:
                            graphs = [th.tensor(X[t].distance_matrix(method=str_to_method['SP']),dtype=dtype) for t in range(X.shape[0])]
                        if not dataset_name in ['imdb-b', 'imdb-m']:
                            features= [th.tensor(np.array(X[t].values()), device=device, dtype=dtype) for t in range(X.shape[0])]                        
                            
                        else:
                            if features_trans is None: # one-hot encoding of node degrees
                                degs = [C.sum(0).to(th.int64) for C in graphs]
                                
                                max_deg = th.tensor([deg.max() for deg in degs]).max().item()                                    
                                min_deg = th.tensor([deg.min() for deg in degs]).min().item()
                                diff_deg = max_deg - min_deg
                                features = []
                                for deg in degs:
                                    N = deg.shape[0]
                                    F = th.zeros((N, diff_deg + 1), device=device, dtype=dtype)
                                    for i in range(N):
                                        F[i, deg[i].item() - min_deg ] = 1.
                                    features.append(F.to(device))
                                standardized_features = False
                            
                            elif features_trans == 'degree':
                                features = [C.sum(0)[:, None].to(device=device, dtype=dtype) for C in graphs]
                                standardized_features = True
        
                            elif features_trans == 'normalizeddegree':
                                degs = [C.sum(0).to(device=device, dtype=dtype) for C in graphs]
                                features = [d[:,None]/d.sum() for d in degs]
                                standardized_features = True
                            elif features_trans == 'ones':                            
                                features = [th.ones((C.shape[0], 1), device=device, dtype=dtype) for C in graphs]
                                standardized_features = False
                        
                        if standardized_features:
                            print('stardardizing features')
                            stacked_features = features[0].mean(axis=0).numpy()
                            for F in features[1:]:
                                mean_F = F.mean(axis=0).numpy()
                                stacked_features = np.vstack([stacked_features, mean_F])
                            for i in range(stacked_features.shape[1]):
                                mean_ = stacked_features[:, i].mean()
                                std_ = stacked_features[:, i].std()
                                for F in features:
                                    F[:, i] = (F[:, i] - mean_)/std_
                            
                        masses = [th.ones(C.shape[0], dtype=dtype, device=device)/C.shape[0] for C in graphs]                            
                        shapes = [C.shape[0] for C in graphs]
                        dataset_size = len(graphs)
                        # Instantiate the class for FGW measure machine
                        #split dataset
                        idx_train, idx_test, y_train, y_test = train_test_split(np.arange(len(graphs)), labels, test_size=0.1, stratify=labels, random_state=split_traintest_seed)
                        X_train, X_test = [graphs[idx] for idx in idx_train], [graphs[idx] for idx in idx_test]
                        F_train, F_test = [features[idx] for idx in idx_train], [features[idx] for idx in idx_test]
                        idx_subtrain, idx_val, y_subtrain, y_val = train_test_split(np.arange(len(X_train)),y_train, test_size=0.1, stratify=y_train, random_state=split_trainval_seed)
                        X_subtrain, X_val = [X_train[idx] for idx in idx_subtrain], [X_train[idx] for idx in idx_val] 
                        F_subtrain, F_val = [F_train[idx] for idx in idx_subtrain], [F_train[idx] for idx in idx_val] 
                        
                        if ('median' in init_mode_atoms):
                            if dataset_name in ['mutag']:
                                Satoms = int(np.median(shapes))
                            else:
                                Satoms = round(np.median(shapes))
                            list_Satoms = [Satoms] * Katoms
                            print('median atom shapes:', list_Satoms)
                        elif 'max' in init_mode_atoms:
                            if dataset_name in ['mutag']:
                                list_Satoms = [16] * Katoms
                            else:
                                raise 'not supported'
                        else:
                            raise 'unknown init_mode_atoms: %s'%init_mode_atoms
                        
                        model = FGWmachine(
                            Katoms=Katoms, 
                            n_labels= unique_labels.shape[0], 
                            alpha=alpha, 
                            experiment_repo=experiment_repo, 
                            learn_hbar = learn_hbar, 
                            clf_net_dict=clf_net_dict,
                            dtype=dtype,
                            device=device)
                        
                        if 'sampling_supervised' in init_mode_atoms:
                            model.init_parameters(list_Satoms=list_Satoms, init_mode_atoms='sampling_supervised', graphs=X_subtrain, features=F_subtrain, labels=th.tensor(y_subtrain,dtype=th.long), atoms_projection=atoms_projection, verbose=verbose)
                        else:
                            raise 'init_mode_atoms =%s / not supported'%init_mode_atoms
                        
                        model.fit(model_name=model_name, X_train=X_subtrain, F_train=F_subtrain, y_train=th.tensor(y_subtrain, dtype=th.long),
                                  X_val=X_val, F_val=F_val, y_val=th.tensor(y_val, dtype=th.long),
                                  X_test=X_test, F_test=F_test, y_test=th.tensor(y_test, dtype=th.long),
                                  atoms_projection=atoms_projection, lr=lr, batch_size=batch_size, 
                                  supervised_sampler=supervised_sampler, epochs=epochs, val_timestamp=val_timestamp, 
                                  algo_seed=algo_seed, track_templates=track_templates, verbose=verbose, n_jobs=n_jobs)
                  