import numpy as np
import torch as th
import os
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import pickle
import utils_networks
import GW_utils 
from joblib import Parallel, delayed
from sklearn.cluster import KMeans
#%%

class FGWmachine():
    
    def __init__(self, 
                 Katoms:int, 
                 n_labels:int, 
                 alpha:float, 
                 learn_hbar:bool,
                 experiment_repo:str, 
                 clf_net_dict:dict,
                 alpha_init:float = None,
                 dtype = th.float64,
                 device='cpu'):
        """
        Implementation of our TFGW model (learned or fixed, see section 3.1 and 3.3)
        
        Parameters
        ----------
        Katoms : int
            number of templates in the Wasserstein layer
        n_labels : int
            number of classes in the dataset.
        alpha : float
            trade-off parameter for FGW.
            alpha == -1: learn it
            otherwise the method support a fix alpha between 0 and 1.
        learn_hbar: bool
            either to learn the weights of the templates or not.
        experiment_repo : str
            repository to save the experiment during training.
        clf_net_dict : dict
            Dictionary containing parameters for the MLP leading to label prediction
            Must contain the keys
                'hidden_dim' :(int) dimension of the hidden layer (fixed to 128)
                'num_hidden' :(int) number of hidden layers in the architecture
                'dropout' :(float) dropout rate to use
        alpha_init: float
            if set to None, alpha is initialized at 0.5, otherwise it will be initialized using the provided value in alpha_init.
        dtype : TYPE, optional
            DESCRIPTION. The default is th.float64.
        device : TYPE, optional
            DESCRIPTION. The default is 'cpu'.
        """
        assert np.all( [s in clf_net_dict.keys() for s in ['hidden_dim', 'num_hidden', 'dropout']] )
            
        self.Katoms = Katoms
        self.n_labels = n_labels
        self.device = device
        self.clf_net_dict = clf_net_dict
        self.classification_metrics = {'accuracy' : (lambda y_true,y_pred : accuracy_score(y_true,y_pred))}
        
        self.experiment_repo = experiment_repo
        if not os.path.exists(self.experiment_repo):
            os.makedirs(self.experiment_repo)
        self.Cbar, self.Fbar, self.hbar = None, None, None
        self.dtype = dtype
        
        if clf_net_dict['dropout'] != 0.:
            self.clf_Net = utils_networks.MLP_dropout(
                input_dim = self.Katoms, 
                output_dim = self.n_labels,
                hidden_dim = clf_net_dict['hidden_dim'],
                num_hidden = clf_net_dict['num_hidden'],
                output_activation= 'linear', 
                dropout = clf_net_dict['dropout'],
                skip_first = True,
                device=self.device, dtype=self.dtype)
        else:
            self.clf_Net= utils_networks.MLP(
                  input_dim = self.Katoms, 
                  output_dim = self.n_labels,
                  hidden_dim = clf_net_dict['hidden_dim'],
                  num_hidden = clf_net_dict['num_hidden'],
                  output_activation= 'linear', 
                  device=self.device, dtype=self.dtype)
        
        self.learn_hbar = learn_hbar
        
        if alpha == -1:
            if alpha_init is None:
                self.alpha = th.tensor([0.5], requires_grad=True, dtype=dtype, device=self.device)
            else:
                self.alpha = th.tensor([alpha_init], requires_grad=True, dtype=dtype, device=self.device)

            self.alpha_mode = 'learnable_shared'
        elif (alpha >= 0) and (alpha <= 1):
            self.alpha = th.tensor([alpha], requires_grad=False, dtype=dtype, device=self.device)
            self.alpha_mode = 'fixed'
        self.loss = th.nn.CrossEntropyLoss().to(self.device)
        
        
    def init_parameters(self,list_Satoms:list, init_mode_atoms:str='sampling_supervised', graphs:list=None, features:list=None, labels:list=None, 
                        atoms_projection:str='clipped', verbose:bool=False):
        assert len(list_Satoms)==self.Katoms
        self.Cbar = []
        self.Fbar = []
        self.hbar = []
        for S in list_Satoms:
            x = th.ones(S, dtype=self.dtype, device=self.device)
            x /= S
            x.requires_grad_(self.learn_hbar)
            self.hbar.append(x)
       
        if init_mode_atoms =='sampling_supervised': 
            print('init_mode_atoms = sampling_supervised' )
            # If not enough samples have the required shape within a label
            # we get the samples within the label and create perturbated versions of observed graphs
            shapes = th.tensor([C.shape[0] for C in graphs])
            unique_atom_shapes, counts = np.unique(list_Satoms,return_counts=True)
            unique_labels = th.unique(labels)
            idx_by_labels = [th.where(labels==label)[0] for label in unique_labels]
            
            for i,shape in enumerate(unique_atom_shapes):
                #print('shape: %s / required_samples: %s/ occurrences in the dataset =%s  '%(shape,counts[i],len(shape_idx)))
                r = counts[i]
                print('counts[i]:', r)
                perm = th.randperm(unique_labels.shape[0])
                print('perm:', perm)
                count_by_label = r// unique_labels.shape[0]
                print('count by label :', count_by_label)
                for i in perm:
                    i_ = i.item()
                    shapes_label = shapes[idx_by_labels[i_]]
                    shape_idx_by_label = th.where(shapes_label==shape)[0]
                    print('shape_idx_by_label (S=%s) / %s'%(shape, shape_idx_by_label))
                    if shape_idx_by_label.shape[0] == 0:
                        print('no samples for shape S=%s -- computing kmeans on features within the label'%shape)
                        stacked_features = th.cat(features)
                        print('stacked features shape:', stacked_features.shape)
                        km = KMeans(n_clusters = shape, init='k-means++', n_init=10, random_state = 0)
                        km.fit(stacked_features)
                        
                        F_clusters = th.tensor(km.cluster_centers_, dtype=self.dtype, device=self.device)
                        print('features from clusters:', F_clusters.shape)
                    try:
                        print('enough samples of shape =%s / within label  =%s'%(shape, i_))
                        sample_idx = np.random.choice(shape_idx_by_label.numpy(), size=min(r,count_by_label), replace=False)
                        print('sample_idx:', sample_idx)
                        for idx in sample_idx:
                            localC = graphs[idx_by_labels[i_][idx]].clone().to(self.device)
                            localC.requires_grad_(True)
                            localF = features[idx_by_labels[i_][idx]].clone().to(self.device)
                            localF.requires_grad_(True)
                            self.Cbar.append(localC)      
                            self.Fbar.append(localF)
                    except:
                        print(' not enough samples of shape =%s / within label  =%s'%(shape, i_))
                        try:
                            sample_idx = np.random.choice(shape_idx_by_label.numpy(), size=min(r,count_by_label), replace=True)
                            print('sample idx:', sample_idx)
                            for idx in sample_idx:
                                C = graphs[idx_by_labels[i_][idx]].clone().to(self.device)
                                noise_distrib_C = th.distributions.normal.Normal(loc=0., scale= C.std() + 1e-05)
                                noise_C = noise_distrib_C.rsample(C.size()).to(self.device)
                                if th.all(C == C.T):
                                    noise_C = (noise_C + noise_C)/2
                                F = features[idx_by_labels[i_][idx]].clone().to(self.device)
                                noise_distrib_F = th.distributions.normal.Normal(loc=th.zeros(F.shape[-1]), scale= F.std(axis=0) + 1e-05)
                                noise_F = noise_distrib_F.rsample(th.Size([shape])).to(self.device)
                                new_C = C + noise_C
                                new_C.requires_grad_(True)                                             
                                new_F = F + noise_F
                                new_F.requires_grad_(True)
                                self.Cbar.append(new_C)
                                self.Fbar.append(new_F)
                        except:
                            print(' NO sample found with proper shape within label = %s'%i_)
                            # We generate a random graph from the distribution of graphs within this label
                            # Randomly over C
                            # by kmeans over F
                            local_idx = idx_by_labels[i_]
                            means_C = th.tensor([graphs[idx].mean() for idx in local_idx])
                            distrib_C = th.distributions.normal.Normal(loc= means_C.mean(), scale= means_C.std()+ 1e-05)
                            for _ in range(min(r,count_by_label)):
                                C = distrib_C.rsample((shape, shape)).to(self.device)
                                C = (C + C.T) / 2.
                                if atoms_projection == 'clipped':
                                    new_C = ( C - C.min())/ (C.max() - C.min())
                                elif atoms_projection == 'nonnegative':
                                    new_C = C.clamp(0)
                                new_C.requires_grad_(True)                                             
                                self.Cbar.append(new_C)
                                new_F = F_clusters.clone().to(self.device)
                                new_F.requires_grad_(True)
                                self.Fbar.append(new_F)
                            
                        continue
                        #print('asked label: %s / found label : %s' %(unique_labels[perm][i],labels[idx_by_labels[i_][idx]]))
        
        else:
            raise "only init_mode_atoms in ['random', 'sampling_supervised'] supported"
        
        self.clf_params = list(self.clf_Net.parameters())
        self.atoms_params = [*self.Cbar, *self.Fbar]
        if self.learn_hbar:
            self.atoms_params += [*self.hbar]
        if 'learnable' in self.alpha_mode:
            self.atoms_params += [self.alpha]
        self.params = self.atoms_params +  self.clf_params
        self.dim_features = self.Fbar[0].shape[-1]
        self.shape_atoms = [C.shape[0] for C in self.Cbar]
        print('---  model initialized  ---')
        #print('sanity check params (shape, requires grad):', [(param.shape, param.requires_grad) for param in self.params])
        print('clf:', self.clf_Net.parameters)
        
    
    def get_pairwise_distance(self, Cs, Fs, hs, Ct, Ft, ht):
        Fs_2 = Fs**2
        ones_s = th.ones((Fs.shape[0], self.dim_features), dtype=self.dtype, device=self.device)
        
        first_term = Fs_2 @ th.ones((self.dim_features, Ft.shape[0]), dtype=self.dtype, device=self.device)
        second_term = ones_s @ (Ft**2).T
        M = first_term + second_term - 2* Fs @ Ft.T
    
        return GW_utils.parallel_fused_gromov_wasserstein2_learnablealpha(C1=Cs, C2=Ct, F1=Fs, F2=Ft, M=M, p=hs, q=ht, alpha=self.alpha[0], learn_alpha=False)
        
        
    def get_features_by_input(self, G, F, h, list_Cbar, list_Fbar, list_hbar, compute_gradients=True):
        res = []
        
        F2 = F**2
        ones_source = th.ones((F.shape[0], self.dim_features), dtype=self.dtype, device=self.device)
        if self.alpha_mode == 'learnable_shared':
            for i in range(self.Katoms):
                first_term = F2 @ th.ones((self.dim_features, list_Fbar[i].shape[0]), dtype=self.dtype, device=self.device)
                second_term = ones_source @ (list_Fbar[i]**2).T
                Mi = first_term + second_term - 2* F @ list_Fbar[i].T
            
                res.append(GW_utils.parallel_fused_gromov_wasserstein2_learnablealpha(C1=G, C2=list_Cbar[i], F1=F, F2=list_Fbar[i], M=Mi, p=h, q=list_hbar[i], alpha=self.alpha[0], learn_alpha=True, compute_gradients=compute_gradients ))
        elif self.alpha_mode == 'fixed':
            for i in range(self.Katoms):
                first_term = F2 @ th.ones((self.dim_features, list_Fbar[i].shape[0]), dtype=self.dtype, device=self.device)
                second_term = ones_source @ (list_Fbar[i]**2).T
                Mi = first_term + second_term - 2* F @ list_Fbar[i].T
            
                res.append(GW_utils.parallel_fused_gromov_wasserstein2_learnablealpha(C1=G, C2=list_Cbar[i], F1=F, F2=list_Fbar[i], M=Mi, p=h, q=list_hbar[i], alpha=self.alpha[0], learn_alpha=False, compute_gradients=compute_gradients))
        
        return res

    def parallelized_get_features(self,list_G, list_F, list_h, list_Cbar, list_Fbar, list_hbar, n_jobs=2, compute_gradients=True):
        
        
        features = th.zeros((len(list_G),self.Katoms), dtype=self.dtype, device=self.device)
        with th.no_grad():
            res_by_input = Parallel(n_jobs=n_jobs)(delayed(self.get_features_by_input)(list_G[i], list_F[i], list_h[i], list_Cbar, list_Fbar, list_hbar, compute_gradients) for i in range(len(list_G)))
        if compute_gradients:
            if self.alpha_mode =='learnable_shared':
                
                for idx_res in range(len(res_by_input)):    
                    for idx_atom in range(self.Katoms):
                        
                        fgw, gh, ghbar, gC, gCbar, gF, gFbar, galpha = res_by_input[idx_res][idx_atom]
                        fgw = GW_utils.set_gradients(self.ValFunction, fgw, 
                                                     (list_h[idx_res], list_hbar[idx_atom], list_G[idx_res], list_Cbar[idx_atom], list_F[idx_res], list_Fbar[idx_atom], self.alpha[0]), 
                                                     (gh, ghbar, gC, gCbar, gF, gFbar, galpha))
                        features[idx_res, idx_atom] = fgw
            elif self.alpha_mode == 'fixed':
                
                for idx_res in range(len(res_by_input)):    
                    for idx_atom in range(self.Katoms):
                        
                        fgw, gh, ghbar, gC, gCbar, gF, gFbar, _ = res_by_input[idx_res][idx_atom]
                        fgw = GW_utils.set_gradients(self.ValFunction, fgw, (list_h[idx_res], list_hbar[idx_atom], list_G[idx_res], list_Cbar[idx_atom], list_F[idx_res], list_Fbar[idx_atom]), 
                                                     (gh, ghbar, gC, gCbar, gF, gFbar))
                        features[idx_res, idx_atom] = fgw
        else:
            for idx_res in range(len(res_by_input)):    
                for idx_atom in range(self.Katoms):
                    fgw= res_by_input[idx_res][idx_atom]
                    
                    features[idx_res, idx_atom] = fgw

        return features
    
    def set_model_to_train(self):
        self.clf_Net.train()
    
    def set_model_to_eval(self):
        self.clf_Net.eval()
    
    def fit(self, model_name:str, 
            X_train:list, F_train:list, y_train:list, 
            X_val:list, F_val:list, y_val:list, 
            X_test:list, F_test:list, y_test:list,
            atoms_projection:str, lr:float, batch_size:int, supervised_sampler:bool, 
            epochs:int, val_timestamp:int, algo_seed:int, track_templates:bool=False, verbose:bool=False, n_jobs:int=2):
        th.manual_seed(algo_seed)
        np.random.seed(algo_seed)
       
        self.ValFunction = GW_utils.ValFunction
        
        h_train = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_train]
        n_train = y_train.shape[0]
        best_train_acc = - np.inf
        best_train_loss = np.inf  
        
        if not (X_val is None):
            h_val = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_val]
            h_test = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_test]
            do_validation = True
            sets = ['train', 'val', 'test']
            best_val_acc = - np.inf
            best_val_acc_train_acc = - np.inf
        else:
            do_validation = False
            sets = ['train']
        self.log = {'train_cumulated_batch_loss':[], 'train_epoch_loss':[], 'val_epoch_loss':[], 'test_epoch_loss':[] }        
        for metric in self.classification_metrics.keys():
            for s in sets:
                self.log['%s_%s'%(s, metric)]=[]
                
        if track_templates:
            self.log_templates = {}
            for label in range(self.n_labels):
                for s in sets:
                    self.log_templates['%s_mean_dists_label%s'%(s, label)] = []
                    self.log_templates['%s_std_dists_label%s'%(s, label)] = []
        
        self.optimizer = th.optim.Adam(params=self.params, lr=lr, betas=[0.9, 0.99])
        print('alpha requires grad : ', self.alpha.requires_grad)        
        if supervised_sampler:
            unique_labels = th.unique(y_train)
            n_labels = unique_labels.shape[0]
            train_idx_by_labels = [th.where(y_train==label)[0] for label in unique_labels]
            labels_by_batch = batch_size // n_labels
        batch_by_epoch = n_train//batch_size +1
            
        self.set_model_to_train()
        for e in tqdm(range(epochs), desc='epochs'):
            cumulated_batch_loss=0.
            for batch_i in range(batch_by_epoch):
                
                self.optimizer.zero_grad()
                if not supervised_sampler:
                    batch_idx = np.random.choice(range(n_train), size=batch_size, replace=False)
                else:
                    r = batch_size
                    for idx_label,label in enumerate(np.random.permutation(unique_labels)):
                        
                        local_batch_idx = np.random.choice(train_idx_by_labels[label], size=min(r,labels_by_batch), replace=False)
                        if idx_label ==0:
                            batch_idx = local_batch_idx
                        else:
                            batch_idx = np.concatenate([batch_idx,local_batch_idx])
                        r -= labels_by_batch

                batch_graphs = [X_train[idx] for idx in batch_idx]
                batch_features = [F_train[idx] for idx in batch_idx]                
                batch_masses = [h_train[idx] for idx in batch_idx]
                features = self.parallelized_get_features(batch_graphs, batch_features, batch_masses, self.Cbar, self.Fbar, self.hbar, n_jobs, compute_gradients=True)
                batch_pred = self.clf_Net(features)
                batch_loss = self.loss(batch_pred, y_train[batch_idx])
                
                cumulated_batch_loss += batch_loss.item()
                batch_loss.backward()
                
                self.optimizer.step()        
                with th.no_grad():    
                    for C in self.Cbar:
                        if atoms_projection =='clipped':
                            C[:] = C.clamp(min=0,max=1)
                        else:
                            C[:] = C.clamp(min=0)

                    if self.alpha.requires_grad:
                        self.alpha[:] = self.alpha.clamp(min=0., max=1.)
                        
                    if self.learn_hbar:
                        for h in self.hbar:
                            h[:] = GW_utils.probability_simplex_projection(h)
            self.log['train_cumulated_batch_loss'].append(cumulated_batch_loss)
            
            if (((e %val_timestamp) ==0) and e>0) or (e == (epochs - 1)):
                with th.no_grad():
                    self.set_model_to_eval()
                    # prune templates at test time

                    if self.learn_hbar:
                        pruned_Cbar, pruned_Fbar, pruned_hbar = self.prune_templates()
                    else:
                        # unchanged
                        pruned_Cbar, pruned_Fbar, pruned_hbar = self.Cbar, self.Fbar, self.hbar
                    
                    features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate(X_train, F_train, h_train, y_train, pruned_Cbar, pruned_Fbar, pruned_hbar, n_jobs)
                    
                    if do_validation:
                        features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate(X_val, F_val, h_val, y_val, pruned_Cbar, pruned_Fbar, pruned_hbar, n_jobs)
                        features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate(X_test, F_test, h_test, y_test, pruned_Cbar, pruned_Fbar, pruned_hbar, n_jobs)
                        
                        if verbose:
                            if self.learn_hbar:
                                print('hbar:', self.hbar)
                            if 'learnable' in self.alpha_mode:
                                print('alpha: %s / requires_grad: %s'%(self.alpha, self.alpha.requires_grad))
                            print('epoch= %s / loss_train = %s / res_train = %s / loss_val =%s/ res_val =%s'%(e,loss_train.item(),res_train,loss_val.item(),res_val))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        self.log['val_epoch_loss'].append(loss_val.item())
                        self.log['test_epoch_loss'].append(loss_test.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            self.log['val_%s'%metric].append(res_val[metric])
                            self.log['test_%s'%metric].append(res_test[metric])
                        
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_val_acc <= res_val['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_val_acc = res_val['accuracy']
                            if best_val_acc_train_acc <= res_train['accuracy']:
                                best_val_acc_train_acc = res_train['accuracy']
                                str_file = self.experiment_repo+'/%s_best_val_accuracy_increasing_train_accuracy.pkl'%model_name
                                full_dict_state = {
                                    'epoch' : e, 
                                    'atoms_params': [param.clone().detach().cpu() for param in self.atoms_params],
                                    'clf_params': self.clf_Net.state_dict()}
                                pickle.dump(full_dict_state, open(str_file, 'wb'))
                                                   
                        if track_templates:
                            
                            for label in th.unique(y_train):
                                label_ = int(label.item())
                                idx_train = th.where(y_train==label)
                                subfeatures_train = features_train[idx_train]
                                idx_val = th.where(y_val==label)
                                subfeatures_val = features_val[idx_val]
                                self.log_templates['train_mean_dists_label%s'%label_].append(subfeatures_train.mean(axis=0))
                                self.log_templates['train_std_dists_label%s'%label_].append(subfeatures_train.std(axis=0))
                                self.log_templates['val_mean_dists_label%s'%label_].append(subfeatures_val.mean(axis=0))
                                self.log_templates['val_std_dists_label%s'%label_].append(subfeatures_val.std(axis=0))
    
                                if verbose:
                                    print('[TRAIN] label = %s / mean dists: %s / std dists: %s'%(label,self.log_templates['train_mean_dists_label%s'%label_][-1], self.log_templates['train_std_dists_label%s'%label_][-1]))
                                #print('[VAL] label = %s / mean dists: %s / std dists: %s'%(label,self.log_weights['val_mean_dists_label%s'%label_][-1], self.log_weights['val_std_dists_label%s'%label_][-1]))
                            str_log = self.experiment_repo+'/%s_tracking_templates_log.pkl'%model_name
                            pickle.dump(self.log_templates, open(str_log,'wb'))
                    else:
                        if verbose:
                            if self.learn_hbar:
                                print('hbar:', self.hbar)
                            if 'learnable' in self.alpha_mode:
                                print('alpha: %s / requires_grad: %s'%(self.alpha, self.alpha.requires_grad))
                            print('epoch= %s / loss_train = %s / res_train = %s '%(e, loss_train.item(), res_train))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_train_acc <= res_train['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_train_acc = res_train['accuracy']
                            
                            str_file = self.experiment_repo+'/%s_best_train_accuracy.pkl'%model_name
                            full_dict_state = {
                                'epoch' : e, 
                                'atoms_params': [param.clone().detach().cpu() for param in self.atoms_params],
                                'clf_params': self.clf_Net.state_dict()}
                            pickle.dump(full_dict_state, open(str_file, 'wb'))
                                                   
                        if track_templates:
                            
                            for label in th.unique(y_train):
                                label_ = int(label.item())
                                idx_train = th.where(y_train==label)
                                
                                subfeatures_train = features_train[idx_train]
                                self.log_templates['train_mean_dists_label%s'%label_].append(subfeatures_train.mean(axis=0))
                                self.log_templates['train_std_dists_label%s'%label_].append(subfeatures_train.std(axis=0))
                                if verbose:
                                    print('[TRAIN] label = %s / mean dists: %s / std dists: %s'%(label,self.log_templates['train_mean_dists_label%s'%label_][-1], self.log_templates['train_std_dists_label%s'%label_][-1]))
                                
                                #print('[VAL] label = %s / mean dists: %s / std dists: %s'%(label,self.log_weights['val_mean_dists_label%s'%label_][-1], self.log_weights['val_std_dists_label%s'%label_][-1]))
                            str_log = self.experiment_repo+'/%s_tracking_templates_log.pkl'%model_name
                            pickle.dump(self.log_templates, open(str_log,'wb'))  
                # after evaluation, make the model trainable again
                self.set_model_to_train()
                
    def evaluate(self, X:list, F:list, h:list, y:list, pruned_Cbar:list, pruned_Fbar:list, pruned_hbar:list, n_jobs:int=None):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        with th.no_grad():
            dist_features = self.parallelized_get_features(X, F, h, pruned_Cbar, pruned_Fbar, pruned_hbar, n_jobs, compute_gradients=False)
            pred = self.clf_Net(dist_features)
            loss = self.loss(pred, y)
            y_pred = pred.argmax(1)
            y_ = y.detach().numpy()
            y_pred_ = y_pred.detach().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return dist_features, pred, y_pred, loss, res
        
    def evaluate_templates(self,):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        with th.no_grad():
            dist_features = th.zeros((self.Katoms, self.Katoms), dtype=self.dtype, device=self.device)
            for i in range(self.Katoms -1):
                for j in range(i+1, self.Katoms):
                    res = self.get_pairwise_distance(self.Cbar[i], self.Fbar[i], self.hbar[i], self.Cbar[j], self.Fbar[j], self.hbar[j])
                    dist_features[i, j] = res[0]
                    dist_features[j, i] = res[0]
                    
            pred = self.clf_Net(dist_features)
            y_pred = pred.argmax(1)
            return dist_features, pred, y_pred
        
    def prune_templates(self):
        # remove nodes with zero weights at test time 
        pruned_hbar = []
        pruned_Cbar = []
        pruned_Fbar = []
        for i, h in enumerate(self.hbar):
            nonzero_idx = th.argwhere(h>0)[:, 0]
            pruned_hbar.append(h[nonzero_idx])
            pruned_Cbar.append(self.Cbar[i][nonzero_idx, :][:, nonzero_idx])
            pruned_Fbar.append(self.Fbar[i][nonzero_idx,:])
        return pruned_Cbar, pruned_Fbar, pruned_hbar
    
    def load(self, model_name):
        str_file = '%s/%s_best_val_accuracy_increasing_train_accuracy.pkl'%(self.experiment_repo, model_name)
        full_dict_state = pickle.load(open(str_file, 'rb'))
        self.Cbar = []
        self.Fbar = []
        self.hbar = []
        self.ValFunction = GW_utils.ValFunction
        
        for p in full_dict_state['atoms_params'][:self.Katoms]:  # atoms structure
            C = p.clone()
            C.requires_grad_(False)
            self.Cbar.append(C)
        for p in full_dict_state['atoms_params'][self.Katoms : 2 * self.Katoms]:  # atoms structure
            F = p.clone()
            F.requires_grad_(False)
            self.Fbar.append(F)
        self.dim_features = self.Fbar[-1].shape[-1]
        self.shape_atoms = [C.shape[0] for C in self.Cbar]
        if self.learn_hbar:
            for p in full_dict_state['atoms_params'][2 * self.Katoms : 3 * self.Katoms]:  # atoms structure
                h = p.clone()
                h.requires_grad_(False)
                self.hbar.append(h)
        else:
            for s in self.shape_atoms:
                h = th.ones(s, dtype=self.dtype)/s
                self.hbar.append(h)
        if self.alpha_mode == 'learnable_shared':
            self.alpha = full_dict_state['atoms_params'][-1].clone()
        self.clf_Net.load_state_dict(full_dict_state['clf_params'])
            
        print('[SUCCESSFULLY LOADED] ',str_file)
    
#%%

class GWmachine():
    
    def __init__(self, 
                 Katoms:int, 
                 n_labels:int, 
                 learn_hbar:bool,
                 experiment_repo:str, 
                 clf_net_dict:dict,
                 dtype = th.float64,
                 device='cpu'):
        """
        Parameters
        ----------
        graphs: list of numpy arrays (N_k,N_k). 
                If set to "None", graphs will be downloaded from the specified "dataset_name"
        masses: list of numpy arrays matching respectively graphs sizes.
                If set to "None", computed based on downloaded graphs from "dataset_name"
        labels: numpy array (N_k,) 
            If set to "None", labels will be downloaded from the specified "dataset_name". Used for analysis of the unsupervised learning process.
        Katoms: number of atoms
        experiment_repo : subrepository to save results of the experiment 
        """
        assert np.all( [s in clf_net_dict.keys() for s in ['hidden_dim', 'num_hidden', 'dropout']])
        
        self.Katoms = Katoms
        self.n_labels = n_labels
        self.device = device
        self.clf_net_dict = clf_net_dict
        self.classification_metrics = {'accuracy' : (lambda y_true,y_pred : accuracy_score(y_true,y_pred))}
        
        self.experiment_repo = experiment_repo
        if not os.path.exists(self.experiment_repo):
            os.makedirs(self.experiment_repo)
        self.Cbar, self.hbar = None, None
        self.dtype = dtype
        if clf_net_dict['dropout'] != 0.:
            self.clf_Net = utils_networks.MLP_dropout(
                input_dim = self.Katoms, 
                output_dim = self.n_labels,
                hidden_dim = clf_net_dict['hidden_dim'],
                num_hidden = clf_net_dict['num_hidden'],
                output_activation= 'linear', 
                dropout = clf_net_dict['dropout'],
                skip_first = True,
                device=self.device, dtype=self.dtype)
        else:
            self.clf_Net= utils_networks.MLP(
                  input_dim = self.Katoms, 
                  output_dim = self.n_labels,
                  hidden_dim = clf_net_dict['hidden_dim'],
                  num_hidden = clf_net_dict['num_hidden'],
                  output_activation= 'linear', 
                  device=self.device, dtype=self.dtype)
        self.learn_hbar = learn_hbar
        self.loss = th.nn.CrossEntropyLoss().to(self.device)
    
    def init_parameters(self,list_Satoms:list, init_mode_atoms:str='sampling_supervised', graphs:list=None, features:list=None, labels:list=None, 
                        atoms_projection:str='clipped', verbose:bool=False):
        assert len(list_Satoms)==self.Katoms
        self.Cbar = []
        self.hbar = []
        for S in list_Satoms:
            x = th.ones(S, dtype=self.dtype, device=self.device)
            x /= S
            x.requires_grad_(self.learn_hbar)
            self.hbar.append(x)
        
        if init_mode_atoms =='sampling_supervised': 
            print('init_mode_atoms = sampling_supervised' )
            # If not enough samples have the required shape within a label
            # we get the samples within the label and create perturbated versions of observed graphs
            shapes = th.tensor([C.shape[0] for C in graphs])
            unique_atom_shapes, counts = np.unique(list_Satoms,return_counts=True)
            unique_labels = th.unique(labels)
            idx_by_labels = [th.where(labels==label)[0] for label in unique_labels]
            print('counts:', counts)
            for i,shape in enumerate(unique_atom_shapes):
                #print('shape: %s / required_samples: %s/ occurrences in the dataset =%s  '%(shape,counts[i],len(shape_idx)))
                r = counts[i]
                perm = th.randperm(self.n_labels)
                if r < self.n_labels:
                    count_by_label = 1
                    perm = perm[:r]
                    print('Katoms < #labels : one atom selected each with sampled labels %s'%perm)
                    
                else:
                    count_by_label = r// self.n_labels
                print('r = %s / unique_labels: %s /count_by_label: %s'%(r, self.n_labels, count_by_label) )
                for i in perm:
                    i_ = i.item()
                    shapes_label = shapes[idx_by_labels[i_]]
                    shape_idx_by_label = th.where(shapes_label==shape)[0]
                    print('shape_idx_by_label (S=%s) / %s'%(shape, shape_idx_by_label))
                    try:
                        print('enough samples of shape =%s / within label  =%s'%(shape, i_))
                        sample_idx = np.random.choice(shape_idx_by_label.numpy(), size=min(r,count_by_label), replace=False)
                        print('sample_idx:', sample_idx)
                        for idx in sample_idx:
                            localC = graphs[idx_by_labels[i_][idx]].clone().to(self.device)
                            localC.requires_grad_(True)
                            self.Cbar.append(localC)      
                    except:
                        print(' not enough samples of shape =%s / within label  =%s'%(shape, i_))
                        try:
                            sample_idx = np.random.choice(shape_idx_by_label.numpy(), size=min(r,count_by_label), replace=True)
                            print('sample idx:', sample_idx)
                            for idx in sample_idx:
                                C = graphs[idx_by_labels[i_][idx]].clone().to(self.device)
                                noise_distrib_C = th.distributions.normal.Normal(loc=0., scale= C.std() + 1e-05)
                                noise_C = noise_distrib_C.rsample(C.size()).to(self.device)
                                if th.all(C == C.T):
                                    noise_C = (noise_C + noise_C)/2
                                new_C = C + noise_C
                                new_C.requires_grad_(True)                                             
                                self.Cbar.append(new_C)
                                
                        except:
                            print(' NO sample found with proper shape within label = %s'%i_)
                            # We generate a random graph from the distribution of graphs within this label
                            # Randomly over C
                            # by kmeans over F
                            local_idx = idx_by_labels[i_]
                            means_C = th.tensor([graphs[idx].mean() for idx in local_idx])
                            distrib_C = th.distributions.normal.Normal(loc= means_C.mean(), scale= means_C.std()+ 1e-05)
                            for _ in range(min(r,count_by_label)):
                                C = distrib_C.rsample((shape, shape)).to(self.device)
                                C = (C + C.T) / 2.
                                if atoms_projection == 'clipped':
                                    new_C = ( C - C.min())/ (C.max() - C.min())
                                elif atoms_projection == 'nonnegative':
                                    new_C = C.clamp(0)
                                new_C.requires_grad_(True)                                             
                                self.Cbar.append(new_C)
                                
                        continue
                        #print('asked label: %s / found label : %s' %(unique_labels[perm][i],labels[idx_by_labels[i_][idx]]))
        
        else:
            raise "only init_mode_atoms in ['random', 'sampling_supervised'] supported"
        self.clf_params = list(self.clf_Net.parameters())
        self.atoms_params = [*self.Cbar]
        if self.learn_hbar:
            self.atoms_params += [*self.hbar]
        self.params = self.atoms_params +  self.clf_params
        self.shape_atoms = [C.shape[0] for C in self.Cbar]
        print('---  model initialized  ---')
        #print('sanity check params (shape, requires grad):', [(param.shape, param.requires_grad) for param in self.params])
        print('clf:', self.clf_Net.parameters)
        
    def get_features_by_input(self, G, h):
        res = []
        for i in range(self.Katoms):
            res.append(GW_utils.parallel_gromov_wasserstein2(G, self.Cbar[i], h, self.hbar[i]))
        
        return res

    def parallelized_get_features(self,list_G, list_h, n_jobs=2):
        """
        list_G: list of input structures
        list_h: list of corresponding masses
        """        
        features = th.zeros((len(list_G),self.Katoms), dtype=self.dtype, device=self.device)
        res_by_input = Parallel(n_jobs=n_jobs)(delayed(self.get_features_by_input)(list_G[i], list_h[i]) for i in range(len(list_G)))
        for idx_res in range(len(res_by_input)):
            for idx_atom in range(self.Katoms):
                
                gw, gh, ghbar, gC, gCbar = res_by_input[idx_res][idx_atom]
                if self.device == 'cpu':
                    gw = GW_utils.set_gradients(self.ValFunction, gw, (list_h[idx_res], self.hbar[idx_atom], list_G[idx_res], self.Cbar[idx_atom]), (gh, ghbar, gC, gCbar))
                else:
                    gw = GW_utils.set_gradients(self.ValFunction, gw, (list_h[idx_res], self.hbar[idx_atom], list_G[idx_res], self.Cbar[idx_atom]), (gh.to(self.device), ghbar.to(self.device), gC.to(self.device), gCbar.to(self.device)))
                
                features[idx_res, idx_atom] = gw
        
        return features
    
    def set_model_to_train(self):
        self.clf_Net.train()
    
    def set_model_to_eval(self):
        self.clf_Net.eval()
    
    def fit(self, model_name:str, 
            X_train:list, y_train:list, 
            X_val:list, y_val:list, 
            X_test:list, y_test:list,
            atoms_projection:str, lr:float, batch_size:int, supervised_sampler:bool, 
            epochs:int, val_timestamp:int, algo_seed:int, track_templates:bool=False, verbose:bool=False, n_jobs:int=2):
        th.manual_seed(algo_seed)
        np.random.seed(algo_seed)
        self.ValFunction = GW_utils.ValFunction
        
        h_train = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_train]
        best_train_acc = - np.inf
        best_train_loss = np.inf     
        
        if not (X_val is None):   
            h_val = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_val]
            h_test = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_test]
            do_validation = True
            sets = ['train', 'val', 'test']
            best_val_acc = - np.inf
            best_val_acc_train_acc = - np.inf

        else:
            do_validation = False
            sets = ['train']
            
        n_train = y_train.shape[0]
        
        self.log = {'train_cumulated_batch_loss':[], 'train_epoch_loss':[], 'val_epoch_loss':[], 'test_epoch_loss':[] }        
        for metric in self.classification_metrics.keys():
            for s in sets:
                self.log['%s_%s'%(s, metric)]=[]
                
        if track_templates:
            self.log_templates = {}
            for label in range(self.n_labels):
                for s in sets:
                    self.log_templates['%s_mean_dists_label%s'%(s, label)] = []
                    self.log_templates['%s_std_dists_label%s'%(s, label)] = []
        
        self.optimizer = th.optim.Adam(params=self.params, lr=lr, betas=[0.9, 0.99])
        
        if supervised_sampler:
            unique_labels = th.unique(y_train)
            n_labels = unique_labels.shape[0]
            train_idx_by_labels = [th.where(y_train==label)[0] for label in unique_labels]
            labels_by_batch = batch_size // n_labels
        batch_by_epoch = n_train//batch_size +1
            
        self.set_model_to_train()
        for e in tqdm(range(epochs), desc='epochs'):
            cumulated_batch_loss=0.
            for batch_i in range(batch_by_epoch):
                
                self.optimizer.zero_grad()
                if not supervised_sampler:
                    batch_idx = np.random.choice(range(n_train), size=batch_size, replace=False)
                else:
                    r = batch_size
                    for idx_label,label in enumerate(np.random.permutation(unique_labels)):
                        
                        local_batch_idx = np.random.choice(train_idx_by_labels[label], size=min(r,labels_by_batch), replace=False)
                        if idx_label ==0:
                            batch_idx = local_batch_idx
                        else:
                            batch_idx = np.concatenate([batch_idx,local_batch_idx])
                        r -= labels_by_batch

                batch_graphs = [X_train[idx] for idx in batch_idx]
                batch_masses = [h_train[idx] for idx in batch_idx]

                features = self.parallelized_get_features(batch_graphs, batch_masses, n_jobs)
                batch_pred = self.clf_Net(features)
                batch_loss = self.loss(batch_pred, y_train[batch_idx])
                cumulated_batch_loss += batch_loss.item()
                batch_loss.backward()
                
                self.optimizer.step()        
                with th.no_grad():    
                    for C in self.Cbar:
                        if atoms_projection =='clipped':
                            C[:] = C.clamp(min=0,max=1)
                        else:
                            C[:] = C.clamp(min=0)
                    if self.learn_hbar:
                        for h in self.hbar:
                            h[:] = GW_utils.probability_simplex_projection(h)
            self.log['train_cumulated_batch_loss'].append(cumulated_batch_loss)    
                
            if (((e %val_timestamp) ==0) and e>0) or (e == (epochs - 1)):
                with th.no_grad():
                    self.set_model_to_eval()
                    
                    features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate(X_train, h_train, y_train, n_jobs)
                    
                    if do_validation:
                        features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate(X_val, h_val, y_val, n_jobs)
                        features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate(X_test, h_test, y_test, n_jobs)
                        
                        if verbose:
                            if self.learn_hbar:
                                print('hbar:', self.hbar)
                            print('epoch= %s / loss_train = %s / res_train = %s / loss_val =%s/ res_val =%s'%(e,loss_train.item(),res_train,loss_val.item(),res_val))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        self.log['val_epoch_loss'].append(loss_val.item())
                        self.log['test_epoch_loss'].append(loss_test.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            self.log['val_%s'%metric].append(res_val[metric])
                            self.log['test_%s'%metric].append(res_test[metric])
                        
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_val_acc <= res_val['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_val_acc = res_val['accuracy']
                            if best_val_acc_train_acc <= res_train['accuracy']:
                                best_val_acc_train_acc = res_train['accuracy']
                                str_file = self.experiment_repo+'/%s_best_val_accuracy_increasing_train_accuracy.pkl'%model_name
                                full_dict_state = {
                                    'epoch' : e, 
                                    'atoms_params': [param.clone().detach().cpu() for param in self.atoms_params],
                                    'clf_params': self.clf_Net.state_dict()}
                                
                                pickle.dump(full_dict_state, open(str_file, 'wb'))
                                                   
                        if track_templates:
                            
                            for label in th.unique(y_train):
                                label_ = int(label.item())
                                idx_train = th.where(y_train==label)
                                subfeatures_train = features_train[idx_train]
                                idx_val = th.where(y_val==label)
                                subfeatures_val = features_val[idx_val]
                                self.log_templates['train_mean_dists_label%s'%label_].append(subfeatures_train.mean(axis=0))
                                self.log_templates['train_std_dists_label%s'%label_].append(subfeatures_train.std(axis=0))
                                self.log_templates['val_mean_dists_label%s'%label_].append(subfeatures_val.mean(axis=0))
                                self.log_templates['val_std_dists_label%s'%label_].append(subfeatures_val.std(axis=0))
    
                                if verbose:
                                    print('[TRAIN] label = %s / mean dists: %s / std dists: %s'%(label,self.log_templates['train_mean_dists_label%s'%label_][-1], self.log_templates['train_std_dists_label%s'%label_][-1]))
                                #print('[VAL] label = %s / mean dists: %s / std dists: %s'%(label,self.log_weights['val_mean_dists_label%s'%label_][-1], self.log_weights['val_std_dists_label%s'%label_][-1]))
                            str_log = self.experiment_repo+'/%s_tracking_templates_log.pkl'%model_name
                            pickle.dump(self.log_templates, open(str_log,'wb'))
                    else:
                        
                        if verbose:
                            if self.learn_hbar:
                                print('hbar:', self.hbar)
                            print('epoch= %s / loss_train = %s / res_train = %s '%(e, loss_train.item(), res_train))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_train_acc <= res_train['accuracy']: #Save model with best val acc assuring increase of train acc
                            save_epoch = False    
                            if best_train_acc == res_train['accuracy']:
                                if best_train_loss > loss_train.item():
                                    save_epoch = True
                                    best_train_loss = loss_train.item()
                                else:
                                    save_epoch = False
                            else:
                                save_epoch = True
                                best_train_loss = loss_train.item()
                                best_train_acc = res_train['accuracy']
                            if save_epoch:
                                if verbose:
                                    print('saving epoch')
                                str_file = self.experiment_repo+'/%s_best_train_accuracy.pkl'%model_name
                                full_dict_state = {
                                    'epoch' : e, 
                                    'atoms_params': [param.clone().detach().cpu() for param in self.atoms_params],
                                    'clf_params': self.clf_Net.state_dict()}
                                pickle.dump(full_dict_state, open(str_file, 'wb'))
                                                   
                        if track_templates:
                            
                            for label in th.unique(y_train):
                                label_ = int(label.item())
                                idx_train = th.where(y_train==label)
                                subfeatures_train = features_train[idx_train]
                                self.log_templates['train_mean_dists_label%s'%label_].append(subfeatures_train.mean(axis=0))
                                self.log_templates['train_std_dists_label%s'%label_].append(subfeatures_train.std(axis=0))
                                if verbose:
                                    print('[TRAIN] label = %s / mean dists: %s / std dists: %s'%(label,self.log_templates['train_mean_dists_label%s'%label_][-1], self.log_templates['train_std_dists_label%s'%label_][-1]))
                                
                                #print('[VAL] label = %s / mean dists: %s / std dists: %s'%(label,self.log_weights['val_mean_dists_label%s'%label_][-1], self.log_weights['val_std_dists_label%s'%label_][-1]))
                            str_log = self.experiment_repo+'/%s_tracking_templates_log.pkl'%model_name
                            pickle.dump(self.log_templates, open(str_log,'wb'))                        
                # after evaluation, make the model trainable again
                self.set_model_to_train()
                
                
    def evaluate(self, X:list, h:list, y:list, n_jobs:int=None):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        with th.no_grad():
            dist_features = self.parallelized_get_features(X, h, n_jobs)
            pred = self.clf_Net(dist_features)
            loss = self.loss(pred, y)
            y_pred = pred.argmax(1)
            y_ = y.cpu().detach().numpy()
            y_pred_ = y_pred.cpu().detach().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return dist_features, pred, y_pred, loss, res
    
    def fit_fixedtemplates(self, model_name:str, 
            X_train:list, y_train:list, 
            X_val:list, y_val:list, 
            X_test:list, y_test:list,
            atoms_projection:str, lr:float, batch_size:int, supervised_sampler:bool, 
            epochs:int, val_timestamp:int, algo_seed:int, track_templates:bool=False, verbose:bool=False, n_jobs:int=2):
        th.manual_seed(algo_seed)
        np.random.seed(algo_seed)
        self.ValFunction = GW_utils.ValFunction
        # Set requires_grad to False for sampled templates
        for idx_atom, (Cbar, hbar) in enumerate(zip(self.Cbar, self.hbar)):
            Cbar.requires_grad_(False)
            hbar.requires_grad_(False)
        print('learning with fixed atoms: Cbar= %s / hbar=%s'%([C.requires_grad for C in self.Cbar], [h.requires_grad for h in self.hbar]))
        
        h_train = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_train]
        dist_features_train = self.parallelized_get_features(X_train, h_train, n_jobs)
        print('dist_features_train:', dist_features_train)
        print('labels_train:', y_train)
        
        best_train_acc = - np.inf
        best_train_loss = np.inf     
        
        if not (X_val is None):   
            h_val = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_val]
            h_test = [th.ones(C.shape[0], dtype=self.dtype, device=self.device)/C.shape[0] for C in X_test]
            do_validation = True
            sets = ['train', 'val', 'test']
            best_val_acc = - np.inf
            best_val_acc_train_acc = - np.inf
            dist_features_val = self.parallelized_get_features(X_val, h_val, n_jobs)
            dist_features_test = self.parallelized_get_features(X_test, h_test, n_jobs)
            
        else:
            do_validation = False
            sets = ['train']
        
        n_train = y_train.shape[0]
        
        self.log = {'train_cumulated_batch_loss':[], 'train_epoch_loss':[], 'val_epoch_loss':[], 'test_epoch_loss':[] }        
        for metric in self.classification_metrics.keys():
            for s in sets:
                self.log['%s_%s'%(s, metric)]=[]
                
        if track_templates:
            self.log_templates = {}
            for label in range(self.n_labels):
                for s in sets:
                    self.log_templates['%s_mean_dists_label%s'%(s, label)] = []
                    self.log_templates['%s_std_dists_label%s'%(s, label)] = []
                    
        self.optimizer = th.optim.Adam(params=self.params, lr=lr, betas=[0.9, 0.99])
        if supervised_sampler:
            unique_labels = th.unique(y_train)
            n_labels = unique_labels.shape[0]
            train_idx_by_labels = [th.where(y_train==label)[0] for label in unique_labels]
            labels_by_batch = batch_size // n_labels
        batch_by_epoch = n_train//batch_size +1
            
        self.set_model_to_train()
        for e in tqdm(range(epochs), desc='epochs'):
            cumulated_batch_loss=0.
            for batch_i in range(batch_by_epoch):
                
                self.optimizer.zero_grad()
                if not supervised_sampler:
                    batch_idx = np.random.choice(range(n_train), size=batch_size, replace=False)
                else:
                    r = batch_size
                    for idx_label,label in enumerate(np.random.permutation(unique_labels)):
                        
                        local_batch_idx = np.random.choice(train_idx_by_labels[label], size=min(r,labels_by_batch), replace=False)
                        if idx_label ==0:
                            batch_idx = local_batch_idx
                        else:
                            batch_idx = np.concatenate([batch_idx,local_batch_idx])
                        r -= labels_by_batch

                batch_dist_features = dist_features_train[batch_idx, :]           
                batch_pred = self.clf_Net(batch_dist_features)
                batch_loss = self.loss(batch_pred, y_train[batch_idx])
                
                cumulated_batch_loss += batch_loss.item()
                batch_loss.backward()
                
                self.optimizer.step()        
                
            self.log['train_cumulated_batch_loss'].append(cumulated_batch_loss)
                
            if (((e %val_timestamp) ==0) and e>0) or (e == (epochs - 1)):
                with th.no_grad():
                    self.set_model_to_eval()
                    features_train, pred_train, y_pred_train, loss_train, res_train = self.evaluate_fixedtemplates(dist_features_train, y_train, n_jobs)
                    
                    if do_validation:
                        features_val, pred_val, y_pred_val, loss_val, res_val = self.evaluate_fixedfeatures(dist_features_val, y_val, n_jobs)
                        features_test, pred_test, y_pred_test, loss_test, res_test = self.evaluate_fixedfeatures(dist_features_test , y_test, n_jobs)
                        
                        if verbose:
                            if self.learn_hbar:
                                print('hbar:', self.hbar)
                            print('epoch= %s / loss_train = %s / res_train = %s / loss_val =%s/ res_val =%s'%(e,loss_train.item(),res_train,loss_val.item(),res_val))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        self.log['val_epoch_loss'].append(loss_val.item())
                        self.log['test_epoch_loss'].append(loss_test.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            self.log['val_%s'%metric].append(res_val[metric])
                            self.log['test_%s'%metric].append(res_test[metric])
                        
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_val_acc <= res_val['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_val_acc = res_val['accuracy']
                            if best_val_acc_train_acc <= res_train['accuracy']:
                                best_val_acc_train_acc = res_train['accuracy']
                                str_file = self.experiment_repo+'/%s_best_val_accuracy_increasing_train_accuracy.pkl'%model_name
                                full_dict_state = {
                                    'epoch' : e, 
                                    'atoms_params': [param.clone().detach().cpu() for param in self.atoms_params],
                                    'clf_params': self.clf_Net.state_dict()}
                                
                                pickle.dump(full_dict_state, open(str_file, 'wb'))                        
                    else:
                        
                        if verbose:
                            if self.learn_hbar:
                                print('hbar:', self.hbar)
                            print('epoch= %s / loss_train = %s / res_train = %s '%(e, loss_train.item(), res_train))
                        self.log['train_epoch_loss'].append(loss_train.item())
                        for metric in self.classification_metrics.keys():
                            self.log['train_%s'%metric].append(res_train[metric])
                            
                        str_log = self.experiment_repo+'/%s_training_log.pkl'%model_name
                        pickle.dump(self.log, open(str_log,'wb'))
                        if best_train_acc <= res_train['accuracy']: #Save model with best val acc assuring increase of train acc
                            
                            best_train_acc = res_train['accuracy']
                            
                            str_file = self.experiment_repo+'/%s_best_train_accuracy.pkl'%model_name
                            full_dict_state = {
                                'epoch' : e, 
                                'atoms_params': [param.clone().detach().cpu() for param in self.atoms_params],
                                'clf_params': self.clf_Net.state_dict()}
                            
                            pickle.dump(full_dict_state, open(str_file, 'wb'))                 
                # after evaluation, make the model trainable again
                self.set_model_to_train()
    
    def evaluate_fixedtemplates(self, dist_features , y:list, n_jobs:int=None):
        #print('--- evaluate current model ---')
        self.set_model_to_eval()
        with th.no_grad():
            
            pred = self.clf_Net(dist_features)
            loss = self.loss(pred, y)
            y_pred = pred.argmax(1)
            y_ = y.detach().numpy()
            y_pred_ = y_pred.detach().numpy()
            res = {}
            for metric in self.classification_metrics.keys():
                res[metric] = self.classification_metrics[metric](y_,y_pred_)                
        return dist_features, pred, y_pred, loss, res
    
    def load(self, str_file):
        full_dict_state = pickle.load(open(str_file, 'rb'))
        self.Cbar = []
        self.hbar = []
        self.ValFunction = GW_utils.ValFunction
        
        for p in full_dict_state['atoms_params'][:self.Katoms]:  # atoms structure
            C = p.clone()
            C.requires_grad_(False)
            self.Cbar.append(C)
        self.shape_atoms = [C.shape[0] for C in self.Cbar]
        if self.learn_hbar:
            for p in full_dict_state['atoms_params'][self.Katoms : 2 * self.Katoms]:  # atoms structure
                h = p.clone()
                h.requires_grad_(False)
                self.hbar.append(h)
        else:
            for s in self.shape_atoms:
                h = th.ones(s, dtype=self.dtype)/s
                self.hbar.append(h)
        self.clf_Net.load_state_dict(full_dict_state['clf_params'])
            
        print('[SUCCESSFULLY LOADED] ',str_file)
    