import numpy as np
from tqdm import tqdm
import ot
import GDL_utils as gwu
import SGWL.gromovWassersteinAveraging as gwa

#%% OPTIMAL GROMOV WASSERSTEIN QUANTIZATION

def degrees_interpolation(deg,Nt):
    #interpolation on masses based on (Xu & al,2019a)
    x_t = np.linspace(0, 1, Nt)
    sorted_deg = np.sort(deg)[::-1]
    x_s = np.linspace(0, 1, deg.shape[0])
    p_t = np.interp(x_t, x_s, sorted_deg) + 1e-3

    p_t /= np.sum(p_t)
    return p_t

def initializer_semirelaxedGW(init_mode,p,N1,N2,seed=0):
    #init_mode = 'product' default for the method
    if init_mode=='random': 
        if not (seed is None):
            np.random.seed(seed)
        seed=None 
        T = np.random.uniform(low=0, high=1,size=((N1,N2)))
        # scaling to satisfy first marginal constraints
        scale=p/ np.sum(T,axis=1)
        T *= scale[:,None]
    elif init_mode =='product':
        q= np.ones(N2)/N2
        T = p[:,None].dot(q[None,:])
    elif init_mode =='random_product':
        if not (seed is None):
            np.random.seed(seed)
        seed=None
        q = np.random.uniform(low=0,high=1,size=N2)
        q/=np.sum(q)
        T = p[:,None].dot(q[None,:])
    else:
        raise 'unknown init mode'    
    return T

def GW_relaxedmarginal(C1,p,C2, init_mode,T_init=None,use_log = False,eps=10**(-5),max_iter=1000,seed=0,verbose=False):
    """ 
        Exact solver for srGW,
        described in Algorithm 1. of the main paper.
        
        Note that our implementation uses the linearity of the GW loss gradient
        which is why we only explicity compute the gradient once.
    """
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    best_T = T.copy()
    q= np.sum(T,axis=0)
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    #lets evaluate the objective for the current state:
    G = ot.gromov.gwggrad(constC, hC1, hC2, T)
    current_loss = 0.5*np.sum(G*T)
    if use_log:
        log={}
        log['loss']=[current_loss]
    best_loss = np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        # 0. Gradient known from evaluation of the  cost function
        # 1. Direction finding by solving each subproblem on rows
        min_ = np.min(G,axis=1)
        X = (G== min_[:,None]).astype(np.float64)
        X *= (p/np.sum(X,axis=1))[:,None]
        # 3. Exact line-search step
        # Compute litteral expressions of coefficients a*\gamma^2 +b \gamma +c
        constCX, hC1X, hC2X = gwu.np_init_matrix_GW2(C1, C2, p, np.sum(X,axis=0))
        GX = ot.gromov.gwggrad(constCX, hC1X, hC2X, X)
        GXX = 0.5*np.sum(GX*X)
        GXT = 0.5*np.sum(GX*T) # Here we say  GXT = GTX= 0.5*np.sum(G*X) what is true if C1 and C2 are symmetric
        a = current_loss + GXX - 2*GXT 
        b= -2*current_loss + 2*GXT 
        if a>0:
            gamma = min(1, max(0, -b/(2*a)))
        elif a+b<0:
            gamma=1
        else:
            gamma=0
        T = (1-gamma)*T +gamma*X 
        q= np.sum(T,axis=0)
        #compute new gradient
        G = (1-gamma)*G + gamma*GX
        current_loss+= a*(gamma**2)+ b*gamma 
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss   
    



