import numpy as np
import torch as th
import os
from tqdm import tqdm
import pickle
import utils_networks
from sklearn.metrics import accuracy_score
from torch_geometric.nn import GINConv
from torch_geometric.data import Data
#%%

    

class GIN():
    
    def __init__(self,
                 input_shape:int,
                 n_labels:int, 
                 experiment_repo:str, 
                 gin_net_dict:dict,
                 gin_layer_dict:dict,
                 clf_net_dict:dict,
                 skip_first_features:bool=False,
                 dtype = th.float64,
                 device='cpu'):
        """
        Implementation of the GIN model 
        from https://openreview.net/forum?id=ryGs6iA5Km

        Parameters
        ----------
        input_shape : int
            dimension of the features in the input.
        n_labels : int
            number of classes in the dataset.
        experiment_repo : str
            repository to save the experiment during training.
        gin_net_dict : dict
            Dictionary containing parameters for the global architecture of GIN.
            Must contain the keys:
                'hidden_dim' : dimension of the hidden layer validated in {16,32,64} depending on datasets.
                                The output dimension of the layer is the same than the hidden dimension.
                'num_hidden' : the number of hidden layers in each MLP.                                 
        gin_layer_dict : dict
            Dictionary containing parameters for the GIN layers. The parameters (eps, train_eps) are fixed to (0, False) as suggested by authors.
            Must contain the key
                'num_layers' : number of GIN layers in the architecture (strictly positive integer)
        clf_net_dict : dict
            Dictionary containing parameters for the MLP leading to label prediction
            Must contain the keys
                'hidden_dim' :(int) dimension of the hidden layer (fixed to 128)
                'num_hidden' :(int) number of hidden layers in the architecture
                'dropout' :(float) dropout rate to use
        skip_first_features : bool, optional
            Either to skip the input features or not in the concatenation of all GIN layer outputs. (see Jumping Knowledge Networks)
            The default is False.
        dtype : TYPE, optional
            DESCRIPTION. The default is th.float64.
        device : TYPE, optional
            DESCRIPTION. The default is 'cpu'.
        """
        assert np.all( [s in gin_net_dict.keys() for s in ['hidden_dim', 'num_hidden'] ])
        assert np.all( [s in gin_layer_dict.keys() for s in ['num_layers']])
        assert np.all( [s in clf_net_dict.keys() for s in ['hidden_dim', 'num_hidden', 'dropout']] )
        assert gin_layer_dict['num_layers'] > 0
        self.input_shape = input_shape
        self.n_labels = n_labels
        self.device = device
        self.gin_net_dict = gin_net_dict
        self.gin_layer_dict = gin_layer_dict   
        # if set to True when aggregate_gin_layers is True: skip input features in the aggregated template features.
        
        self.skip_first_features = skip_first_features 
        self.classification_metrics = {'accuracy' : (lambda y_true,y_pred : accuracy_score(y_true,y_pred))}
        
        self.experiment_repo = experiment_repo
        if not os.path.exists(self.experiment_repo):
            os.makedirs(self.experiment_repo)
        
        self.dtype = dtype
        # Instantiate network for GIN embeddings
        self.GIN_layers = th.nn.ModuleList()
        self.clf_input_shape = 0
        if not self.skip_first_features:
            self.clf_input_shape = self.input_shape
        for layer in range(gin_layer_dict['num_layers']):
            if layer == 0:
                local_input_shape = self.input_shape
            else:
                local_input_shape = gin_net_dict['hidden_dim']
            MLP = utils_networks.MLP_batchnorm(
                local_input_shape, 
                gin_net_dict['hidden_dim'], # output_dim = hidden_dim
                gin_net_dict['hidden_dim'], 
                gin_net_dict['num_hidden'], 'relu', device=self.device, dtype=self.dtype)
            GIN_layer = GINConv(MLP, eps= 0., train_eps=False).to(self.device)
            self.GIN_layers.append(GIN_layer)   
            self.clf_input_shape += gin_net_dict['hidden_dim']
                 
        print('clf_input_shape with aggregation:', self.clf_input_shape)
        # Instantiate network for classification
        if clf_net_dict['dropout'] != 0.:
            self.clf_Net = utils_networks.MLP_dropout(
                input_dim = self.clf_input_shape, 
                output_dim = self.n_labels,
                hidden_dim = clf_net_dict['hidden_dim'],
                num_hidden = clf_net_dict['num_hidden'],
                output_activation= 'linear', 
                dropout = clf_net_dict['dropout'],
                skip_first = True,
                device=self.device, dtype=self.dtype)
        else:
            self.clf_Net = utils_networks.MLP(
                self.clf_input_shape, 
                self.n_labels,
                clf_net_dict['hidden_dim'],
                clf_net_dict['num_hidden'],
                'linear', device=self.device, dtype=self.dtype)
      
        self.loss = th.nn.CrossEntropyLoss().to(self.device)
        self.params =  list(self.GIN_layers.parameters()) + list(self.clf_Net.parameters())
        
        print('GIN_layers:', self.GIN_layers.parameters)
        print('Clf:', self.clf_Net.parameters)
    
    
    
    def set_model_to_train(self):
        self.GIN_layers.train()
        self.clf_Net.train()
    
    def set_model_to_eval(self):
        self.GIN_layers.eval()
        self.clf_Net.eval()
    
    def GIN_forward(self, batch_graphs:list, batch_features:list, batch_shapes:th.Tensor, cumsum_shapes:th.Tensor):
        processed_batch_graphs = th.block_diag(*[C for C in batch_graphs])                 
        batch_edge_index = th.argwhere(processed_batch_graphs == 1.).T
        processed_batch_features = th.cat([F for F in batch_features])
        
        layered_batch_features = []
        if not self.skip_first_features:
            layered_batch_features.append(processed_batch_features)
    
        batch_embedded_features = self.GIN_layers[0](x=processed_batch_features, edge_index=batch_edge_index)
        
        layered_batch_features.append(batch_embedded_features)
            
        for layer in range(1, self.gin_layer_dict['num_layers']):
            batch_embedded_features = self.GIN_layers[layer](x=batch_embedded_features, edge_index=batch_edge_index)
            layered_batch_features.append(batch_embedded_features)
        # for a given node, we concatenate embeddings generated at each k GIN layer inducing a k-hop smoothing
        batch_embedded_features = th.cat(layered_batch_features, dim = 1) 
        batch_embedded_features_uncat =  [batch_embedded_features[cumsum_shapes[k] : cumsum_shapes[k + 1], :] for k in range(len(batch_shapes))]
                               
        # global pooling scheme = sum
        batch_embedded_graphs = [F.sum(0) for F in batch_embedded_features_uncat]
        batch_embedded_graphs = th.stack(batch_embedded_graphs, dim=0)
        return batch_embedded_graphs
        
    def fit(self, 
            model_name:str, 
            X_train:list, F_train:list, y_train:list, 
            X_val:list, F_val:list, y_val:list, 
            X_test:list, F_test:list, y_test:list, # test was provided for reducing the storage as original papers proposed a validation scheme where epochs were validated across runs
            lr:float, batch_size:int, supervised_sampler:bool, 
            epochs:int, val_timestamp:int, algo_seed:int, use_lrschedule:bool, decay_step_size:int=50, verbose:bool=False):
        th.manual_seed(algo_seed)
        np.random.seed(algo_seed)

        n_train = y_train.shape[0]
        if not (X_val is None):
            do_validation = True
            sets = ['train', 'val', 'test']
            best_val_acc = - np.inf
            best_val_acc_train_acc = - np.inf            
        else:
            do_validation = False
            sets = ['train']
            best_train_acc = - np.inf
            best_train_loss = np.inf  
        self.log = {'train_cumulated_batch_loss':[]}
        for metric in list(self.classification_metrics.keys()) + ['epoch_loss']:
            for s in sets:
                self.log['%s_%s'%(s, metric)]=[]
        
                    
        
        self.optimizer = th.optim.Adam(params=self.params, lr=lr, betas=[0.9, 0.99])
        if use_lrschedule:
            self.scheduler = th.optim.lr_scheduler.StepLR(self.optimizer, step_size=decay_step_size, gamma=0.5)

        if n_train <= batch_size:
            batch_by_epoch = 1
            batch_size = n_train
            print('batch size bigger than #samples > batch_size set to ', batch_size)
        else:
            batch_by_epoch = n_train//batch_size +1
        y_train_ = y_train.detach().cpu()
        if supervised_sampler:
            unique_labels = th.unique(y_train_)
            n_labels = unique_labels.shape[0]
            train_idx_by_labels = [th.where(y_train_==label)[0] for label in unique_labels]
            labels_by_batch = batch_size // n_labels
        
        self.set_model_to_train()
        for e in tqdm(range(epochs), desc='epochs'):
            
            cumulated_batch_loss=0.
            for batch_i in range(batch_by_epoch):
                
                self.optimizer.zero_grad()
                if not supervised_sampler:
                    batch_idx = np.random.choice(range(n_train), size=batch_size, replace=False)
                else:
                    r = batch_size
                    for idx_label,label  in enumerate(np.random.permutation(unique_labels)):
                        
                        local_batch_idx = np.random.choice(train_idx_by_labels[label], size=min(r,labels_by_batch), replace=False)
                        if idx_label ==0:
                            batch_idx = local_batch_idx
                        else:
                            batch_idx = np.concatenate([batch_idx,local_batch_idx])
                        r -= labels_by_batch

                batch_graphs = [X_train[idx] for idx in batch_idx]
                batch_shapes = [C.shape[0] for C in batch_graphs]
                
                cumsum_shapes = th.tensor([0]+batch_shapes).cumsum(dim=0)
                batch_features = [F_train[idx] for idx in batch_idx]                
                batch_y = y_train[batch_idx]
                batch_embedded_graphs = self.GIN_forward(batch_graphs, batch_features, batch_shapes, cumsum_shapes)
                batch_pred = self.clf_Net(batch_embedded_graphs)
                batch_loss = self.loss(batch_pred, batch_y)
                
                cumulated_batch_loss += batch_loss.item()

                batch_loss.backward()
                self.optimizer.step()
                
            self.log['train_cumulated_batch_loss'].append(cumulated_batch_loss)
            if use_lrschedule:
                self.scheduler.step()
            if (((e %val_timestamp) ==0) and e>0) or (e == (epochs - 1)):
                with th.no_grad():
                    self.set_model_to_eval()
                    if self.device== 'cpu':
                        features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_fullbatch(X_train, F_train, y_train)
                    else:
                        features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_minibatch(X_train, F_train, y_train, batch_size)
                    
                    if do_validation:
                        if self.device== 'cpu':
                            features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_fullbatch(X_val, F_val, y_val)
                            features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_fullbatch(X_test, F_test, y_test)

                        else:
                            features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_minibatch(X_train, F_train, y_train, batch_size)
                            features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_minibatch(X_val, F_val, y_val, batch_size)
                            features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_minibatch(X_test, F_test, y_test, batch_size)
                        
                        if verbose:
                            print('epoch= %s / loss_train = %s / res_train = %s / loss_val =%s/ res_val =%s'%(e,loss_train.item(),res_train,loss_val.item(),res_val))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        self.log['val_epoch_loss'].append(loss_val.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            self.log['val_%s'%metric].append(res_val[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_val_acc <= res_val['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_val_acc = res_val['accuracy']
                            if best_val_acc_train_acc <= res_train['accuracy']:
                                best_val_acc_train_acc = res_train['accuracy']
                                str_file = self.experiment_repo+'/%s_best_val_accuracy_increasing_train_accuracy.pkl'%model_name
                                full_dict_state = {'epoch' : e,
                                                   'GIN_params':self.GIN_layers.state_dict(),
                                                   'clf_params':self.clf_Net.state_dict(),
                                                   }
                                pickle.dump(full_dict_state, open(str_file, 'wb'))
                    else:
                        self.log['train_epoch_loss'].append(loss_train.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                         
                        if best_train_acc <= res_train['accuracy']: #Save model with best val acc assuring increase of train acc
                             save_epoch = False    
                             if best_train_acc == res_train['accuracy']:
                                 if best_train_loss > loss_train.item():
                                     save_epoch = True
                                     best_train_loss = loss_train.item()
                                 else:
                                     save_epoch = False
                             else:
                                 save_epoch = True
                                 best_train_loss = loss_train.item()
                                 best_train_acc = res_train['accuracy']
                             if save_epoch:
                                 if verbose:
                                     print('saving epoch')
                                 str_file = self.experiment_repo+'/%s_best_train_accuracy.pkl'%model_name
                                 full_dict_state = {
                                     'epoch' : e, 
                                     'GIN_params':self.GIN_layers.state_dict(),
                                     'clf_params':self.clf_Net.state_dict()}
                                 pickle.dump(full_dict_state, open(str_file, 'wb'))
                                          
                # after evaluation, make the model trainable again
                self.set_model_to_train()
                
    def evaluate_fullbatch(self, list_C:list, list_F:list, list_y:list):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        
        with th.no_grad():
            batch_shapes = [C.shape[0] for C in list_C]
            cumsum_shapes = th.tensor([0] + batch_shapes).cumsum(dim=0)
            batch_embedded_graphs = self.GIN_forward(list_C, list_F, batch_shapes, cumsum_shapes)
            
            pred = self.clf_Net(batch_embedded_graphs)
            loss = self.loss(pred, list_y)
            y_pred = pred.argmax(1)
            y_ = list_y.detach().numpy()
            y_pred_ = y_pred.detach().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return batch_embedded_graphs, pred, y_pred, loss, res
    
    def evaluate_minibatch(self, full_list_C:list, full_list_F:list, full_list_y:list, batch_size:int):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        
        with th.no_grad():
            full_embedded_graphs = []
            len_ = len(full_list_C)
            n_splits = len_ // batch_size + 1
            full_idx = np.arange(len_)
            for k in range(n_splits):
                idx_batch = full_idx[k * batch_size : (k + 1) * batch_size]
                #print('idx_batch:', idx_batch[0], idx_batch[-1])
                batch_graphs = [full_list_C[idx] for idx in idx_batch]
                batch_features = [full_list_F[idx] for idx in idx_batch]
                batch_shapes = [C.shape[0] for C in batch_graphs]
                cumsum_shapes = th.tensor([0] + batch_shapes).cumsum(dim=0)
                batch_embedded_graphs = self.GIN_forward(batch_graphs, batch_features, batch_shapes, cumsum_shapes)
                full_embedded_graphs.append(batch_embedded_graphs)
            full_embedded_graphs = th.cat(full_embedded_graphs, dim=0)                              
            pred = self.clf_Net(full_embedded_graphs)
            loss = self.loss(pred, full_list_y)
            y_pred = pred.argmax(1)
            y_ = full_list_y.detach().cpu().numpy()
            y_pred_ = y_pred.detach().cpu().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return full_embedded_graphs, pred, y_pred, loss, res
    
    def load(self, model_name:str, dtype:type=th.float64):
        str_file = '%s/%s_best_val_accuracy_increasing_train_accuracy.pkl'%(self.experiment_repo, model_name)
        full_dict_state = pickle.load(open(str_file, 'rb'))
        self.clf_Net.load_state_dict(full_dict_state['clf_params'])
        self.GIN_layers.load_state_dict(full_dict_state['GIN_params'])
        print('[SUCCESSFULLY LOADED] ',str_file)
    
#%%
class DropGIN():
    
    def __init__(self,
                 input_shape:int,
                 n_labels:int,
                 r:int,
                 rdropout:float,
                 experiment_repo:str, 
                 gin_net_dict:dict,
                 gin_layer_dict:dict,
                 clf_net_dict:dict,
                 skip_first_features:bool=False,
                 dtype = th.float64,
                 device='cpu'):
        
        """
        Implementation of the DropGIN model 
        from https://openreview.net/forum?id=fpQojkIV5q8
        
        Parameters
        ----------
        input_shape : int
            dimension of the features in the input.
        n_labels : int
            number of classes in the dataset.
        r : int
            number of runs for each forward (r perturbations of a graph are used to get a final graph representation using averaged perturbated graph representations)
        rdropout:float,
            rate of node dropout used to get the perturbated graphs.
        experiment_repo : str
            repository to save the experiment during training.
        gin_net_dict : dict
            Dictionary containing parameters for the global architecture of GIN.
            Must contain the keys:
                'hidden_dim' : dimension of the hidden layer validated in {16,32,64} depending on datasets.
                                The output dimension of the layer is the same than the hidden dimension.
                'num_hidden' : the number of hidden layers in each MLP.                                 
        gin_layer_dict : dict
            Dictionary containing parameters for the GIN layers. The parameters (eps, train_eps) are fixed to (0, False) as suggested by authors.
            Must contain the key
                'num_layers' : number of GIN layers in the architecture (strictly positive integer)
        clf_net_dict : dict
            Dictionary containing parameters for the MLP leading to label prediction
            Must contain the keys
                'hidden_dim' :(int) dimension of the hidden layer (fixed to 128)
                'num_hidden' :(int) number of hidden layers in the architecture
                'dropout' :(float) dropout rate to use
        skip_first_features : bool, optional
            Either to skip the input features or not in the concatenation of all GIN layer outputs. (see Jumping Knowledge Networks)
            The default is False.
        dtype : TYPE, optional
            DESCRIPTION. The default is th.float64.
        device : TYPE, optional
            DESCRIPTION. The default is 'cpu'.
        """
        assert np.all( [s in gin_net_dict.keys() for s in ['hidden_dim', 'num_hidden'] ])
        assert np.all( [s in gin_layer_dict.keys() for s in ['num_layers']])
        assert np.all( [s in clf_net_dict.keys() for s in ['hidden_dim', 'num_hidden', 'dropout']] )
        assert gin_layer_dict['num_layers'] > 0
        self.input_shape = input_shape
        self.n_labels = n_labels
        self.r = r
        self.rdropout = rdropout
        self.device = device
        self.gin_net_dict = gin_net_dict
        self.gin_layer_dict = gin_layer_dict   
        # if set to True when aggregate_gin_layers is True: skip input features in the aggregated template features.
        
        self.skip_first_features = skip_first_features 
        self.classification_metrics = {'accuracy' : (lambda y_true,y_pred : accuracy_score(y_true,y_pred))}
        
        self.experiment_repo = experiment_repo
        if not os.path.exists(self.experiment_repo):
            os.makedirs(self.experiment_repo)
        
        self.dtype = dtype
        # Instantiate network for GIN embeddings
        self.GIN_layers = th.nn.ModuleList()
        self.clf_input_shape = 0
        if not self.skip_first_features:
            self.clf_input_shape = self.input_shape
        for layer in range(gin_layer_dict['num_layers']):
            if layer == 0:
                local_input_shape = self.input_shape
            else:
                local_input_shape = gin_net_dict['hidden_dim']
            MLP = utils_networks.MLP_batchnorm(
                local_input_shape, 
                gin_net_dict['hidden_dim'], # output_dim = hidden_dim
                gin_net_dict['hidden_dim'], 
                gin_net_dict['num_hidden'], 'relu', device=self.device, dtype=self.dtype)
            GIN_layer = GINConv(MLP, eps= 0., train_eps=False).to(self.device)
            self.GIN_layers.append(GIN_layer)   
            self.clf_input_shape += gin_net_dict['hidden_dim']
                 
        print('clf_input_shape with aggregation:', self.clf_input_shape)
        # Instantiate network for classification
        if clf_net_dict['dropout'] != 0.:
            self.clf_Net = utils_networks.MLP_dropout(
                input_dim = self.clf_input_shape, 
                output_dim = self.n_labels,
                hidden_dim = clf_net_dict['hidden_dim'],
                num_hidden = clf_net_dict['num_hidden'],
                output_activation= 'linear', 
                dropout = clf_net_dict['dropout'],
                skip_first = True,
                device=self.device, dtype=self.dtype)
        else:
            self.clf_Net = utils_networks.MLP(
                self.clf_input_shape, 
                self.n_labels,
                clf_net_dict['hidden_dim'],
                clf_net_dict['num_hidden'],
                'linear', device=self.device, dtype=self.dtype)
      
        self.loss = th.nn.CrossEntropyLoss().to(self.device)
        self.params =  list(self.GIN_layers.parameters()) + list(self.clf_Net.parameters())
        
        print('GIN_layers:', self.GIN_layers.parameters)
        print('Clf:', self.clf_Net.parameters)
    
    def set_model_to_train(self):
        self.GIN_layers.train()
        self.clf_Net.train()
    
    def set_model_to_eval(self):
        self.GIN_layers.eval()
        self.clf_Net.eval()

    def DropGIN_forward(self, batch_graphs:list, batch_features:list, batch_shapes:th.Tensor, cumsum_shapes:th.Tensor):
        processed_batch_features = th.cat([F for F in batch_features])
        # GNN filters
        aggbatch_embedded_features = None
        processed_batch_graphs = th.block_diag(*[C for C in batch_graphs])                 
        # for each run drop nodes and its connections within the graphs
        for _ in range(self.r):
            drop = th.bernoulli(self.rdropout * th.ones(processed_batch_graphs.shape[0], device=self.device)).bool()
            mask = th.ones(processed_batch_graphs.shape, device=self.device, dtype=self.dtype)
            mask[drop, :] = 0.
            mask[:, drop] = 0.
            dropbatch_edge_index = th.argwhere((processed_batch_graphs * mask) == 1.).T
                
            droplayered_batch_features = []
            if not self.skip_first_features:
                droplayered_batch_features.append(processed_batch_features)
        
            dropbatch_embedded_features = self.GIN_layers[0](x=processed_batch_features, edge_index=dropbatch_edge_index)
            
            droplayered_batch_features.append(dropbatch_embedded_features)
                
            for layer in range(1, self.gin_layer_dict['num_layers']):
                dropbatch_embedded_features = self.GIN_layers[layer](x=dropbatch_embedded_features, edge_index=dropbatch_edge_index)
                droplayered_batch_features.append(dropbatch_embedded_features)
            # for a given node, we concatenate embeddings generated at each k GIN layer inducing a k-hop smoothing
            dropbatch_embedded_features = th.cat(droplayered_batch_features, dim = 1) 
            if aggbatch_embedded_features is None:
                aggbatch_embedded_features = dropbatch_embedded_features
            else:
                aggbatch_embedded_features += dropbatch_embedded_features
        aggbatch_embedded_features /= self.r
        aggbatch_embedded_features_uncat =  [aggbatch_embedded_features[cumsum_shapes[k] : cumsum_shapes[k + 1], :] for k in range(len(batch_shapes))]
                                   
        # global pooling scheme = sum
        batch_embedded_graphs = [F.sum(0) for F in aggbatch_embedded_features_uncat]
        batch_embedded_graphs = th.stack(batch_embedded_graphs, dim=0)
        return batch_embedded_graphs
    
    def fit(self, 
            model_name:str, 
            X_train:list, F_train:list, y_train:list, 
            X_val:list, F_val:list, y_val:list, 
            X_test:list, F_test:list, y_test:list,
            lr:float, batch_size:int, supervised_sampler:bool, 
            epochs:int, val_timestamp:int, algo_seed:int, use_lrschedule:bool, decay_step_size:int=50, verbose:bool=False):
        th.manual_seed(algo_seed)
        np.random.seed(algo_seed)

        n_train = y_train.shape[0]
        
        if not (X_val is None):
            
            do_validation = True
            sets = ['train', 'val', 'test']
            best_val_acc = - np.inf
            best_val_acc_train_acc = - np.inf
            
        else:
            do_validation = False
            sets = ['train']
            best_train_acc = - np.inf
            best_train_loss = np.inf  
        self.log = {'train_cumulated_batch_loss':[]}
        for metric in list(self.classification_metrics.keys()) + ['epoch_loss']:
            for s in sets:
                self.log['%s_%s'%(s, metric)]=[]
        
                    
        
        self.optimizer = th.optim.Adam(params=self.params, lr=lr, betas=[0.9, 0.99])
        if use_lrschedule:
            self.scheduler = th.optim.lr_scheduler.StepLR(self.optimizer, step_size=decay_step_size, gamma=0.5)

        if n_train <= batch_size:
            batch_by_epoch = 1
            batch_size = n_train
            print('batch size bigger than #samples > batch_size set to ', batch_size)
        else:
            batch_by_epoch = n_train//batch_size +1
        y_train_ = y_train.detach().cpu()
        if supervised_sampler:
            unique_labels = th.unique(y_train_)
            n_labels = unique_labels.shape[0]
            train_idx_by_labels = [th.where(y_train_==label)[0] for label in unique_labels]
            labels_by_batch = batch_size // n_labels
        
        self.set_model_to_train()
        for e in tqdm(range(epochs), desc='epochs'):
            
            cumulated_batch_loss=0.
            #for batch_i in tqdm(range(batch_by_epoch),desc='epoch %s'%e):
            for batch_i in range(batch_by_epoch):
                
                self.optimizer.zero_grad()
                if not supervised_sampler:
                    batch_idx = np.random.choice(range(n_train), size=batch_size, replace=False)
                else:
                    r = batch_size
                    for idx_label,label  in enumerate(np.random.permutation(unique_labels)):
                        
                        local_batch_idx = np.random.choice(train_idx_by_labels[label], size=min(r,labels_by_batch), replace=False)
                        if idx_label ==0:
                            batch_idx = local_batch_idx
                        else:
                            batch_idx = np.concatenate([batch_idx,local_batch_idx])
                        r -= labels_by_batch

                batch_graphs = [X_train[idx] for idx in batch_idx]
                batch_shapes = [C.shape[0] for C in batch_graphs]
                
                cumsum_shapes = th.tensor([0]+batch_shapes).cumsum(dim=0)
                batch_features = [F_train[idx] for idx in batch_idx]                
                batch_embedded_graphs = self.DropGIN_forward(batch_graphs, batch_features, batch_shapes, cumsum_shapes)
                    
                batch_pred = self.clf_Net(batch_embedded_graphs)
                batch_loss = self.loss(batch_pred, y_train[batch_idx])
                cumulated_batch_loss += batch_loss.item()
                batch_loss.backward()
                self.optimizer.step()
                
            self.log['train_cumulated_batch_loss'].append(cumulated_batch_loss)
            if use_lrschedule:
                self.scheduler.step()
                #print('current lr with scheduler:', self.scheduler.get_last_lr())
            if (((e %val_timestamp) ==0) and e>0) or (e == (epochs - 1)):
                with th.no_grad():
                    self.set_model_to_eval()
                    if self.device== 'cpu':
                        features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_fullbatch(X_train, F_train, y_train)
                    else:
                        features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_minibatch(X_train, F_train, y_train, batch_size)
                    
                    if do_validation:
                        if self.device== 'cpu':
                            features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_fullbatch(X_val, F_val, y_val)
                            features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_fullbatch(X_test, F_test, y_test)

                        else:
                            features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_minibatch(X_train, F_train, y_train, batch_size)
                            features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_minibatch(X_val, F_val, y_val, batch_size)
                            features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_minibatch(X_test, F_test, y_test, batch_size)
                        
                        if verbose:
                            print('epoch= %s / loss_train = %s / res_train = %s / loss_val =%s/ res_val =%s'%(e,loss_train.item(),res_train,loss_val.item(),res_val))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        self.log['val_epoch_loss'].append(loss_val.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            self.log['val_%s'%metric].append(res_val[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_val_acc <= res_val['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_val_acc = res_val['accuracy']
                            if best_val_acc_train_acc <= res_train['accuracy']:
                                best_val_acc_train_acc = res_train['accuracy']
                                str_file = self.experiment_repo+'/%s_best_val_accuracy_increasing_train_accuracy.pkl'%model_name
                                full_dict_state = {'epoch' : e,
                                                   'GIN_params':self.GIN_layers.state_dict(),
                                                   'clf_params':self.clf_Net.state_dict(),
                                                   }
                                pickle.dump(full_dict_state, open(str_file, 'wb'))
                    else:
                        self.log['train_epoch_loss'].append(loss_train.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                         
                        if best_train_acc <= res_train['accuracy']: #Save model with best val acc assuring increase of train acc
                             save_epoch = False    
                             if best_train_acc == res_train['accuracy']:
                                 if best_train_loss > loss_train.item():
                                     save_epoch = True
                                     best_train_loss = loss_train.item()
                                 else:
                                     save_epoch = False
                             else:
                                 save_epoch = True
                                 best_train_loss = loss_train.item()
                                 best_train_acc = res_train['accuracy']
                             if save_epoch:
                                 if verbose:
                                     print('saving epoch')
                                 str_file = self.experiment_repo+'/%s_best_train_accuracy.pkl'%model_name
                                 full_dict_state = {
                                     'epoch' : e, 
                                     'GIN_params':self.GIN_layers.state_dict(),
                                     'clf_params':self.clf_Net.state_dict()}
                                 pickle.dump(full_dict_state, open(str_file, 'wb'))
                                          
                # after evaluation, make the model trainable again
                self.set_model_to_train()
                
    def evaluate_fullbatch(self, list_C:list, list_F:list, list_y:list):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        
        with th.no_grad():
            batch_shapes = [C.shape[0] for C in list_C]
            cumsum_shapes = th.tensor([0] + batch_shapes).cumsum(dim=0)
            batch_embedded_graphs = self.DropGIN_forward(list_C, list_F, batch_shapes, cumsum_shapes)
            
            pred = self.clf_Net(batch_embedded_graphs)
            loss = self.loss(pred, list_y)
            y_pred = pred.argmax(1)
            y_ = list_y.detach().numpy()
            y_pred_ = y_pred.detach().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return batch_embedded_graphs, pred, y_pred, loss, res
    
    def evaluate_minibatch(self, full_list_C:list, full_list_F:list, full_list_y:list, batch_size:int):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        
        with th.no_grad():
            full_embedded_graphs = []
            len_ = len(full_list_C)
            n_splits = len_ // batch_size + 1
            full_idx = np.arange(len_)
            for k in range(n_splits):
                idx_batch = full_idx[k * batch_size : (k + 1) * batch_size]
                #print('idx_batch:', idx_batch[0], idx_batch[-1])
                batch_graphs = [full_list_C[idx] for idx in idx_batch]
                batch_features = [full_list_F[idx] for idx in idx_batch]
                batch_shapes = [C.shape[0] for C in batch_graphs]
                cumsum_shapes = th.tensor([0] + batch_shapes).cumsum(dim=0)
                batch_embedded_graphs = self.DropGIN_forward(batch_graphs, batch_features, batch_shapes, cumsum_shapes)
                full_embedded_graphs.append(batch_embedded_graphs)
            full_embedded_graphs = th.cat(full_embedded_graphs, dim=0)                              
            pred = self.clf_Net(full_embedded_graphs)
            loss = self.loss(pred, full_list_y)
            y_pred = pred.argmax(1)
            y_ = full_list_y.detach().cpu().numpy()
            y_pred_ = y_pred.detach().cpu().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return full_embedded_graphs, pred, y_pred, loss, res
    
    def load(self, model_name:str, dtype:type=th.float64):
        str_file = '%s/%s_best_val_accuracy_increasing_train_accuracy.pkl'%(self.experiment_repo, model_name)
        full_dict_state = pickle.load(open(str_file, 'rb'))
        self.clf_Net.load_state_dict(full_dict_state['clf_params'])
        self.GIN_layers.load_state_dict(full_dict_state['GIN_params'])
        print('[SUCCESSFULLY LOADED] ',str_file)
    