import numpy as np
import torch as th
from data_handler import dataloader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split 
import os
import pandas as pd
from sys import argv
from TFGW_GIN import GIN_FGWmachine
import pickle
from multiprocessing import cpu_count
#os.environ["CUDA_VISIBLE_DEVICES"]=""
# %% parser for TFGW GIN experiments
"""
graph_mode: input representation of the graphs ['ADJ','SP']
list_Katoms: number of templates
list_alphas: Default to [-1] for learning the alphas, other float values can be given between 0 and 1 
list_batch: batch size to consider for the experiment validated in {32,128}
list_split_trainval: list of seeds corresponding to each fold of the CV taken in {0, ..., 9}
list_hidden_dim_netgin: list of hidden dimension to consider in the GIN layers, validated in {16, 32, 64} depending on datasets as described in the supplementary
list_dropout: validated values for the dropout ratios (see supplementary), taking values in {0, 0.2, 0.5}
num_layers_gin: number of GIN layers to use for node features preprocessing. (L=2 by default)
force_cpu : force usage of cpu if set to 1 else if set to 0 will detect gpu
cpu_njobs: number of cpus to parallelize the computation of FGW distances, default is cpu_count() - 1
features_trans: default is None, ie operate on raw features. features_train in ['degree', 'normalizeddegrees', 'ones'] can be used for exploration on social network datasets
"""

"""
dataset_name = str(argv[1])
assert dataset_name in ['imdb-b', 'imdb-m', 'mutag', 'ptc', 'nci1', 'bzr', 'cox2', 'enzymes', 'protein']
graph_mode =str(argv[2]) # ['ADJ','SP'...]
if '[' in argv[3]: # number of graph templates
    list_Katoms = [int(x) for x in argv[3][1:-1].split(',')] 
else:
    list_Katoms = [int(argv[3])]
print('list_Katoms:', list_Katoms)
if '[' in argv[4]: # alphas for the FGW loss > if alpha = -1, it will be considered as a learnable parameter initialized at 0.5
    list_alphas = [float(x) for x in argv[4][1:-1].split(',')] 
else:
    list_alphas = [float(argv[4])]
print('list_alphas:', list_alphas)

if '[' in argv[5]:
    list_batch= [int(x) for x in argv[5][1:-1].split(',') ]
else:
    list_batch= [int(argv[5])]
print('list_batch:', list_batch)
if '[' in argv[6]: # random seed for initializating the atoms.
    list_split_trainval = [int(x) for x in argv[6][1:-1].split(',')] 
else:
    list_split_trainval = [int(argv[6])]
print('list_split_trainval:', list_split_trainval)
if '[' in argv[7]: # random seed for initializating the atoms.
    list_hidden_dim_netgin  = [int(x) for x in argv[7][1:-1].split(',')] 
else:
    list_hidden_dim_netgin  = [int(argv[7])]

if '[' in argv[8]:
    list_dropout = [float(x) for x in argv[8][1:-1].split(',') ]
else:
    list_dropout = [float(argv[8])]

num_layers_gin = int(argv[9])
force_cpu = bool(int(argv[10]))
try:
    cpu_njobs = int(argv[11])
except:
    cpu_njobs = None
try:
    features_trans = str(argv[12])
except:
    features_trans  = None
"""
# %%

dataset_name ='mutag'
graph_mode = 'ADJ'
list_Katoms = [4]
lr=0.01
if dataset_name in ['mutag','ptc', 'protein', 'nci1', 'enzymes']:
    list_hidden_dim_netgin = [16, 32]
elif dataset_name in ['imdb-b', 'imdb-m', 'collab']:
    list_hidden_dim_netgin = [64]

list_batch = [128]
list_alphas = [-1]
epochs = 500

list_split_trainval = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
if dataset_name in ['mutag']:
    """
    exception to initialize the samples with shapes the closest possible to the median
    ensuring that for each class there exists in the train dataset a sample with the required shape for templates
    (this detail is developed in the supplementary material)
    """
    init_mode_atoms = 'sampling_supervised_max'
else:
    init_mode_atoms = 'sampling_supervised_median'
list_dropout = [0.0]
skip_first_features = False
num_layers_gin=2
learn_hbar = True

force_cpu = True
if force_cpu:
    os.environ["CUDA_VISIBLE_DEVICES"]=""
    device = 'cpu'
else:
    if th.cuda.is_available():
        device = th.device('cuda:0')
    else:
        device = 'cpu'
n_jobs = cpu_count() - 1

if graph_mode == 'SP':
    atoms_projection = 'nonnegative'
else:
    atoms_projection = 'clipped'