def GW_relaxedmarginal_linear_regularization(C1,h1,C2, init_mode,lin_reg,T_init=None,use_log = False,eps=10**(-5),max_iter=1000,seed=0,verbose=False):
    """ 
        Exact solver for srGW penalized with a linear OT terms,
        described in Algorithm 3. of the annex.
        
    """
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,h1,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    best_T = T.copy()
    q= np.sum(T,axis=0)
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, h1, q)
    #lets evaluate the objective for the current state:
    #1. Compute first gradient
    G = ot.gromov.gwggrad(constC, hC1, hC2, T)
    current_loss_unreg = 0.5*np.sum(G*T)
    current_loss_reg = np.sum(lin_reg*T)
    current_loss = current_loss_unreg+ current_loss_reg
    G+=lin_reg
    if use_log:
        log={}
        log['loss']=[current_loss]
    best_loss = np.inf
    saved_loss_unreg=np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        # 2. Direction-finding step
        min_ = np.min(G,axis=1)
        X = (G== min_[:,None]).astype(np.float64)
        X *= (h1/np.sum(X,axis=1))[:,None]
        # 3. Exact line-search step
        # Compute litteral expressions of coefficients a*\gamma^2 +b \gamma +c
        constCX, hC1X, hC2X = gwu.np_init_matrix_GW2(C1, C2, h1, np.sum(X,axis=0))
        GX = ot.gromov.gwggrad(constCX, hC1X, hC2X, X)
        GXX = 0.5*np.sum(GX*X)
        X_reg = np.sum(lin_reg*X)
        GXT = 0.5*np.sum(GX*T) # Here we say  GXT = GTX= 0.5*np.sum(G*X) what is true if C1 and C2 are symmetric
        a = current_loss_unreg + GXX - 2*GXT 
        b= -2*current_loss_unreg + 2*GXT + X_reg - current_loss_reg
        if a>0:
            gamma = min(1, max(0, -b/(2*a)))
        elif a+b<0:
            gamma=1
        else:
            gamma=0
        T = (1-gamma)*T +gamma*X
        #1. Loop: compute new gradient
        G = (1-gamma)*G + gamma*(GX+lin_reg)
        current_loss+= a*(gamma**2)+ b*gamma 
        current_loss_reg=(1-gamma)*current_loss_reg + gamma*X_reg
        current_loss_unreg=current_loss-current_loss_reg
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
            
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            saved_loss_unreg = current_loss_unreg
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, saved_loss_unreg,log
    else:
        return best_T,best_loss,saved_loss_unreg  
    
def GW_relaxedmarginal_majorationminimization_lpl1(C1,h1,C2, init_mode,p_reg=1/2, lambda_reg = 0.001,gamma_entropy=0., T_init=None,
                                                   use_log = False,use_warmstart=True,eps_inner=10**(-6),
                                                   eps_outer=10**(-6),eps_reg = 10**(-15),max_iter_inner=1000,
                                                   max_iter_outer =50,seed=0,verbose=False,inner_log = False):
    """ 
        Exact solver for srGW penalized with the sparse regularizer
        described in Section 3 of the main paper, 
        and detailed algorithm in Algorithm 5. of the annex.
    """
    assert 0<p_reg<1
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,h1,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    
    lin_reg = 0  
    best_T = T.copy()
    if use_log:
        log={}
        log['loss']=[]
        log['loss_reg']=[]
        if inner_log:
            log['inner_loss']=[]
        #log['T']=[T.copy()]
    best_loss = np.inf
    previous_loss = 10**15
    current_loss=  10**14
    ones_ = np.ones((N1,1))
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps_outer) and (outer_count<max_iter_outer):
        # For a fixed linear penalty R_t, solve OT by exact or entropic solver
        previous_loss = current_loss
        if gamma_entropy ==0:
            if inner_log :
                T,majorization_current_loss,current_loss_unreg,inner_log_= GW_relaxedmarginal_linear_regularization(C1,h1,C2, init_mode,lin_reg,T_init=T_init,use_log = True,eps=eps_inner,max_iter=max_iter_inner,seed=seed,verbose=verbose)
            else:
                T,majorization_current_loss,current_loss_unreg= GW_relaxedmarginal_linear_regularization(C1,h1,C2, init_mode,lin_reg,T_init=T_init,use_log = False,eps=eps_inner,max_iter=max_iter_inner,seed=seed,verbose=verbose)        
        else:
            if inner_log :
    
                T,majorization_current_loss,current_loss_unreg,inner_log_= entropic_semirelaxedGW_linear_regularization(C1,h1,C2, gamma_entropy,init_mode,lin_reg,T_init=T_init,use_log = True,eps=eps_inner,max_iter=max_iter_inner,seed=seed,verbose=verbose)
            else:
                T,majorization_current_loss,current_loss_unreg= entropic_semirelaxedGW_linear_regularization(C1,h1,C2, gamma_entropy,init_mode,lin_reg,T_init=T_init,use_log = False,eps=eps_inner,max_iter=max_iter_inner,seed=seed,verbose=verbose)        
        
        if use_warmstart:#used as initialization for srGW solver at the next iteration
            T_init = T.copy()
        #Update the linear penality --> R_{t+1}
        q = np.sum(T,axis=0)
        current_loss = current_loss_unreg
        current_loss_reg = lambda_reg*np.sum((q+eps_reg)**p_reg)
        current_loss += current_loss_reg
        lin_reg = lambda_reg*p_reg*(ones_.dot(q[None,:])+eps_reg)**(p_reg-1)
            
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
            log['loss_reg'].append(current_loss_reg)
            if inner_log:
                log['inner_loss'].append(inner_log_)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss    




