from data_handler import dataloader
import numpy as np
from tqdm import tqdm 
import algo_relaxedGW as algo

from data_handler.graph_class import random_edge_removal, random_edge_addition
import os
import pandas as pd
import pickle 
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split
import pylab as pl
from sklearn.cluster import KMeans
from  scipy.sparse.csgraph import shortest_path
import networkx as nx
import GDL_utils as gwu
#%%


str2rpzfunctions ={'ADJ': (lambda x: x),
                   'normADJ':gwu.compute_normadj,
                   'SP':shortest_path,
                   'LAP': gwu.compute_laplacian,
                   'normLAP':gwu.compute_normlaplacian,
                   'SIF':gwu.compute_sif_distance}



class GWh_datasets_graph():
    
    def __init__(self,
                 graphs:list=None, 
                 masses:list=None,
                 dataset_name:str=None,
                 mode:str='ADJ',
                 Ntarget:int=None, 
                 experiment_repo:str=None, 
                 experiment_name:str=None,
                 degrees:bool=False,
                 completion_parameters:dict={},
                 data_path:str='../data/'):
        """
        Parameters
        ----------
        dataset_name : name of the dataset to experiment on. To match our data loaders it is restricted to ['imdb-b','imdb-m','balanced_clustertoy','clustertoy2C']
        mode : representations for input graphs. (e.g) 'ADJ':adjacency / 'SP': shortest path 
        Ntarget: size of target graph which summarizes the dataset
        experiment_repo : subrepository to save results of the experiment 
        experiment_name : subrepository to save results of the experiment under the 'experiment repo' repository
        degrees: either to use uniform distribution (False) for each graph, else use degree distribution (True)
        completion_parameters: dict used to handle split into D_train and D_test
        data_path : path where data is. The default is '../data/'.
        """
        self.experiment_repo= experiment_repo
        self.experiment_name = experiment_name
        print('dataset_name:', dataset_name)
        str_to_method = {'ADJ': 'adjacency', 'SP':'shortest_path','LAP':'laplacian',
                         'fullADJ':'augmented_adjacency','normADJ':'normalized_adjacency',
                         'SIF':'sif_distance', 'SLAP':'signed_laplacian','normLAP':'normalized_laplacian'}
        self.degrees=degrees
        self.Ntarget = Ntarget
        self.completion_parameters = completion_parameters
        if graphs is None:
            
            if dataset_name in ['imdb-b','imdb-m']:
                
                self.dataset_name= dataset_name
                self.mode = mode
                if  (completion_parameters=={}):   
                    X,self.y=dataloader.load_local_data(data_path,dataset_name)
                
                    if self.mode in str_to_method.keys():
                        self.graphs= [np.array(X[t].distance_matrix(method=str_to_method[mode]),dtype=np.float64) for t in range(X.shape[0])]
                        if not self.degrees:
                            #uniform distributions
                            self.masses= [np.ones(Cs.shape[0])/Cs.shape[0] for Cs in self.graphs]
                            
                        else:
                            print('computing degree distributions')
                            self.masses =[]
                            for C in self.graphs:
                                h = np.sum(C,axis=0)
                                self.masses.append( h/np.sum(h))
                    else:
                        raise 'unknown mode /graph representation'
               
                elif completion_parameters != {}:
                    
                    X,y = dataloader.load_local_data(data_path,dataset_name)
                    list_X= [np.array(X[t].distance_matrix(method=str_to_method['ADJ']),dtype=np.float64) for t in range(X.shape[0])]
                    N = len(list_X)
                    train_idx,test_idx = train_test_split(np.arange(N), test_size =self.completion_parameters['split_rate'], stratify=np.array(y),random_state = self.completion_parameters['split_seed'])
                    self.raw_train_graphs, self.raw_test_graphs= [list_X[i] for i in train_idx], [list_X[i] for i in test_idx]
                    self.train_y, self.test_y= [y[i] for i in train_idx], [y[i] for i in test_idx]  
                    if not self.mode=='ADJ':
                        self.graphs = [str2rpzfunctions[self.mode](C) for C in self.raw_train_graphs] 
                    else:
                        self.graphs = self.raw_train_graphs                        
                    if not self.degrees:
                        #uniform distributions
                        self.masses= [np.ones(C.shape[0])/C.shape[0] for C in self.raw_train_graphs]
                    
                    else:
                        print('computing degree distributions')
                        self.masses =[]
                        for C in self.raw_train_graphs:
                            h = np.sum(C,axis=0)
                            self.masses.append( h/np.sum(h))
                    print('number of graphs in the train dataset for completion experiments:', len(self.graphs))
                else:
                    raise 'unknown type of experiments to run'
        else:# The graphs to learn on are given already 
            self.mode=mode
            self.dataset_name = dataset_name
            self.graphs=graphs
            self.masses= masses
        # Analyse either graphs are undirected are directed to run proper OT solvers
        self.dataset_size = len(self.graphs)
        self.undirected = np.all([np.all(self.graphs[i]==self.graphs[i].T) for i in range(self.dataset_size)])
        if self.dataset_name in ['imdb-b', 'imdb-m']:
            assert self.undirected ==True
        else:
            print('All graphs in the dataset are undirected?', self.undirected)
                   
                    
    def init_Ctarget(self,init_mode_graph:str='random',seed:int=0, use_checkpoint:bool = True):

        if init_mode_graph =='random':
            #cluster-like graph with random off-diagonal component
            np.random.seed(seed)
            #x = np.random.uniform(low=0.1, high=0.9, size=(self.Ntarget,self.Ntarget))
            x = np.random.normal(loc=0.5, scale=0.01, size=(self.Ntarget,self.Ntarget))
            
            if self.proj in ['nsym','sym']:
                self.Ctarget = (x+x.T)/2
            np.fill_diagonal(self.Ctarget,1)


        else:
            raise 'unknown init graph mode'
        if use_checkpoint:
            self.checkpoint_Ctarget = []
            
            
    def initialize_optimizer(self):
        #Initialization for our numpy implementation of adam optimizer
        self.adam_moment1 = np.zeros(( self.Ntarget,self.Ntarget))#Initialize first  moment vector
        self.adam_moment2 = np.zeros((self.Ntarget,self.Ntarget))#Initialize second moment vector
        self.adam_count = 1
        
    def create_srGW_operator(self,init_mode:str='product',
                             eps_inner:float=10**(-6), 
                             max_iter_inner:int=1000,
                             gamma_entropy:float=0,
                             lambda_reg:float=None,
                             eps_inner_MM:float=0, 
                             max_iter_MM:int=0, 
                             use_warmstart_MM:bool=True,
                             seed:int=0):
        # Just create an operator for unmixing step involved in each iteration of the dictionary learning
        if max_iter_MM==0: # This condition means We do not use concave sparsity promoting regularization with MM solver.
            if self.undirected:
                if gamma_entropy ==0:
                    srGW_operator = (lambda C1,h1,C2,T_init: algo.GW_relaxedmarginal(C1,h1,C2, init_mode=init_mode, 
                                                                                     T_init=T_init, use_log=False, eps=eps_inner, max_iter=max_iter_inner,seed=seed))
                else:    
                    srGW_operator = (lambda C1,h1,C2,T_init: algo.entropic_semirelaxedGW(C1,h1,C2,gamma_entropy=gamma_entropy, init_mode=init_mode, 
                                                                                     T_init=T_init, use_log=False, eps=eps_inner, max_iter=max_iter_inner,seed=seed))
            else:#directed graphs
                # entropic solver not implemented yet
                srGW_operator = (lambda C1,h1,C2,T_init: algo.GW_relaxedmarginal_asym(C1,h1,C2, init_mode=init_mode, 
                                                                                     T_init=T_init, use_log=False, eps=eps_inner, max_iter=max_iter_inner,seed=seed))
        else:# We use Majorization-Minimization solver.
            if self.undirected:
                srGW_operator = (lambda C1,h1,C2,T_init: algo.GW_relaxedmarginal_majorationminimization_lpl1(C1,h1,C2,p_reg=0.5, init_mode=init_mode, lambda_reg=lambda_reg, gamma_entropy=gamma_entropy,
                                                                                 T_init=T_init, use_log=False, eps_inner=eps_inner, max_iter_inner=max_iter_inner,eps_outer = eps_inner_MM, eps_reg = 10**(-15),
                                                                                 max_iter_outer=max_iter_MM,seed=seed, verbose=False, inner_log=False, use_warmstart=use_warmstart_MM))
                                                                                     
            else:#directed graphs
                # entropic solver not implemented yet
                if gamma_entropy != 0:
                    raise 'Warning: Maj-Min solver with mirror descent inner solver has not been implemented for DIRECTED graphs'
                else:
                    srGW_operator = (lambda C1,h1,C2,T_init: algo.GW_relaxedmarginal_majorationminimization_lpl1_asym(C1,h1,C2,init_mode=init_mode, p_reg=0.5,lambda_reg=lambda_reg, T_init=T_init,
                                                                                                                      eps_inner=eps_inner,eps_outer=eps_inner_MM,eps_reg = 10**(-15),use_log=False,
                                                                                                                      max_iter_inner=max_iter_inner,max_iter_outer =max_iter_MM, verbose=False, inner_log = False))
        return srGW_operator
    
    def Learn_Ctarget(self,lambda_reg:float,max_iter_inner:int,
              eps_inner:float,lr:float,batch_size:int,epochs:int,
              algo_seed:int,max_iter_MM:int=None, eps_inner_MM:float=None,use_warmstart_MM:bool=False,
              gamma_entropy:float=0.,beta_1:float=0.9, beta_2:float=0.99,
              init_mode_graph:str='random',use_optimizer:bool=True,checkpoint_freq = 5,
              use_checkpoint:bool = True, proj:str= 'nsym',init_GW:str='product'):
        """
        Stochastic Algorithm to learn srGW dictionaries 
        described in Section 4 of the main paper and Algorithm 2,
        further details in the supplementary material. 

        Parameters
        ----------
        lambda_reg : sparse regularization coefficient
        gamma_entropy: regularization coefficient of the negative quadratic regularization on unmixings
        eps : precision to stop our learning process based on relative variation of the loss
        max_iter_inner : maximum number of iterations for the Conditional Gradient algorithm on {wk}
        lr : Initial learning rate of Adam optimizer
        batch_size : batch size 
        algo_seed : initialization random seed
        OT_loss : GW discrepency ground cost. The default is 'square_loss'.
        beta_1 : Adam parameter on gradient. The default is 0.9.
        beta_2 : Adam parameter on gradient**2. The default is 0.99.
        use_checkpoint : To save dictionary state and corresponding unmixing at different time steps. The default is False.
        verbose : Check the good evolution of the loss. The default is False.
        """
        if max_iter_MM is None:
            self.settings = {'Ntarget':self.Ntarget, 'max_iter_inner':max_iter_inner,'eps_inner':eps_inner,'epochs':epochs,
                             'lr':lr,'init_mode_graph':init_mode_graph,'batch_size':batch_size,
                             'algo_seed':algo_seed, 'beta1':beta_1, 'beta2':beta_2,'l2_reg':0,'lambda_reg':0, #to make it compatible with past versions
                             'use_optimizer':use_optimizer,'init_GW':init_GW, 'proj':proj}
        else:
              
            self.settings = {'Ntarget':self.Ntarget, 'max_iter_FW':max_iter_inner,'eps_inner_FW':eps_inner,'max_iter_MM':max_iter_MM,'eps_inner_MM':eps_inner_MM,
                 'lr':lr,'init_mode_graph':init_mode_graph,'batch_size':batch_size,'epochs':epochs,
                 'algo_seed':algo_seed, 'beta1':beta_1, 'beta2':beta_2,'lambda_reg':lambda_reg, #to make it compatible with past versions
                 'use_optimizer':use_optimizer,'init_GW':init_GW, 'proj':proj,'use_warmstart_MM':use_warmstart_MM}

        if gamma_entropy !=0:
            self.settings['gamma_entropy']=gamma_entropy
            
        self.proj = proj
        self.init_Ctarget(init_mode_graph, algo_seed, use_checkpoint)
        
        # first call of random seed generator done while initializating atoms
        algo_seed=None
        srGW_operator=self.create_srGW_operator(init_mode=init_GW,eps_inner=eps_inner, max_iter_inner=max_iter_inner, 
                                                eps_inner_MM=eps_inner_MM, max_iter_MM=max_iter_MM,lambda_reg=lambda_reg,
                                                gamma_entropy=gamma_entropy,use_warmstart_MM=use_warmstart_MM,seed=algo_seed)
        
        if use_optimizer:
            self.initialize_optimizer()
        T = len(self.graphs)
        self.log ={}
        self.log['batch_loss']=[]
        self.log['epoch_loss']=[]
        best_epoch_global_rec = np.inf
        consecutive_global_rec_drops =0
        #for projection on nonnegative symmetric matrix
        threshold_C = np.zeros_like(self.Ctarget)
       
        for epoch in tqdm(range(epochs), desc='epochs'):
            seen_graphs_count = 0
            while seen_graphs_count <self.dataset_size:
                #batch sampling
                seen_graphs_count+=batch_size
                batch_t = np.random.choice(range(T), size=batch_size, replace=False)
                #print('batch idx:', batch_t)
                best_T = []
                batch_loss = 0
                for k,t in enumerate(batch_t):
                    local_T, local_loss= srGW_operator(C1= self.graphs[t], h1 = self.masses[t], C2= self.Ctarget, T_init = None)

                    best_T.append(local_T)
                    batch_loss+=local_loss
                self.log['batch_loss'].append(batch_loss)
                #Stochastic update
                grad= np.zeros_like(self.Ctarget)
                for k,t in enumerate(batch_t):
                    hk = np.sum(best_T[k],axis=0)
                    grad+= self.Ctarget*(hk[:,None].dot(hk[None,:])) - (best_T[k].T).dot(self.graphs[t].dot(best_T[k]))
                grad*= (2/batch_size)
                if not use_optimizer:
                    self.Ctarget -= lr*grad
                else:
                    
                    m1_t= beta_1*self.adam_moment1+ (1-beta_1)*grad
                    m2_t= beta_2*self.adam_moment2+(1-beta_2)*(grad**2)
                    m1_t_unbiased = m1_t /(1-beta_1**self.adam_count)
                    m2_t_unbiased = m2_t /(1-beta_2**self.adam_count)
                    self.Ctarget-= lr*m1_t_unbiased/(np.sqrt(m2_t_unbiased)+10**(-15))
                    self.adam_moment1= m1_t
                    self.adam_moment2= m2_t
                    self.adam_count +=1
                #projection on nonnegative matrices
                if proj =='nsym':
                    self.Ctarget = np.maximum(threshold_C, self.Ctarget)
                
            
            pl.plot(self.log['batch_loss']);pl.title('loss evolution by batches'); pl.show()
            if epoch==0:
                print('saved settings:', self.settings)
                self.save_elements(save_settings=True, use_checkpoint = use_checkpoint)
            elif epoch>0 and (epoch%checkpoint_freq ==0):
                print('checkpoint_step to evaluate embeddings and decide on early stopping')
                self.save_elements(save_settings=False, use_checkpoint = use_checkpoint)
                      
                _, list_losses = self.compute_unmixing(lambda_reg,gamma_entropy,eps_inner,max_iter_inner,
                                                       eps_inner_MM, max_iter_MM,use_warmstart_MM,
                                                       algo_seed,init_GW='product',use_checkpoint= False)
                mean_rec=np.mean(list_losses)
                if mean_rec<best_epoch_global_rec:
                    best_epoch_global_rec= mean_rec
                    consecutive_global_rec_drops =0            
                    print('epoch:%s / new best epoch global rec :%s'%(epoch,best_epoch_global_rec))
                else:
                    consecutive_global_rec_drops +=1
                    print('[not improved] epoch :%s / current epoch loss :%s / fails:%s '%(epoch,mean_rec,consecutive_global_rec_drops))
                    if consecutive_global_rec_drops>1:
                        break
                        
    def compute_unmixing(self,lambda_reg:float, gamma_entropy:float,eps_inner:float,max_iter_inner:int,
                         eps_inner_MM:float, max_iter_MM:int,use_warmstart_MM:bool,
                         algo_seed:int,init_GW:str='product',use_checkpoint:bool = False, verbose:bool=False):
        """
        Parameters
        ----------
        l2_reg : regularization coefficient of the negative quadratic regularization on unmixings
        eps : precision to stop our learning process based on relative variation of the loss
        max_iter_inner : maximum number of iterations for the Conditional Gradient algorithm on {wk}
        verbose : Check the good evolution of the loss. The default is False.
        """
        srGW_operator=self.create_srGW_operator(init_mode=init_GW,eps_inner=eps_inner, max_iter_inner=max_iter_inner, 
                                                eps_inner_MM=eps_inner_MM, max_iter_MM=max_iter_MM,lambda_reg=lambda_reg,
                                                gamma_entropy=gamma_entropy,use_warmstart_MM=use_warmstart_MM,seed=algo_seed)
        T= len(self.graphs)
        if not use_checkpoint :
            best_T =[]
            best_losses=[]
            for t in tqdm(range(T),desc='unmixing'):
                #Nb: could add kmeans for initializations 
                local_T,local_loss = srGW_operator(self.graphs[t],self.masses[t],self.Ctarget,T_init=None)
                best_T.append(local_T)
                best_losses.append(local_loss)  
            return best_T,best_losses
        else: #ran over all saved dictionary graph state
            list_best_T =[]
            list_best_losses=[]
            for i in range(len(self.checkpoint_Ctarget)):
                local_list_T = []
                local_list_losses = []
                for t in tqdm(range(T),desc='unmixing'):
                    #Nb: could add kmeans for initializations 
                    local_T,local_loss = srGW_operator(self.graphs[t], self.masses[t], self.checkpoint_Ctarget[i],T_init=None)
                    local_list_T.append(local_T)
                    local_list_losses.append(local_loss)
                list_best_T.append(local_list_T)
                list_best_losses.append(local_list_losses)
            return list_best_T,list_best_losses
    
    
    def compute_unmixing_inits_parallelized(self,l2_reg:float,eps_inner:float,max_iter_inner:int,n_seeds:int,use_checkpoint:bool = False, verbose:bool=False,njobs:int=2):
        """
            compute unmixing over n_seeds for each sample to get better local minimum

            WARNING = DROPPED FOR NOW > I HAVE TO MAKE IT UP TO DATA !!!!

        """
        raise 'NOT UP TO DATE'
            
        T= len(self.graphs)
        if not use_checkpoint :
            best_T =[]
            best_losses=[]
            for t in tqdm(range(T)):
                #Nb: could add kmeans for initializations 
                #list_local_T=[]
                #list_local_loss=[]
                results = Parallel(n_jobs=njobs)(delayed(algo.GW_relaxedmarginal)(self.graphs[t],self.masses[t], self.Ctarget,
                                                                   reg=l2_reg,init_mode='random',T_init=None,use_log = False,
                                                                   eps=eps_inner,max_iter=max_iter_inner,seed=seed) for seed in range(n_seeds))
                best_idx = np.argmin([x[1] for x in results])                
                best_T.append(results[best_idx][0])
                best_losses.append(results[best_idx][1])
            return best_T,best_losses
        else: #ran over all saved dictionary graph state / + test several unmixing initiializations
            full_list_best_T =[]
            full_list_best_losses=[]
            for i in range(len(self.checkpoint_Ctarget)):
                checkpoint_list_T = []
                checkpoint_list_losses = []
                for t in tqdm(range(T)):
                    #Nb: could add kmeans for initializations 
                    results = Parallel(n_jobs=njobs)(delayed(algo.GW_relaxedmarginal)(self.graphs[t],self.masses[t], self.checkpoint_Ctarget[i],
                                                                   reg=l2_reg,init_mode='random',T_init=None,use_log = False,
                                                                   eps=eps_inner,max_iter=max_iter_inner,seed=seed) for seed in range(n_seeds))
                    #l=[x[1] for x in results]
                    #print(l)
                    best_idx = np.argmin([x[1] for x in results])
                    #print('best idx:', best_idx , l[best_idx])
                    
                    checkpoint_list_T.append(results[best_idx][0])
                    checkpoint_list_losses.append(results[best_idx][1])
                    
                full_list_best_T.append(checkpoint_list_T)
                full_list_best_losses.append(checkpoint_list_losses)
            return full_list_best_T,full_list_best_losses    
        
    
    def complete_patch(self,patch:np.array, 
                       Nfullpatch:int, 
                       lambda_reg:float=0.,
                       gamma_entropy:float=0.,
                       lr:float=0.01, 
                       init_GW:str='product',
                       max_iter:int=100,
                       eps:float=10**(6),
                       max_iter_inner:int=1000,
                       eps_inner:float=10**(-6),
                       max_iter_MM:int=20,
                       eps_inner_MM:float=10**(-6),
                       use_warmstart_MM:bool=False,
                       proj:str='nsym',
                       algo_seed:int = 0,
                       use_optimizer:bool=True,
                       use_warmstart:bool=False,
                       beta_1:float=0.9,
                       beta_2:float=0.99,
                       use_log:bool=False,
                       init_patch:str='random'):
        if not (proj in ['nsym','sym']):
            raise "only proj in ['nsym','sym'] is supported for now"
        if use_log:
            local_log={}
            local_log['loss']=[]
        else:
            local_log=None
        Npatch = patch.shape[0]
        learnable_mask = np.ones((Nfullpatch,Nfullpatch))
        learnable_mask[:Npatch,:Npatch]= 0
        # print('learnable mask:', learnable_mask)
        assert Nfullpatch >= Npatch
        if init_patch=='fixed':
            completed_patch = 0.5*(np.ones((Nfullpatch,Nfullpatch)) - np.eye(Nfullpatch))
        elif init_patch=='random':
            np.random.seed(algo_seed)
            #x = np.random.uniform(low=0.1, high=0.9, size=(self.Ntarget,self.Ntarget))
            x = np.random.normal(loc=0.5, scale=0.01, size=(Nfullpatch,Nfullpatch))
            
            completed_patch = (x+x.T)/2
            np.fill_diagonal(completed_patch,0)# no diagonal as we do not seek for super nodes
        else:
            if 'scaleddegrees' in init_patch:
            # We compute the normalized degrees distributions of the fixed patch
            # Then we consider as entries of the completed patch
            # C[i,new_j] =   norm_deg(i)
            # C[new_i, new_j ] = 0.5 ?
                completed_patch = np.zeros((Nfullpatch,Nfullpatch))
                patch_degrees = np.sum(patch,axis=0)
                patch_degrees/= np.max(patch_degrees)
                completed_patch[:Npatch, Npatch:] = patch_degrees[:,None]
                completed_patch[Npatch:, :Npatch] = patch_degrees[None,:]
                completed_patch[Npatch:, Npatch:] = 0.5
                if init_patch =='noisy_scaleddegrees':
                    perturbation_range=0.1*np.min(patch_degrees)
                    noise = np.random.uniform(low=-perturbation_range,high=perturbation_range, size=(Nfullpatch,Nfullpatch))
                    completed_patch += (noise+noise.T)/2
                
            
        if use_optimizer:
            adam_moment1 = np.zeros(( Nfullpatch,Nfullpatch))#Initialize first  moment vector
            adam_moment2 = np.zeros((Nfullpatch,Nfullpatch))#Initialize second moment vector
            adam_count = 1
        srGW_operator=self.create_srGW_operator(init_mode=init_GW,eps_inner=eps_inner, max_iter_inner=max_iter_inner, 
                                                eps_inner_MM=eps_inner_MM, max_iter_MM=max_iter_MM,lambda_reg=lambda_reg,
                                                gamma_entropy=gamma_entropy,use_warmstart_MM=use_warmstart,seed=algo_seed)
        
        
        completed_patch_masses = np.ones(Nfullpatch)/Nfullpatch
        weight_mask = completed_patch_masses[:,None].dot(completed_patch_masses[None,:])
        completed_patch[:Npatch,:Npatch]= patch
        init_completed_patch = completed_patch.copy()
                
        curr_loss = 10**7
        best_loss = np.inf
        best_completed_patch = completed_patch.copy()
        convergence_criterion = np.inf
        count = 0
        if proj =='nsym':
            threshold_C = np.zeros((Nfullpatch,Nfullpatch))
        T_init = None
            
        while (convergence_criterion>= eps) and (count< max_iter):
            prev_loss=curr_loss
            #print('count :%s /curr_loss : %s'%(count,curr_loss))
            # compute transport between completed patch and dictionary
            local_OT,curr_loss = srGW_operator(completed_patch,completed_patch_masses, self.C_target,T_init=T_init)
            if use_warmstart:
                T_init = local_OT
            if curr_loss <best_loss:
                best_loss  = curr_loss
                best_completed_patch = completed_patch.copy()
            # compute gradient to update the completed_patch
            if not use_optimizer:
                completed_patch -= 2*lr*learnable_mask*( completed_patch*  weight_mask - local_OT.dot(self.C_target).dot(local_OT.T))
            else:
                grad=learnable_mask*( completed_patch*  weight_mask - local_OT.dot(self.C_target).dot(local_OT.T))
                m1_t= beta_1*adam_moment1+ (1-beta_1)*grad
                m2_t= beta_2*adam_moment2+(1-beta_2)*(grad**2)
                m1_t_unbiased = m1_t /(1-beta_1**adam_count)
                m2_t_unbiased = m2_t /(1-beta_2**adam_count)
                completed_patch-= lr*m1_t_unbiased/(np.sqrt(m2_t_unbiased)+10**(-15))
                adam_moment1= m1_t
                adam_moment2= m2_t
                adam_count +=1
                #projection on nonnegative matrices
                completed_patch= np.maximum(threshold_C, completed_patch)
            #print('count: %s / curr_loss: %s / grad norm: %s '%(count,curr_loss,np.linalg.norm(grad)))
            
            if proj =='nsym':
                completed_patch= np.maximum(threshold_C, completed_patch)
            if prev_loss != 0:
                convergence_criterion = np.abs(prev_loss-curr_loss)/prev_loss
            else:# if no reg the loss can not be negative
                break
            count+=1
            if use_log:
                local_log['loss'].append(curr_loss)
            
        return best_completed_patch, best_loss, local_log, init_completed_patch 
    
    
    def save_elements(self,save_settings=False,use_checkpoint = False):
        path = os.path.abspath('../')+self.experiment_repo
        print('path',path)
        if not os.path.exists(path+self.experiment_name):
            os.makedirs(path+self.experiment_name)
            print('made dir', path+self.experiment_name)
        if not use_checkpoint:
            np.save(path+'%s/Ctarget.npy'%self.experiment_name, self.Ctarget)
        else:
            self.checkpoint_Ctarget.append(self.Ctarget.copy())
            np.save(path+'%s/checkpoint_Ctarget.npy'%self.experiment_name, np.array(self.checkpoint_Ctarget))
        for key in self.log.keys():
            np.save(path+'%s/%s.npy'%(self.experiment_name,key), np.array(self.log[key]))
            
        if save_settings:
            pd.DataFrame(self.settings, index=self.settings.keys()).to_csv(path+'%s/settings'%self.experiment_name)

            
    def load_elements(self, use_checkpoint=False):
        path = os.path.abspath('../')+self.experiment_repo
        if not use_checkpoint:
            self.Ctarget = np.load(path+'%s/Ctarget.npy'%self.experiment_name)
            self.Ntarget = self.Ctarget.shape[0]
        else:
            self.checkpoint_Ctarget = np.load(path+'%s/checkpoint_Ctarget.npy'%self.experiment_name)
            self.Ntarget = self.checkpoint_Ctarget.shape[-1]



            