str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                 'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency'}

abspath = os.path.abspath('../')
res_repo = abspath+'/results/%s/'%dataset_name
data_path = abspath+'/real_datasets/'

algo_seed = 0
split_traintest_seed = 0
val_timestamp = 5#5
track_templates = True

dtype = th.float64

verbose= True
features_trans = None

hidden_dim_clf = 128
num_hidden_clf = 2
print('device: %s / cpu_count:%s'%(device, n_jobs))
num_hidden_netgin = 1
dtype = th.float64

for batch_size in list_batch:
    for dropout in list_dropout:
        for hidden_dim_netgin in list_hidden_dim_netgin:
            for split_trainval_seed in list_split_trainval:
                for alpha in list_alphas:
                    for Katoms in list_Katoms:
                        if (dataset_name == 'mutag') and (batch_size == 128):
                            supervised_sampler = False
                        else:
                            supervised_sampler = True
    
                        if supervised_sampler:
                            str_batch=''
                        else:
                            str_batch='random'
                        str_skip_first_features = ''
                        if skip_first_features :
                            str_skip_first_features = 'skipfl' 
                        th.manual_seed(algo_seed)
                        np.random.seed(algo_seed)
                        init_graph_str = {'sampling_supervised_median':'samplabelmed','sampling_supervised_modes':'samplabelmod',
                                        'sampling_supervised_max':'samplabelmax'}
                        proj_str  = {'clipped':'', 'nonnegative':'_nonnegative_'}
                        
                        gin_net_dict = {'hidden_dim':hidden_dim_netgin,
                                        'num_hidden':num_hidden_netgin}
                        
                        gin_layer_dict = {'num_layers':num_layers_gin}
                        
                        gnn_str = 'GINL%s%s_MLPh%so%sL%sbatchnorm_'%(num_layers_gin, str_skip_first_features , hidden_dim_netgin, hidden_dim_netgin, num_hidden_netgin)
                        clf_net_dict = {'hidden_dim':hidden_dim_clf,
                                        'num_hidden':num_hidden_clf,
                                        'dropout':dropout}
                        clf_str = 'clfMLPh%sL%sdropout%s'%(hidden_dim_clf, num_hidden_clf, dropout)
                        
                        if learn_hbar:
                            learn_hbar_str = 'learnhbarbis'
                        else:
                            learn_hbar_str = ''
                        
                        if (alpha >= 0.) and (alpha <= 1.):
                            experiment_name = '/V2%s%s_%s_aggregfdistK%s%s%s_alpha%s_%s_lr%s_%sbatch%s_ep%s_splitseed%s/'%(
                                graph_mode, learn_hbar_str, gnn_str, Katoms, proj_str[atoms_projection], 
                                init_graph_str[init_mode_atoms], alpha,
                                clf_str, lr, str_batch, batch_size, epochs, split_traintest_seed)
                        
                        elif alpha == -1:
                            experiment_name = '/V2%s%s_%s_aggregfdistK%s%s%s_alphalearnable_%s_lr%s_%sbatch%s_ep%s_splitseed%s/'%(
                                graph_mode, learn_hbar_str, gnn_str, Katoms, proj_str[atoms_projection], 
                                init_graph_str[init_mode_atoms], clf_str, lr, str_batch, batch_size, epochs, split_traintest_seed)
                        
                        else:
                            raise 'alpha= %s / not supported yet'
                        experiment_repo = res_repo + experiment_name
                        model_name = 'model_splitvalseed%s'%split_trainval_seed
                        print('experiment_repo:', experiment_repo)
                        print('model_name:', model_name)
                        train_model = True
                        if os.path.exists(experiment_repo):
                            print('already existing repository')
                            try:
                                training_log = pickle.load(open(experiment_repo+'/%s_training_log.pkl'%model_name,'rb'))
                                print('training log found / saved epochs:', len(training_log['train_cumulated_batch_loss']))
                                if (len(training_log['train_cumulated_batch_loss'])  == epochs):
                                    print('experiment already completed')
                                    train_model = False
                                else:
                                    train_model = True
                            except:
                                train_model = True
                        print('train model? ', train_model)
    
                        if train_model:
                            # Load graphs and corresponding labels for clustering benchmark
                            if dataset_name in ['mutag', 'ptc', 'nci1']:  # One-hot encoding
                                one_hot = True
                                standardized_features = False
                            elif dataset_name in ['enzymes', 'protein', 'bzr', 'cox2']:
                                one_hot = False
                                standardized_features = True
                            elif dataset_name in ['imdb-b', 'imdb-m']:
                                one_hot = False
                                standardized_features = False
                            X,labels=dataloader.load_local_data(data_path,dataset_name, one_hot=one_hot)                
                            unique_labels = np.unique(labels)
                            #print('unique_labels to fit to [0,C-1] :', unique_labels)
                            unprocessed_labels = np.array(labels)
                            new_labels = np.zeros_like(unprocessed_labels)
                            for idx_y, y in enumerate(unique_labels):
                                idx_samples = np.argwhere(unprocessed_labels==y)[:,0]
                                new_labels[idx_samples] = idx_y
                            labels = new_labels
                            unique_labels = np.unique(labels)
                            #print('unique labels after processing: ', unique_labels)
                            if graph_mode == 'ADJ':
                                graphs = [th.tensor(X[t].distance_matrix(method=str_to_method[graph_mode]), device=device, dtype=dtype) for t in range(X.shape[0])]
                                # Add self-loops for node features propagation
                                graphs = [C + th.eye(C.shape[0], device=device, dtype=dtype) for C in graphs]
                                # Self-loops are then omitted in the optimization of FGW templates
                                masses = [th.ones(C.shape[0], dtype=dtype, device=device)/C.shape[0] for C in graphs]                            
                                shapes = [C.shape[0] for C in graphs]
                            elif graph_mode == 'SP':
                                # join both to be able to pass the adjacency matrix for GNN Message Passing
                                adj_graphs = [th.tensor(X[t].distance_matrix(method=str_to_method['ADJ']), device='cpu', dtype=dtype) for t in range(X.shape[0])]
                                adj_graphs = [C + th.eye(C.shape[0], device='cpu', dtype=dtype) for C in adj_graphs]
                                if device != 'cpu':
                                    adj_graphs = [th.tensor(C, device=device, dtype=dtype) for C in adj_graphs]
                                sp_graphs = [th.tensor(X[t].distance_matrix(method=str_to_method['SP']),device=device, dtype=dtype) for t in range(X.shape[0])]
                                graphs = list(zip(sp_graphs, adj_graphs))
                                masses = [th.ones(C.shape[0], dtype=dtype, device=device)/C.shape[0] for C in adj_graphs]                            
                                shapes = [C.shape[0] for C in adj_graphs]
                            if not dataset_name in ['imdb-b', 'imdb-m']:
                                features= [th.tensor(np.array(X[t].values()), dtype=dtype, device=device) for t in range(X.shape[0])]                        
                                
                            else:
                                if features_trans is None: # one-hot encoding of node degrees
                                    if graph_mode == 'ADJ':
                                        degs = [C.sum(0).to(th.int64) for C in graphs]
                                    elif graph_mode == 'SP':
                                        degs = [C.sum(0).to(th.int64) for C in adj_graphs]
                                    
                                    max_deg = th.tensor([deg.max() for deg in degs]).max().item()                                    
                                    min_deg = th.tensor([deg.min() for deg in degs]).min().item()
                                    diff_deg = max_deg - min_deg
                                    features = []
                                    for deg in degs:
                                        N = deg.shape[0]
                                        F = th.zeros((N, diff_deg + 1), device=device, dtype=dtype)
                                        for i in range(N):
                                            F[i, deg[i].item() - min_deg ] = 1.
                                        features.append(F.to(device))
                                    standardized_features = False
                                
                                elif features_trans == 'degree':
                                    features = [C.sum(0)[:, None].to(device=device, dtype=dtype) for C in graphs]
                                    standardized_features = True
            
                                elif features_trans == 'normalizeddegree':
                                    degs = [C.sum(0).to(device=device, dtype=dtype) for C in graphs]
                                    features = [d[:,None]/d.sum() for d in degs]
                                    standardized_features = True
                                elif features_trans == 'ones':                            
                                    features = [th.ones((C.shape[0], 1), device=device, dtype=dtype) for C in graphs]
                                    standardized_features = False
                            if standardized_features:
                                print('stardardizing features')
                                stacked_features = features[0].mean(axis=0).cpu().numpy()
                                for F in features[1:]:
                                    mean_F = F.mean(axis=0).cpu().numpy()
                                    stacked_features = np.vstack([stacked_features, mean_F])
                                for i in range(stacked_features.shape[1]):
                                    mean_ = stacked_features[:, i].mean()
                                    std_ = stacked_features[:, i].std()
                                    for F in features:
                                        F[:, i] = (F[:, i] - mean_)/std_
                                
                                max_deg = th.tensor([deg.max() for deg in degs]).max().item()                                    
                                min_deg = th.tensor([deg.min() for deg in degs]).min().item()
                                diff_deg = max_deg - min_deg
                                features = []
                                for deg in degs:
                                    N = deg.shape[0]
                                    F = th.zeros((N, diff_deg + 1), device=device, dtype=dtype)
                                    for i in range(N):
                                        F[i, deg[i].item() - min_deg ] = 1.
                                    features.append(F.to(device))
                            if standardized_features:
                                print('stardardizing features')
                                stacked_features = features[0].mean(axis=0).cpu().numpy()
                                print('before norm: means F[0] = ', stacked_features)
                                for F in features[1:]:
                                    mean_F = F.mean(axis=0).cpu().numpy()
                                    stacked_features = np.vstack([stacked_features, mean_F])
                                for i in range(stacked_features.shape[1]):
                                    mean_ = stacked_features[:, i].mean()
                                    std_ = stacked_features[:, i].std()
                                    for F in features:
                                        F[:, i] = (F[:, i] - mean_)/std_
                                print('after norm: means F[0] = ', features[0].mean(axis=0))
                            
                            dataset_size = len(graphs)
                            input_shape = features[0].shape[-1]
                            # Instantiate the class for GW measure machine
                            model = GIN_FGWmachine(
                                graph_mode = graph_mode,
                                input_shape = input_shape,
                                Katoms = Katoms, 
                                n_labels = unique_labels.shape[0], 
                                alpha = alpha, 
                                learn_hbar = learn_hbar,
                                experiment_repo = experiment_repo, 
                                gin_net_dict = gin_net_dict,
                                gin_layer_dict = gin_layer_dict,
                                clf_net_dict = clf_net_dict,
                                skip_first_features =skip_first_features , 
                                dtype = dtype,
                                device = device)
                            #split dataset
                            idx_train, idx_test, y_train, y_test = train_test_split(np.arange(dataset_size), labels, test_size=0.1, stratify=labels, random_state=split_traintest_seed)
                            X_train, X_test = [graphs[idx] for idx in idx_train], [graphs[idx] for idx in idx_test]
                            F_train, F_test = [features[idx] for idx in idx_train], [features[idx] for idx in idx_test]
                            idx_subtrain, idx_val, y_subtrain, y_val = train_test_split(np.arange(len(X_train)),y_train, test_size=0.1, stratify=y_train, random_state=split_trainval_seed)
                            X_subtrain, X_val = [X_train[idx] for idx in idx_subtrain], [X_train[idx] for idx in idx_val] 
                            F_subtrain, F_val = [F_train[idx] for idx in idx_subtrain], [F_train[idx] for idx in idx_val] 
                            
                            if ('median' in init_mode_atoms) or ('GDL' in init_mode_atoms):
                                if dataset_name in ['mutag']:
                                    Satoms = int(np.median(shapes))
                                else:
                                    Satoms = round(np.median(shapes))
                                list_Satoms = [Satoms] * Katoms
                                print('median atom shapes:', list_Satoms)
                            elif 'max' in init_mode_atoms:
                                if dataset_name in ['mutag']:
                                    list_Satoms = [16] * Katoms
                                else:
                                    raise 'max only supported for mutag'
                            elif 'mode' in init_mode_atoms:
                                raise 'atom shape by modes - not implemented yet'
                            if graph_mode == 'ADJ':
                                model.init_parameters_with_aggregation(
                                    list_Satoms=list_Satoms, init_mode_atoms=init_mode_atoms, labels=th.tensor(y_subtrain, device=device, dtype=th.long),
                                    graphs=X_subtrain, features= F_subtrain, atoms_projection=atoms_projection)
                                    
                            elif graph_mode == 'SP':
                                model.init_parameters_with_aggregation(
                                    list_Satoms=list_Satoms, init_mode_atoms=init_mode_atoms, labels=th.tensor(y_subtrain, device=device, dtype=th.long),
                                    graphs=[C[0] for C in X_subtrain], features= F_subtrain, atoms_projection=atoms_projection)
                                

                            model.fit(model_name=model_name, X_train=X_subtrain, F_train=F_subtrain, y_train=th.tensor(y_subtrain, device=device, dtype=th.long),
                                      X_val=X_val, F_val=F_val, y_val=th.tensor(y_val, device=device, dtype=th.long),
                                      X_test=X_test, F_test=F_test, y_test=th.tensor(y_test, device=device, dtype=th.long),
                                      atoms_projection=atoms_projection, lr=lr, batch_size=batch_size, 
                                      supervised_sampler=supervised_sampler, epochs=epochs, val_timestamp=val_timestamp, 
                                      algo_seed=algo_seed, track_templates=track_templates, verbose=verbose, n_jobs=n_jobs)
                         