def GW_relaxedmarginal_asym(C1,p,C2, T_init,init_mode,use_log = False,eps=10**(-5),max_iter=1000,seed=0,verbose=False):
    """ 
        Solver: GW with relaxed marginal for asymmetric/directed graphs
        > It is the same one than for undirected graphs up to few missing factorization,
        which are false when C1 and C2 asymmetric.
            
    """
    
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    best_T = T.copy()
    q= np.sum(T,axis=0)
    constC, hC1, hC2 = gwa.init_matrix(C1,C2,T,p,q) #T is not used here except through q 
    constCt, hC1t, hC2t = gwa.init_matrix(C1.T,C2.T,T,p,q) # T is not used here except through q
    subG = gwa.tensor_product(constC, hC1, hC2, T)
    subG_T = gwa.tensor_product(constCt, hC1t, hC2t,T)
    G = subG+subG_T
    current_loss=0.5 * np.sum(G*T)
    if use_log:
        log={}
        log['loss']=[current_loss]
    best_loss = np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        #0. Gradient known from evaluation of the cost function
        #1. Direction finding by solving each subproblems on rows
        min_ = np.min(G,axis=1)
        X = (G== min_[:,None]).astype(np.float64)
        X *= (p/np.sum(X,axis=1))[:,None]
        # 3. Exact line-search step        
        constCX, hC1X, hC2X = gwa.init_matrix(C1, C2, X,p, np.sum(X,axis=0)) #X is not used here except through q
        constCXt, hC1Xt, hC2Xt= gwa.init_matrix(C1.T, C2.T,X, p, np.sum(X,axis=0)) # X is not used here except through q
        subGX = gwa.tensor_product(constCX, hC1X, hC2X, X)
        subGX_T = gwa.tensor_product(constCXt, hC1Xt, hC2Xt,X)
        GX = subGX+subGX_T
        GXX = 0.5*np.sum(GX*X)
        subGX_T_dotT = np.sum(subGX_T*T) # \sum_ijkl (C_ij - Cbar_kl)^2 X_ik T_jl
        subGT_T_dotX = np.sum(subG_T*X) # \sum_ijkl (C_ij - Cbar_kl)^2 T_ik X_jl
        a = current_loss + GXX -subGX_T_dotT - subGT_T_dotX
        b= -2*current_loss + subGX_T_dotT + subGT_T_dotX
                
        if a>0:
            gamma = min(1, max(0, -b/(2*a)))
        elif a+b<0:
            gamma=1
        else:
            gamma=0
        T = (1-gamma)*T +gamma*X
        q= np.sum(T,axis=0)
        subG_T = (1-gamma)*subG_T + gamma*subGX_T
        G= (1-gamma)*G + gamma*GX
        current_loss+= a*(gamma**2)+b*gamma
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-10))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss


