import numpy as np
import ot
import torch as th
import os
from tqdm import tqdm
import pickle
import utils_networks
import GW_utils 
from sklearn.metrics import accuracy_score
from joblib import Parallel, delayed
from sklearn.cluster import KMeans
from torch_geometric.nn import GINConv
#try:
#    import nvidia_smi
#except:
#    print('could not find nvidia-smi for GPU memory tracking')
#    pass
#%%


class OT_GIN():
    
    def __init__(self,
                 input_shape:int,
                 Katoms:int, 
                 n_labels:int, 
                 experiment_repo:str, 
                 gin_net_dict:dict,
                 gin_layer_dict:dict,
                 clf_net_dict:dict,
                 skip_first_features:bool=False,
                 sizes_scaling:bool=False,
                 dtype = th.float64,
                 device='cpu'):
        """
        Implementation of the OT-GIN model
        adapted from https://openreview.net/forum?id=o1O5nc48rn
        for a fair benchmark with TFGW.
        
        Parameters
        ----------
        input_shape : int
            dimension of the features in the input.
        Katoms : int
            number of templates in the Wasserstein layer
        n_labels : int
            number of classes in the dataset.
        sizes_scaling: bool
            Default is False. Either to multiply by respective sizes of graphs in the embedding.
            Authors suggested to use this scaling but it was definitely hurting the performances of the adaptation of their models in our experiments. Our default setting lead to much better performances.
        experiment_repo : str
            repository to save the experiment during training.
        gin_net_dict : dict
            Dictionary containing parameters for the global architecture of GIN.
            Must contain the keys:
                'hidden_dim' : dimension of the hidden layer validated in {16,32,64} depending on datasets.
                                The output dimension of the layer is the same than the hidden dimension.
                'num_hidden' : the number of hidden layers in each MLP.                                 
        gin_layer_dict : dict
            Dictionary containing parameters for the GIN layers. The parameters (eps, train_eps) are fixed to (0, False) as suggested by authors.
            Must contain the key
                'num_layers' : number of GIN layers in the architecture (strictly positive integer)
        clf_net_dict : dict
            Dictionary containing parameters for the MLP leading to label prediction
            Must contain the keys
                'hidden_dim' :(int) dimension of the hidden layer (fixed to 128)
                'num_hidden' :(int) number of hidden layers in the architecture
                'dropout' :(float) dropout rate to use
        skip_first_features : bool, optional
            Either to skip the input features or not in the concatenation of all GIN layer outputs. (see Jumping Knowledge Networks)
            The default is False.
        dtype : TYPE, optional
            DESCRIPTION. The default is th.float64.
        device : TYPE, optional
            DESCRIPTION. The default is 'cpu'.
        """
        assert np.all( [s in gin_net_dict.keys() for s in ['hidden_dim', 'num_hidden'] ])
        assert np.all( [s in gin_layer_dict.keys() for s in ['num_layers']])
        assert np.all( [s in clf_net_dict.keys() for s in ['hidden_dim', 'num_hidden', 'dropout']] )
        assert gin_layer_dict['num_layers'] > 0
        self.input_shape = input_shape
        self.Katoms = Katoms
        self.n_labels = n_labels
        self.device = device
        self.gin_net_dict = gin_net_dict
        self.gin_layer_dict = gin_layer_dict   
        self.skip_first_features = skip_first_features # if set to True when aggregate_gin_layers is True: skip input features in the aggregated template features.
        
        self.classification_metrics = {'accuracy' : (lambda y_true,y_pred : accuracy_score(y_true,y_pred))}    
        
        self.experiment_repo = experiment_repo
        if not os.path.exists(self.experiment_repo):
            os.makedirs(self.experiment_repo)
        
        # Representing a point cloud as a feature matrix with weights hbar fixed to uniform 
        # as in the original paper.
        self.Fbar, self.hbar = None, None
        self.sizes_scaling = sizes_scaling
        self.dtype = dtype
        # Instantiate network for GIN embeddings
        self.GIN_layers = th.nn.ModuleList()
        
        for layer in range(gin_layer_dict['num_layers']):
            if layer == 0:
                local_input_shape = self.input_shape
            else:
                local_input_shape = gin_net_dict['hidden_dim']
            MLP = utils_networks.MLP_batchnorm(
                local_input_shape, 
                gin_net_dict['hidden_dim'], 
                gin_net_dict['hidden_dim'], 
                gin_net_dict['num_hidden'], 'relu', device=self.device, dtype=self.dtype)
            GIN_layer = GINConv(MLP, eps= 0., train_eps=False).to(self.device)
            
            self.GIN_layers.append(GIN_layer)
                 
    
        # Instantiate network for classification
        if clf_net_dict['dropout'] != 0.:
            self.clf_Net = utils_networks.MLP_dropout(
                input_dim = self.Katoms, 
                output_dim = self.n_labels,
                hidden_dim = clf_net_dict['hidden_dim'],
                num_hidden = clf_net_dict['num_hidden'],
                output_activation= 'linear', 
                dropout = clf_net_dict['dropout'],
                skip_first = True,
                device=self.device, dtype=self.dtype)
        else:
            self.clf_Net = utils_networks.MLP(
                self.Katoms, 
                self.n_labels,
                clf_net_dict['hidden_dim'],
                clf_net_dict['num_hidden'],
                'linear', device=self.device, dtype=self.dtype)
      
        self.loss = th.nn.CrossEntropyLoss().to(self.device)
    
    def init_parameters_with_aggregation(self, 
                                         list_Satoms:list, 
                                         init_mode_atoms:str,
                                         graphs:list=None, features:list=None, 
                                         labels:list=None, atoms_projection:str='clipped'):
        """
        No initialization scheme was discussed in the original paper.
        
        In a similar fashion than for FGW templates,
        we initialize their Wasserstein templates with samples in a supervised fashion,
        by passing initial feature representations of input graphs through GNN layer.
        """
        assert len(list_Satoms)==self.Katoms
        
       
        # Handle Wasserstein templates
        self.sampled_structures = []
        self.Fbar = []
        self.hbar = []
        for layer in range((self.gin_layer_dict['num_layers'] + 1)):
            for S in list_Satoms:
                x = th.ones(S, dtype=self.dtype, device=self.device) /S
                x.requires_grad_(False)
                self.hbar.append(x)
        
        if 'sampling_supervised' in init_mode_atoms:
            print('init_mode_atoms = sampling_supervised' )
            
            shapes = th.tensor([C.shape[0] for C in graphs])
            unique_atom_shapes, counts = np.unique(list_Satoms,return_counts=True)
            unique_labels = th.unique(labels)
            idx_by_labels = [th.where(labels==label)[0] for label in unique_labels]
            
            for i,shape in enumerate(unique_atom_shapes):
                r = counts[i]
                perm = th.randperm(unique_labels.shape[0])
                count_by_label = r// unique_labels.shape[0]
                for i in perm:
                    i_ = i.item()
                    shapes_label = shapes[idx_by_labels[i_]]
                    shape_idx_by_label = th.where(shapes_label==shape)[0]
                    print('shape_idx_by_label (S=%s) / %s'%(shape, shape_idx_by_label))
                    if shape_idx_by_label.shape[0] == 0:
                        print('no samples for shape S=%s -- computing kmeans on features within the label'%shape)
                        stacked_features = th.cat(features)
                        print('stacked features shape:', stacked_features.shape)
                        km = KMeans(n_clusters = shape, init='k-means++', n_init=10, random_state = 0)
                        km.fit(stacked_features)
                        
                        F_clusters = th.tensor(km.cluster_centers_, dtype=self.dtype, device=self.device)
                        print('features from clusters:', F_clusters.shape)
                    try:
                        print('enough samples of shape =%s / within label  =%s'%(shape, i_))
                        sample_idx = np.random.choice(shape_idx_by_label.numpy(), size=min(r,count_by_label), replace=False)
                        for idx in sample_idx:
                            localC = graphs[idx_by_labels[i_][idx]].clone().to(self.device)
                            localF = features[idx_by_labels[i_][idx]].clone().to(self.device)
                            self.Fbar.append(localF)
                            self.sampled_structures.append(localC)
                            
                    except:
                        print(' not enough samples of shape =%s / within label  =%s'%(shape, i_))
                        try:
                            sample_idx = np.random.choice(shape_idx_by_label.numpy(), size=min(r,count_by_label), replace=True)
                            print('sample idx:', sample_idx)
                            for idx in sample_idx:
                                C = graphs[idx_by_labels[i_][idx]].clone().to(self.device)
                                F = features[idx_by_labels[i_][idx]].clone().to(self.device)
                                noise_distrib_F = th.distributions.normal.Normal(loc=th.zeros(F.shape[-1]), scale= F.std(axis=0) + 1e-05)
                                noise_F = noise_distrib_F.rsample(th.Size([shape])).to(self.device)
                                new_F = F + noise_F
                                self.Fbar.append(new_F)
                                self.sampled_structures.append(C)
                                
                        except:
                            print(' NO sample found with proper shape within label = %s'%i_)
                            raise 'provide an existing number of nodes within samples with label:%s'%i_
                            
                        continue
                            
                # Embed template features through GIN layers
                # stack features for the sampled initial templates at each GIN layers :
                with th.no_grad():
                    for idx_atom in range(self.Katoms):
                        F = self.Fbar[idx_atom].to(self.device)
                        processedC = self.sampled_structures[idx_atom].to(self.device)
                        edge_index = th.argwhere( processedC == 1.).T
                        list_embedded_features  = []
                        if not self.skip_first_features:
                            list_embedded_features.append(F)
                        embedded_features_layered = self.GIN_layers[0](x=F, edge_index=edge_index)
                        list_embedded_features.append(embedded_features_layered)
                        for local_layer in range(1, self.gin_layer_dict['num_layers']):
                            embedded_features_layered = self.GIN_layers[local_layer](x=embedded_features_layered, edge_index=edge_index)
                            list_embedded_features.append(embedded_features_layered)
                        embedded_features = th.cat(list_embedded_features, dim=1)
                        embedded_features.requires_grad_(True)
                        self.Fbar[idx_atom] = embedded_features
                    print('processed Fbar shapes:', [F.shape for F in self.Fbar])
        self.atoms_params = [*self.Fbar, *self.hbar]
        self.params =  self.atoms_params + list(self.GIN_layers.parameters()) + list(self.clf_Net.parameters())
        
        self.shape_atoms = [F.shape[0] for F in self.Fbar]
        print('---  model initialized  ---')
        print('Fbar dims:', [F.shape for F in self.Fbar])
        print('sanity check params (shape, requires grad, device):', [(param.shape, param.requires_grad, param.device) for param in self.params])
        print('GIN_layers:', self.GIN_layers.parameters)
        print('Clf:', self.clf_Net.parameters)
    
    
    def compute_pairwise_euclidean_distance(self, list_F, list_Fbar, detach = True):
        list_Mik = []
        dim_features = list_F[0].shape[-1]
        for F in list_F:
            list_Mi = []
            F2 = F**2
            ones_source = th.ones((F.shape[0], dim_features), dtype=self.dtype, device=self.device)
            for Fbar in list_Fbar:
                
                shape_atom = Fbar.shape[0]
                ones_target = th.ones((dim_features, shape_atom), dtype=self.dtype, device=self.device)
                first_term = F2 @ ones_target
                second_term = ones_source @ (Fbar**2).T
                Mi = first_term + second_term - 2* F @ Fbar.T
                list_Mi.append(Mi)
            list_Mik.append(list_Mi)
        return list_Mik
    
    def get_OT_by_input(self, F, h, list_Fbar, list_hbar):
        res = []
        dim_features = F.shape[-1]
        F2 = F**2
        ones_source = th.ones((F.shape[0], dim_features), dtype=self.dtype, device=self.device)
        
        for i in range(self.Katoms):
            shape_atom = list_Fbar[i].shape[0]
            ones_target = th.ones((dim_features, shape_atom), dtype=self.dtype, device=self.device)
            first_term = F2 @ ones_target
            second_term = ones_source @ (list_Fbar[i]**2).T
            Mi = first_term + second_term - 2* F @ list_Fbar[i].T
            res.append((ot.emd(h, list_hbar[i], Mi), Mi))
        return res
    
    

    def parallelized_get_features(self, list_F, list_h, list_Fbar, list_hbar, n_jobs=2):
      
        
        features = th.zeros((len(list_F), self.Katoms), dtype=self.dtype, device=self.device)
        with th.no_grad():
            if self.device == 'cpu':
                res_by_input = Parallel(n_jobs=n_jobs)(delayed(self.get_OT_by_input)(list_F[i], list_h[i], list_Fbar, list_hbar) for i in range(len(list_F)))
            else:
                raise 'GPU not handled yet'
        for idx_res in range(len(res_by_input)):    
            for idx_atom in range(self.Katoms):
                with th.no_grad():
                    OT, M_ik = res_by_input[idx_res][idx_atom]
                    sizes_scaling = OT.shape[0] * OT.shape[1]
                    
                    wass = (OT * M_ik).sum()
                    
                    gF = 2 * ( list_F[idx_res] * list_h[idx_res][:, None] - OT @ list_Fbar[idx_atom])
                    gFbar = 2 * ( list_Fbar[idx_atom] * list_hbar[idx_atom][:, None] - (OT.T) @ list_F[idx_res])
                    if self.sizes_scaling:
                        wass *= sizes_scaling
                        gF *= sizes_scaling
                        gFbar *= sizes_scaling
                wass = GW_utils.set_gradients(self.ValFunction, wass.to(self.device), 
                                             (list_F[idx_res], list_Fbar[idx_atom]), 
                                             (gF.to(self.device), gFbar.to(self.device)))
                
                features[idx_res, idx_atom] = wass
        return features
    
    
    def set_model_to_train(self):
        self.GIN_layers.train()
        self.clf_Net.train()
    
    def set_model_to_eval(self):
        self.GIN_layers.eval()
        self.clf_Net.eval()
        
    def GIN_forward(self, batch_graphs, batch_features, batch_shapes, cumsum_shapes):
        # GNN filters
        processed_batch_graphs = th.block_diag(*[C for C in batch_graphs])                 
        batch_edge_index = th.argwhere(processed_batch_graphs == 1.).T
        processed_batch_features = th.cat([F for F in batch_features])
        
        layered_batch_features = []
        if not self.skip_first_features:
            layered_batch_features.append(processed_batch_features)
    
        batch_embedded_features = self.GIN_layers[0](x=processed_batch_features, edge_index=batch_edge_index)
        
        layered_batch_features.append(batch_embedded_features)
            
        for layer in range(1, self.gin_layer_dict['num_layers']):
            batch_embedded_features = self.GIN_layers[layer](x=batch_embedded_features, edge_index=batch_edge_index)
            layered_batch_features.append(batch_embedded_features)
        # for a given node, we concatenate embeddings generated at each k GIN layer inducing a k-hop smoothing
        batch_embedded_features = th.cat(layered_batch_features, dim = 1)                                        
        batch_embedded_features_uncat =  [batch_embedded_features[cumsum_shapes[k] : cumsum_shapes[k + 1], :] for k in range(len(batch_shapes))]
        
        return batch_embedded_features_uncat

    def fit(self, 
            model_name:str, 
            X_train:list, F_train:list, y_train:list, 
            X_val:list, F_val:list, y_val:list, 
            X_test:list, F_test:list, y_test:list,
            lr:float, batch_size:int, supervised_sampler:bool, 
            epochs:int, val_timestamp:int, algo_seed:int, use_lrschedule:bool,
            track_templates:bool=False, verbose:bool=False, n_jobs:int=None):
        th.manual_seed(algo_seed)
        np.random.seed(algo_seed)
        if not n_jobs is None:
            self.ValFunction = GW_utils.ValFunction

        h_train = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_train]
        n_train = y_train.shape[0]
        
        h_val = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_val]
        h_test = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_test]
        
        sets = ['train', 'val', 'test']
        best_val_acc = - np.inf
        best_val_acc_train_acc = - np.inf
            
        self.log = {'train_cumulated_batch_loss':[]}
        for metric in list(self.classification_metrics.keys()) + ['epoch_loss']:
            for s in sets:
                self.log['%s_%s'%(s, metric)]=[]
        
        if track_templates:
            self.log_templates = {}
            for label in range(self.n_labels):
                for s in sets:
                    self.log_templates['%s_mean_dists_label%s'%(s, label)] = []
                    self.log_templates['%s_std_dists_label%s'%(s, label)] = []
                    
        
        self.optimizer = th.optim.Adam(params=self.params, lr=lr, betas=[0.9, 0.99])
        if use_lrschedule:
            self.scheduler = th.optim.lr_scheduler.StepLR(self.optimizer, step_size=50, gamma=0.5)

        if n_train <= batch_size:
            batch_by_epoch = 1
            batch_size = n_train
            print('batch size bigger than #samples > batch_size set to ', batch_size)
        else:
            batch_by_epoch = n_train//batch_size +1
        y_train_ = y_train.detach().cpu()
        if supervised_sampler:
            unique_labels = th.unique(y_train_)
            n_labels = unique_labels.shape[0]
            train_idx_by_labels = [th.where(y_train_==label)[0] for label in unique_labels]
            labels_by_batch = batch_size // n_labels
        
        self.set_model_to_train()
        for e in tqdm(range(epochs), desc='epochs'):
            
            cumulated_batch_loss=0.
            #for batch_i in tqdm(range(batch_by_epoch),desc='epoch %s'%e):
            for batch_i in range(batch_by_epoch):
                
                self.optimizer.zero_grad()
                if not supervised_sampler:
                    batch_idx = np.random.choice(range(n_train), size=batch_size, replace=False)
                else:
                    r = batch_size
                    for idx_label,label  in enumerate(np.random.permutation(unique_labels)):
                        
                        local_batch_idx = np.random.choice(train_idx_by_labels[label], size=min(r,labels_by_batch), replace=False)
                        if idx_label ==0:
                            batch_idx = local_batch_idx
                        else:
                            batch_idx = np.concatenate([batch_idx,local_batch_idx])
                        r -= labels_by_batch

                batch_graphs = [X_train[idx] for idx in batch_idx]
                batch_shapes = [C.shape[0] for C in batch_graphs]
                
                cumsum_shapes = th.tensor([0]+batch_shapes).cumsum(dim=0)
                batch_features = [F_train[idx] for idx in batch_idx]                
                batch_masses = [h_train[idx] for idx in batch_idx]
                
                batch_embedded_features_uncat =  self.GIN_forward(batch_graphs, batch_features, batch_shapes, cumsum_shapes)
                dist_features = self.parallelized_get_features(
                    batch_embedded_features_uncat, batch_masses, 
                    self.Fbar, self.hbar, n_jobs=n_jobs)
                
                batch_pred = self.clf_Net(dist_features)
                batch_loss = self.loss(batch_pred, y_train[batch_idx])
                cumulated_batch_loss += batch_loss.item()
                batch_loss.backward()
                self.optimizer.step()
                
            self.log['train_cumulated_batch_loss'].append(cumulated_batch_loss)
            if use_lrschedule:
                self.scheduler.step()
                #print('current lr with scheduler:', self.scheduler.get_last_lr())
            if (((e %val_timestamp) ==0) and e>0) or (e == (epochs - 1)):
                with th.no_grad():
                    self.set_model_to_eval()
                    
                    if self.device== 'cpu':
                        features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_fullbatch(X_train, F_train, h_train, y_train, n_jobs)
                        features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_fullbatch(X_val, F_val, h_val, y_val, n_jobs)
                        features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_fullbatch(X_test, F_test, h_test, y_test, n_jobs)
                    else:
                        features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_minibatch(X_train, F_train, h_train, y_train, n_jobs, batch_size=128)
                        features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_minibatch(X_val, F_val, h_val, y_val, n_jobs, batch_size=128)
                        features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_minibatch(X_test, F_test, h_test, y_test, n_jobs, batch_size=128)
                    
                    if verbose:
                        print('epoch= %s / loss_train = %s / res_train = %s / loss_val =%s/ res_val =%s'%(e,loss_train.item(),res_train,loss_val.item(),res_val))
                    self.log['train_epoch_loss'].append(loss_train.item())
                    self.log['val_epoch_loss'].append(loss_val.item())
                    self.log['test_epoch_loss'].append(loss_test.item())
                    for metric in self.classification_metrics.keys():
                        self.log['train_%s'%metric].append(res_train[metric])
                        self.log['val_%s'%metric].append(res_val[metric])
                        self.log['test_%s'%metric].append(res_test[metric])
                    
                    str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                    pickle.dump(self.log, open(str_log,'wb'))
                    if best_val_acc <= res_val['accuracy']: #Save model with best val acc assuring increase of train acc
                        
                        best_val_acc = res_val['accuracy']
                        if best_val_acc_train_acc <= res_train['accuracy']:
                            best_val_acc_train_acc = res_train['accuracy']
                            str_file = self.experiment_repo+'/%s_best_val_accuracy_increasing_train_accuracy.pkl'%model_name
                            full_dict_state = {'epoch' : e,
                                               'GIN_params':self.GIN_layers.state_dict(),
                                               'clf_params':self.clf_Net.state_dict(),
                                               'atoms_params':[param.clone().detach().cpu() for param in self.atoms_params]
                                               }
                            pickle.dump(full_dict_state, open(str_file, 'wb'))
                                                   
                    if track_templates:
                        for label in th.unique(y_train):
                            label_ = int(label.item())
                            idx_train = th.where(y_train==label)[0]
                            subfeatures_train = features_train[idx_train]
                            idx_val = th.where(y_val==label)[0]
                            subfeatures_val = features_val[idx_val]
                            self.log_templates['train_mean_dists_label%s'%label_].append(subfeatures_train.mean(axis=0))
                            self.log_templates['train_std_dists_label%s'%label_].append(subfeatures_train.std(axis=0))
                            self.log_templates['val_mean_dists_label%s'%label_].append(subfeatures_val.mean(axis=0))
                            self.log_templates['val_std_dists_label%s'%label_].append(subfeatures_val.std(axis=0))

                            if verbose:
                                print('[TRAIN] label = %s / mean dists: %s / std dists: %s'%(label,self.log_templates['train_mean_dists_label%s'%label_][-1], self.log_templates['train_std_dists_label%s'%label_][-1]))
                            #print('[VAL] label = %s / mean dists: %s / std dists: %s'%(label,self.log_weights['val_mean_dists_label%s'%label_][-1], self.log_weights['val_std_dists_label%s'%label_][-1]))
                        str_log = self.experiment_repo+'/%s_tracking_templates_log.pkl'%model_name
                        pickle.dump(self.log_templates, open(str_log,'wb'))
                    
                # after evaluation, make the model trainable again
                self.set_model_to_train()
                
    def evaluate_fullbatch(self, list_C:list, list_F:list, list_h:list, list_y:list, n_jobs:int=None):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        
        with th.no_grad():
            # Get dist features from embedded graphs using GIN layers
            batch_shapes = [C.shape[0] for C in list_C]
            cumsum_shapes = th.tensor([0] + batch_shapes).cumsum(dim=0)
            
            batch_embedded_features_uncat =  self.GIN_forward(list_C, list_F, batch_shapes, cumsum_shapes)
            dist_features = self.parallelized_get_features(
                batch_embedded_features_uncat, list_h, 
                self.Fbar, self.hbar, n_jobs=n_jobs)
            
            pred = self.clf_Net(dist_features)
            loss = self.loss(pred, list_y)
            y_pred = pred.argmax(1)
            y_ = list_y.detach().numpy()
            y_pred_ = y_pred.detach().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return dist_features, pred, y_pred, loss, res
    
    def evaluate_minibatch(self, full_list_C:list, full_list_F:list, full_list_h:list, full_list_y:list, n_jobs:int, batch_size:int=128):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        
        with th.no_grad():
            full_embedded_features = []
            len_ = len(full_list_C)
            n_splits = len_ // batch_size + 1
            full_idx = np.arange(len_)
            for k in range(n_splits):
                idx_batch = full_idx[k * batch_size : (k + 1) * batch_size]
                #print('idx_batch:', idx_batch[0], idx_batch[-1])
                batch_graphs = [full_list_C[idx] for idx in idx_batch]
                batch_features = [full_list_F[idx] for idx in idx_batch]
                batch_masses = [full_list_h[idx] for idx in idx_batch]
                batch_shapes = [C.shape[0] for C in batch_graphs]
                cumsum_shapes = th.tensor([0] + batch_shapes).cumsum(dim=0)
                batch_embedded_features_uncat = self.GIN_forward(batch_graphs, batch_features, batch_shapes, cumsum_shapes)
                dist_features = self.parallelized_get_features(
                    batch_embedded_features_uncat, batch_masses, 
                    self.Fbar, self.hbar, n_jobs=n_jobs)
                full_embedded_features.append(dist_features)
            full_embedded_features = th.cat(full_embedded_features, dim=0)                              
            pred = self.clf_Net(full_embedded_features)
            loss = self.loss(pred, full_list_y)
            y_pred = pred.argmax(1)
            y_ = full_list_y.detach().cpu().numpy()
            y_pred_ = y_pred.detach().cpu().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return full_embedded_features, pred, y_pred, loss, res

    def load(self, model_name:str, dtype:type=th.float64):
        str_file = '%s/%s_best_val_accuracy_increasing_train_accuracy.pkl'%(self.experiment_repo, model_name)
        full_dict_state = pickle.load(open(str_file, 'rb'))
        self.Fbar = []
        self.hbar = []
        self.ValFunction = GW_utils.ValFunction
        for p in full_dict_state['atoms_params'][: self.Katoms]:  # atoms structure
            F = p.clone()
            F.requires_grad_(False)
            self.Fbar.append(F)
        self.dim_features = self.Fbar[-1].shape[-1]
        self.shape_atoms = [F.shape[0] for F in self.Fbar]
        for s in self.shape_atoms:
            h = th.ones(s, dtype=self.dtype)/s
            self.hbar.append(h)
        self.clf_Net.load_state_dict(full_dict_state['clf_params'])
        self.GIN_layers.load_state_dict(full_dict_state['GIN_params'])
        print('[SUCCESSFULLY LOADED] ',str_file)
    