#%% Entropic version of the semi-relaxed GW

def entropic_semirelaxedGW(C1,p,C2, gamma_entropy, init_mode,T_init=None,use_log = False,eps=10**(-5),max_iter=1000,seed=0,verbose=False,force_learning=True):
    """ 
        Mirror descent for entropic srGW, 
        algorithm described in Section 3.2 of the main paper,
        details provided in the annex cf Algorithm 4.
    """
    assert gamma_entropy>0
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T = initializer_semirelaxedGW(init_mode,p, N1,N2,seed=seed )
    else:
        T = T_init.copy()
    
    best_T = T.copy()
    q= np.sum(T,axis=0)
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    #lets evaluate the objective for the current state:
    G = ot.gromov.gwggrad(constC, hC1, hC2, T)
    current_loss = 0.5*np.sum(G*T)
    if use_log:
        log={}
        log['loss']=[current_loss]
    if not force_learning:
        best_loss = np.inf
        convergence_criterion = np.inf
    else:
        best_loss = current_loss
        convergence_criterion = np.inf
        
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        #1. Compute M_k(epsilon) = 2 <L(C1,C2) \kron T_k> - gamma_entropie* log(T_k)
        # single Bregman projection
        M= G -gamma_entropy*np.log(T)
        K= np.exp(-M/gamma_entropy)
        scaling = p/K.sum(axis=1)
        T = np.diag(scaling).dot(K)
        q= np.sum(T,axis=0)
        constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
        #2. lets evaluate the objective for the current state:
        G = ot.gromov.gwggrad(constC, hC1, hC2, T)
        current_loss = 0.5*np.sum(G*T)
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss   
    


def entropic_semirelaxedGW_linear_regularization(C1,p,C2, gamma_entropy,init_mode,lin_reg,T_init=None,use_log = False,eps=10**(-5),max_iter=1000,seed=0,verbose=False):
    """ 
        Mirror descent for entropic srGW with a linear penalty term <D,T>, 
        algorithm described in Section 3.2 of the main paper,
        details provided in the annex cf Algorithm 4.
    """
    assert gamma_entropy>0
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    best_T = T.copy()
    q= np.sum(T,axis=0)
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    #1.lets evaluate the objective for the current state:
    G = ot.gromov.gwggrad(constC, hC1, hC2, T)
    current_loss_unreg = 0.5*np.sum(G*T)
    current_loss_reg = np.sum(lin_reg*T)
    current_loss = current_loss_unreg+ current_loss_reg
    G+=lin_reg
    if use_log:
        log={}
        log['loss']=[current_loss]
    best_loss = np.inf
    saved_loss_unreg=np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        #2. Compute M_k(epsilon) = 2 <L(C1,C2) \kron T_k> + C - gamma_entropie* log(T_k)
        #Single Bregman projection.
        M= G -gamma_entropy*np.log(T)
        K= np.exp(-M/gamma_entropy)
        scaling = p/K.sum(axis=1)
        T = np.diag(scaling).dot(K)
        q= np.sum(T,axis=0)
        constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
        #lets evaluate the objective for the current state:
        G = ot.gromov.gwggrad(constC, hC1, hC2, T)
        current_loss_unreg = 0.5*np.sum(G*T)
        current_loss_reg = np.sum(lin_reg*T)
        current_loss= current_loss_unreg + current_loss_reg
        G+=lin_reg
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            saved_loss_unreg = current_loss_unreg
            best_T = T.copy()
    
    if use_log:
        return best_T, best_loss, saved_loss_unreg,log
    else:
        return best_T,best_loss,saved_loss_unreg  
    


#%%semi_relaxed FUSED GROMOV-WASSERSTEIN :SOLVERS
"""
Note that all these algorithms are variants of the srGW solver with linear penalty <D,T>
As detailed in the annex.  
"""


def FGW_relaxedmarginal(C1,A1,p,C2,A2,alpha,
                        init_mode,T_init=None,use_log = False,
                        eps=10**(-5),max_iter=1000,seed=0):
    """ 
        Exact solver for srGW,
        described in Algorithm 1. of the main paper.
        
        Note that our implementation uses the linearity of the GW loss gradient
        which is why we only explicity compute the gradient once.
    """
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    
    best_T = T.copy()
    q= np.sum(T,axis=0)
    FS2 = (A1**2).dot(np.ones((A1.shape[1], A2.shape[0])))
    FT2 = (np.ones((A1.shape[0], A1.shape[1]))).dot((A2**2).T)
    D= FS2+FT2 - 2*A1.dot(A2.T)
    #Evaluate first gradient then use linearity
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    G_gw = alpha*ot.gromov.gwggrad(constC, hC1, hC2, T) # gradient from GW term 
    G_w = (1-alpha)*D #gradients from W term
    G=G_gw+G_w
    GW_current_loss = 0.5*np.sum(G_gw*T)
    W_current_loss = np.sum(G_w*T)
    current_loss = GW_current_loss + W_current_loss
    if use_log:
        log={}
        log['loss']=[current_loss]
    best_loss = np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        # 0. Gradient known from evaluation of the  cost function
        # 1. Direction finding by solving each subproblem on rows
        min_ = np.min(G,axis=1)
        X = (G== min_[:,None]).astype(np.float64)
        X *= (p/np.sum(X,axis=1))[:,None]
        # 3. Exact line-search step
        # Compute litteral expressions of coefficients a*\gamma^2 +b \gamma +c
        # Get terms from GW - a_gw = alpha*a_fgw
        constCX, hC1X, hC2X = gwu.np_init_matrix_GW2(C1, C2, p, np.sum(X,axis=0))
        GX = alpha*ot.gromov.gwggrad(constCX, hC1X, hC2X, X)
        GXX = 0.5*np.sum(GX*X)
        GXT = 0.5*np.sum(GX*T) 
        WX= np.sum(G_w*X)
        a_gw = GW_current_loss + GXX - 2*GXT # a coefficient only depends on GW terms
        b_gw= -2*GW_current_loss + 2*GXT
        b_w=-W_current_loss +WX
        b = b_gw+b_w
        if a_gw>0:
            gamma = min(1, max(0, -b/(2*a_gw)))
        elif a_gw+b<0:
            gamma=1
        else:
            gamma=0

        T = (1-gamma)*T +gamma*X
        G_gw = (1-gamma)*G_gw + gamma*GX
        G=G_gw+G_w
        GW_current_loss = a_gw*(gamma**2)+b_gw*gamma + GW_current_loss
        W_current_loss = b_w*gamma + W_current_loss
        current_loss = GW_current_loss + W_current_loss
        
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
            
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-10))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss




def FGW_relaxedmarginal_linear_regularization(C1,A1,p,C2,A2,D,
                                              alpha, init_mode,lin_reg,
                                              T_init=None,use_log = False,
                                              eps=10**(-5),max_iter=1000,seed=0):
    
    np.random.seed(seed)
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    
    best_T = T.copy()
    q= np.sum(T,axis=0)
    #compute distance matrix of features
    if D is None:
        FS2 = (A1**2).dot(np.ones((A1.shape[1], A2.shape[0])))
        FT2 = (np.ones((A1.shape[0], A1.shape[1]))).dot((A2**2).T)
        D= FS2+FT2 - 2*A1.dot(A2.T)
    # GW
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    #lets evaluate the objective for the current state:
    G_gw = alpha*ot.gromov.gwggrad(constC, hC1, hC2, T) # gradient from GW term 
    G_w = (1-alpha)*D #gradients from W term
    G=G_gw+G_w + lin_reg
    GW_current_loss = 0.5*np.sum(G_gw*T)
    W_current_loss = np.sum(G_w*T)
    current_reg = np.sum(lin_reg*T)
    current_loss_unreg = GW_current_loss + W_current_loss
    current_loss = current_loss_unreg + current_reg
    #current_loss = f1
    if use_log:
        log={}
        log['loss_unreg']=[current_loss_unreg]
        log['loss']=[current_loss]
        #log['gamma']=[]
        #log['T']=[T.copy()]
    best_loss = np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        # 0. Gradient known from evaluation of the  cost function
        # 1. Direction finding by solving each subproblem on rows
        min_ = np.min(G,axis=1)
        X = (G== min_[:,None]).astype(np.float64)
        X *= (p/np.sum(X,axis=1))[:,None]
        # 3. Exact line-search step
        # Compute litteral expressions of coefficients a*\gamma^2 +b \gamma +c
        # Get terms from GW - a_gw = alpha*a_fgw
        constCX, hC1X, hC2X = gwu.np_init_matrix_GW2(C1, C2, p, np.sum(X,axis=0))
        GX = alpha*ot.gromov.gwggrad(constCX, hC1X, hC2X, X)
        GXX = 0.5*np.sum(GX*X)
        GXT = 0.5*np.sum(GX*T) 
        WX= np.sum(G_w*X)
        regX = np.sum(lin_reg*X)
        a_gw = GW_current_loss + GXX - 2*GXT # a coefficient only depends on GW terms
        b_gw= -2*GW_current_loss + 2*GXT
        b_w=-W_current_loss +WX
        b_reg = regX - current_reg
        b = b_gw+b_w + b_reg
        if a_gw>0:
            gamma = min(1, max(0, -b/(2*a_gw)))
        elif a_gw+b<0:
            gamma=1
        else:
            gamma=0
       
        T = (1-gamma)*T +gamma*X
        G_gw = (1-gamma)*G_gw + gamma*GX
        G=G_gw+G_w
        GW_current_loss = a_gw*(gamma**2)+b_gw*gamma + GW_current_loss
        W_current_loss += b_w*gamma 
        current_reg += gamma*b_reg
        current_loss_unreg = GW_current_loss + W_current_loss
        current_loss =current_loss_unreg + current_reg
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
            log['loss_unreg'].append(current_loss_unreg)
         
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-10))
        if current_loss < best_loss:
            best_loss = current_loss
            corresponding_loss_unreg = current_loss_unreg 
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, corresponding_loss_unreg ,log,D
    else:
        return best_T,best_loss,corresponding_loss_unreg,D


def FGW_relaxedmarginal_majorationminimization_lpl1(C1,A1,h1,C2, A2, 
                                                    init_mode='product',gamma_entropy=0,
                                                    alpha_fgw=0.5, p_reg=1/2, lambda_reg = 0.0,
                                                    T_init=None,use_log = False,eps_inner=10**(-6),
                                                    eps_outer=10**(-6),eps_reg = 10**(-15),max_iter_inner=1000,
                                                    max_iter_outer =50,seed=0,verbose=False,inner_log = False,warmstart=True):
    
    assert 0<p_reg<1
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,h1,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    
    lin_reg = 0  
    best_T = T.copy()
    if use_log:
        log={}
        log['loss']=[]
        log['loss_reg']=[]
        if inner_log:
            log['inner_count']=[]
       
    best_loss = np.inf
    previous_loss = 10**15
    current_loss=  10**14
    ones_ = np.ones((N1,1))
    convergence_criterion = np.inf
    outer_count=0
    D=None
    while (convergence_criterion >eps_outer) and (outer_count<max_iter_outer):
        previous_loss = current_loss
        if gamma_entropy==0:
            if inner_log :
                T,majorization_current_loss,current_loss_unreg,inner_log_,D= FGW_relaxedmarginal_linear_regularization(C1,A1,h1,C2,A2,D, alpha_fgw,init_mode,lin_reg,T_init=T_init,use_log = True,eps=eps_inner,max_iter=max_iter_inner,seed=seed)
            else:
                T,majorization_current_loss,current_loss_unreg,D= FGW_relaxedmarginal_linear_regularization(C1,A1,h1,C2,A2, D,alpha_fgw, init_mode,lin_reg,T_init=T_init,use_log = False,eps=eps_inner,max_iter=max_iter_inner,seed=seed)        
        else:
            if inner_log :
                T,majorization_current_loss,current_loss_unreg,inner_log_,D= entropic_semirelaxedFGW_linear_regularization(C1,A1,h1,C2,A2, D,alpha_fgw,gamma_entropy,init_mode,lin_reg,T_init=T,use_log = True,eps=eps_inner,max_iter=max_iter_inner,seed=seed)
            else:
                T,majorization_current_loss,current_loss_unreg,D= entropic_semirelaxedFGW_linear_regularization(C1,A1,h1,C2,A2, D,alpha_fgw, gamma_entropy,init_mode,lin_reg,T_init=T,use_log = False,eps=eps_inner,max_iter=max_iter_inner,seed=seed)        
        if warmstart:
            T_init = T.copy()
        q = np.sum(T,axis=0)

        current_loss = current_loss_unreg
        current_loss_reg = lambda_reg*np.sum((q[None]+eps_reg)**p_reg)
        current_loss += current_loss_reg
        lin_reg = lambda_reg*p_reg*(ones_.dot(q[None,:])+eps_reg)**(p_reg-1)
        if verbose:
            print('---outer_count: %s / log : %s  / q : %s / current_loss : %s'%(outer_count, log['loss'], q, current_loss ))
            
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
            log['loss_reg'].append(current_loss_reg)
            if inner_log:
                log['inner_count'].append(len(inner_log_['loss']))
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss    
#%% Entropic srFGW



def entropic_semirelaxedFGW(C1,A1,p, # input structure
                            C2,A2,   # target structure
                            D=None,  # D != None if the pairwise matrix of distances between features is precomputed
                            alpha=0.5,
                            gamma_entropy=0.1, init_mode='product',
                            T_init=None,use_log = False,eps=10**(-5),
                            max_iter=1000,seed=0,verbose=False):
    """ 
        Solver: semi-relaxed Fused Gromov-Wasserstein with Mirror Descent scheme.
    """
    assert gamma_entropy>0
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    
    best_T = T.copy()
    q= np.sum(T,axis=0)
    if D is None:
        FS2 = (A1**2).dot(np.ones((A1.shape[1], A2.shape[0])))
        FT2 = (np.ones((A1.shape[0], A1.shape[1]))).dot((A2**2).T)
        D= FS2+FT2 - 2*A1.dot(A2.T)
    
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    #lets evaluate the objective for the current state:
    G_gw = alpha*ot.gromov.gwggrad(constC, hC1, hC2, T) # gradient from GW term 
    G_w = (1-alpha)*D #gradients from W term
    G=G_gw+G_w
    GW_current_loss = 0.5*np.sum(G_gw*T)
    W_current_loss = np.sum(G_w*T)
    current_loss = GW_current_loss + W_current_loss
    if use_log:
        log={}
        log['loss']=[current_loss]
    
    best_loss = np.inf
    convergence_criterion = np.inf
    
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        #1. Compute M_k(epsilon) = 2*alpha* <L(C1,C2) \kron T_k> + (1-alpha) D - gamma_entropie* log(T_k)
        M= G_gw + G_w -gamma_entropy*np.log(T)
        K= np.exp(-M/gamma_entropy)
        scaling = p/K.sum(axis=1)
        T = np.diag(scaling).dot(K)
        q= np.sum(T,axis=0)
        constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
        #2. lets evaluate the objective for the current state and update gradient
        G_gw = ot.gromov.gwggrad(constC, hC1, hC2, T)
        GW_current_loss = 0.5*np.sum(G*T)
        W_current_loss = np.sum(G_w*T)
        current_loss = GW_current_loss + W_current_loss
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-15))
        if current_loss < best_loss:
            best_loss = current_loss
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss
    

def entropic_semirelaxedFGW_linear_regularization(C1,A1,p,
                                                  C2,A2,
                                                  D=None,
                                                  alpha=0.5, 
                                                  gamma_entropy=0.1,
                                                  init_mode='product',lin_reg=0,T_init=None,use_log = False,eps=10**(-5),max_iter=1000,seed=0,verbose=False):
    """ 
        Solver: semi relaxed Fused Gromov-Wasserstein with additional linear OT penalty term
    """
    np.random.seed(seed)
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T=initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        T = T_init.copy()
    
    best_T = T.copy()
    q= np.sum(T,axis=0)
    #compute distance matrix of features
    if D is None:
        FS2 = (A1**2).dot(np.ones((A1.shape[1], A2.shape[0])))
        FT2 = (np.ones((A1.shape[0], A1.shape[1]))).dot((A2**2).T)
        D= FS2+FT2 - 2*A1.dot(A2.T)
    # GW
    constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
    #lets evaluate the objective for the current state:
    G_gw = alpha*ot.gromov.gwggrad(constC, hC1, hC2, T) # gradient from GW term 
    G_w = (1-alpha)*D #gradients from W term
    G=G_gw+G_w + lin_reg
    GW_current_loss = 0.5*np.sum(G_gw*T)
    W_current_loss = np.sum(G_w*T)
    current_reg = np.sum(lin_reg*T)
    current_loss_unreg = GW_current_loss + W_current_loss
    current_loss = current_loss_unreg + current_reg
    #current_loss = f1
    if use_log:
        log={}
        log['loss_unreg']=[current_loss_unreg]
        log['loss']=[current_loss]
        #log['gamma']=[]
        #log['T']=[T.copy()]
    best_loss = np.inf
    saved_loss_unreg = current_loss_unreg
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss
        
        # Compute M_k(epsilon) = 2*alpha* <L(C1,C2) \kron T_k> + (1-alpha) D +C - gamma_entropie* log(T_k)
        M= G -gamma_entropy*np.log(T)
        K= np.exp(-M/gamma_entropy)
        scaling = p/K.sum(axis=1)
        if verbose:
            print('current loss: %s / embedding :%s / scaling: %s  /scaling_denominator: %s'%(current_loss,q,scaling,K.sum(axis=1)))
        T = np.diag(scaling).dot(K)
        q= np.sum(T,axis=0)
        constC, hC1, hC2 = gwu.np_init_matrix_GW2(C1, C2, p, q)
        #lets evaluate the objective for the current state:
        G_gw = ot.gromov.gwggrad(constC, hC1, hC2, T)
        GW_current_loss = 0.5*np.sum(G*T)
        W_current_loss = np.sum(G_w*T)
        current_loss_unreg = GW_current_loss + W_current_loss
        current_loss_reg = np.sum(lin_reg*T)
        G=G_gw+G_w+lin_reg
        current_loss = current_loss_unreg+current_loss_reg
        outer_count+=1
        
        if use_log:
            log['loss'].append(current_loss)
            log['loss_unreg'].append(current_loss_unreg)
            #log['T'].append(T.copy())
            #log['gamma'].append(gamma)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss)/ abs(previous_loss+10**(-10))
        if current_loss < best_loss:
            best_loss = current_loss
            saved_loss_unreg = current_loss_unreg 
            best_T = T.copy()
    if use_log:
        return best_T, best_loss, saved_loss_unreg ,log,D
    else:
        return best_T,best_loss,saved_loss_unreg,D



