
import numpy as np
from tqdm import tqdm
import ot
import GDL_utils as gwu
import SGWL.gromovWassersteinAveraging as gwa
import torch as th

#%% TORCH IMPLEMENTATION OF SEMI-RELAXED GROMOV-WASSERSTEIN - TO PLUG ON GPU
"""
    Used for graph partitioning experiments.
    Really useful when graphs are large enough.
    Otherwise parallelizing the unmixings on a single GPU as we do not need
    autograd can bring great speed ups. 
"""

def torch_initializer_semirelaxedGW(init_mode,p,N1,N2,seed=0,tensor_type=th.float64, device='cuda:0'):
    if init_mode=='random': 
        if not (seed is None):
            th.manual_seed(seed)
        T = th.rand(size=((N1,N2)), dtype=tensor_type, device=device)
        # scaling to satisfy first marginal constraints
        scale=p/ T.sum(axis=1)
        T *= scale[:,None]
    elif init_mode =='product':
        q= th.ones(N2,dtype=tensor_type, device=device)/N2
        T= p[:,None]@q[None,:]
    elif init_mode =='random_product':
        if not (seed is None):
            th.manual_seed(seed)
        seed=None
        q = th.rand(size=N2,dtype=tensor_type,device=device)
        q/= q.sum()
        T = p[:,None] @ q[None,:]
    else:
        raise 'unknown init mode'
    return T

def torch_init_matrix_GW2(C1, C2, p, q,ones_p, ones_q):
    f1_ , f2_ = C1**2 , C2**2
    constC1 = f1_ @ ( p[:,None]@ ones_q[None,:] )
    constC2 = (ones_p[:,None] @ q[None,:]) @ f2_
    constC=constC1+constC2
    return constC, C1,2*C2 


def torch_init_matrix_asym(C1, C2, p, q,ones_p, ones_q):
    f1_ , f2_ = C1**2/2. , C2**2/2.
    constC1 = f1_ @ ( p[:,None]@ ones_q[None,:] )
    constC2 = (ones_p[:,None] @ q[None,:]) @ f2_.T
    constC=constC1+constC2
    return constC, C1,C2 

def torch_GWgrad(constC,hC1,hC2,T):
    A = - hC1 @ T @ (hC2.T)
    tens = constC + A
    return 2*tens

    
def torch_semirelaxedGW(C1:th.Tensor,
                        p:th.Tensor,
                        C2:th.Tensor, 
                        init_mode:str,
                        T_init:th.Tensor=None,
                        use_log:bool = False,
                        eps:float=10**(-5),
                        max_iter:int=1000,
                        seed:int=0,
                        verbose:bool=False,
                        device:str='cpu',
                        tensor_type:type=th.float32):
    r""" 
        Solver FW algorithm for unregularized semi-relaxed GW:
            \min_{T >=0 , T1 = h_1} <L(C_1, C_2) \otimes T, T> 
        
        Nb: C constant matrix used while enforcing concave sparse regularization on 
        unmixings \overline{h} = T^\top 1 .
        
    """
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    
    if T_init is None:
        T= torch_initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed,tensor_type=tensor_type, device=device)
    else:
        assert list(T_init.shape)==[N1,N2] 
        T = T_init.clone()
    best_T = T.clone()
    # Get gradient frominitial starting point
    q= T.sum(axis=0) 
    ones_p = th.ones(p.shape[0],dtype=tensor_type, device=device)
    ones_q = th.ones(q.shape[0],dtype=tensor_type, device=device)
    constC, hC1, hC2 = torch_init_matrix_GW2(C1, C2, p, q,ones_p, ones_q)
    G = torch_GWgrad(constC,hC1,hC2,T)
    current_loss = 0.5*th.sum(G*T)
    #current_loss = f1
    if use_log:
        log={}
        log['loss']=[current_loss.item()]

    best_loss = np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        #print('iter : %s / T device: %s'%(outer_count, T.device))
        previous_loss = current_loss.clone()
        # 0. Gradient known from evaluation of the  cost function
        # 1. Direction finding by solving each subproblem on rows
        min_, _ = G.min(axis=1)
        X = (G== min_[:,None]).type(tensor_type)
        X *= (p/X.sum(axis=1))[:,None]
        # 3. Exact line-search step
        # Compute litteral expressions of coefficients a*\gamma^2 +b \gamma +c
        constCX, hC1X, hC2X = torch_init_matrix_GW2(C1, C2, p, X.sum(axis=0),ones_p,ones_q)
        GX = torch_GWgrad(constCX, hC1X, hC2X, X)
        GXX = 0.5*th.sum(GX*X)
        GXT = 0.5*th.sum(GX*T) # Here we say  GXT = GTX= 0.5*np.sum(G*X) what is true if C1 and C2 are symmetric
        a = current_loss + GXX - 2*GXT 
        b= -2*current_loss + 2*GXT 
        
        if a>0:
            gamma = min(1, max(0, -b.item()/(2*a.item())))
        elif a+b<0:
            gamma=1
        else:
            gamma=0
        T = (1-gamma)*T +gamma*X 
        q= T.sum(axis=0)
        #new grad
        G = (1-gamma)*G + gamma*GX
        current_loss+= a*(gamma**2)+ b*gamma 
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss.item())
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss.item() - current_loss.item())/ abs(previous_loss.item())
        else:
            convergence_criterion = abs(previous_loss.item() - current_loss.item())/ abs(previous_loss.item()+10**(-15))
        if current_loss.item() < best_loss:
            best_loss = current_loss.item()
            best_T = T.clone()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss   
    



def torch_semirexaledGW_linear_regularization(C1:th.Tensor,
                                              p:th.Tensor,
                                              C2:th.Tensor,
                                              init_mode:str='product',
                                              lin_reg:th.Tensor=None,
                                              T_init:th.Tensor=None,
                                              use_log:bool= False,
                                              eps:float=10**(-5),
                                              max_iter:int=1000,
                                              seed:int=0,
                                              verbose:bool=False,
                                              tensor_type:type=th.float64,
                                              device:str='cpu'):
    """ 
        Solver FW algorithm for unregularized semi-relaxed GW:
            \min_{T >=0 , T1 = h_1} <L(C_1, C_2) \otimes T, T> + <Const, T>
        
        Nb: C constant matrix used while enforcing concave sparse regularization on 
        unmixings \overline{h} = T^\top 1 .
    """
    if verbose:
        print('call semi-relaxed GW with linear reg')
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T= torch_initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed)
    else:
        assert list(T_init.shape)==[N1,N2] 
        T = T_init.detach().clone()
    best_T = T.detach().clone()
    q= T.sum(axis=0) 
    ones_p = th.ones(p.shape[0],dtype=tensor_type,device=device)
    ones_q = th.ones(q.shape[0],dtype=tensor_type,device=device)
    best_T = T.detach().clone()
    constC, hC1, hC2 = torch_init_matrix_GW2(C1, C2, p, q,ones_p, ones_q)
    #lets evaluate the objective for the current state:
    G = torch_GWgrad(constC,hC1,hC2,T)
    current_loss_unreg = 0.5*th.sum(G*T)
    current_loss_reg = th.sum(lin_reg*T)
    current_loss = current_loss_unreg + current_loss_reg
    G+= lin_reg
    if use_log:
        log={}
        log['loss']=[current_loss.item()]
        #log['gamma']=[]
        #log['T']=[T.detach().clone()]

    best_loss = np.inf
    saved_loss_unreg=np.inf
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss.item()
        # 0. Gradient known from evaluation of the  cost function
        # 1. Direction finding by solving each subproblem on rows
        min_, min_indices = G.min(axis=1)
        X = (G== min_[:,None]).type(tensor_type)
        X *= (p/X.sum(axis=1))[:,None]
        # 3. Exact line-search step
        # Compute litteral expressions of coefficients a*\gamma^2 +b \gamma +c
        constCX, hC1X, hC2X = torch_init_matrix_GW2(C1, C2, p, X.sum(axis=0),ones_p,ones_q)
        GX = torch_GWgrad(constCX, hC1X, hC2X, X)
        GXX = 0.5*th.sum(GX*X)
        X_reg = th.sum(lin_reg*X)
        GXT = 0.5*th.sum(GX*T) # Here we say  GXT = GTX= 0.5*np.sum(G*X) what is true if C1 and C2 are symmetric
        a = current_loss_unreg + GXX - 2*GXT 
        b= -2*current_loss_unreg + 2*GXT +X_reg - current_loss_reg
        #print('a:%s / b:%s'%(a,b))
        if a.item()>0:
            gamma = min(1, max(0, -b.item()/(2*a.item())))
        elif a.item()+b.item()<0:
            gamma=1
        else:
            gamma=0
        T = (1-gamma)*T +gamma*X # that is the normal order - should be harmonised between functions
        q= T.sum(axis=0)
        #new grad
        G = (1-gamma)*G + gamma*GX
        current_loss+= a*(gamma**2)+ b*gamma 
        current_loss_reg = (1-gamma)*current_loss_reg + gamma*X_reg
        current_loss_unreg = current_loss-current_loss_reg
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss.item())
            #log['T'].append(T.detach().clone())
            #log['gamma'].append(gamma)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss.item())/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss.item())/ abs(previous_loss+10**(-15))
        if current_loss.item() < best_loss:
            best_loss = current_loss.item()
            saved_loss_unreg= current_loss_unreg.item()
            best_T = T.detach().clone()
    if use_log:
        return best_T, best_loss,saved_loss_unreg, log
    else:
        return best_T,best_loss,saved_loss_unreg




    
def torch_semirexaledGW_majorationminimization_lpl1(C1:th.Tensor,
                                                    h1:th.Tensor,
                                                    C2:th.Tensor, 
                                                    init_mode:str='product',
                                                    p_reg:float=1/2,
                                                    lambda_reg:float = 0.001,
                                                    T_init:th.Tensor=None,
                                                    use_log:bool = False,
                                                    use_warmstart:bool=False,
                                                    eps_inner:float=10**(-6),
                                                    eps_outer:float=10**(-6),
                                                    eps_reg:float = 10**(-15),
                                                    max_iter_inner:int =1000,
                                                    max_iter_outer:int =50,
                                                    seed:int=0,
                                                    verbose:bool=False,
                                                    inner_log:bool = False,
                                                    tensor_type:type=th.float64,
                                                    device:str='cpu'):
    r""" 
        Solver:
            sparse regularization: 
                \Omega(T) = + lambda_reg* \sum_j ( \sum_i T_ij)^p with 0<p<1.
            problem:
                min_{T \geq 0, T1= h_1} <L(C_1, C_2) \otimes T, T> + \Omega(T)
    """
    assert 0<p_reg<1
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    if T_init is None:
        T= torch_initializer_semirelaxedGW(init_mode,h1,N1,N2,seed=seed,tensor_type=tensor_type, device=device)
    else:
        assert list(T_init.shape)==[N1,N2] 
        T = T_init.detach().clone()
    lin_reg = 0  
    best_T = T.detach().clone()
    if use_log:
        log={}
        log['loss']=[]
        log['loss_reg']=[]
        if inner_log:
            log['inner_loss']=[]
        #log['T']=[T.copy()]
    best_loss = np.inf
    ones_p = th.ones(N1)
    previous_loss = 10**15
    current_loss=  10**14
    convergence_criterion = np.inf
    outer_count=0
    while (convergence_criterion >eps_outer) and (outer_count<max_iter_outer):
        previous_loss = current_loss.item()
        if inner_log :
            T,majorization_current_loss,current_loss_unreg,inner_log_= torch_semirexaledGW_linear_regularization(C1,h1,C2, init_mode,lin_reg,T_init=T_init,use_log = True,eps=eps_inner,
                                                                                                                 max_iter=max_iter_inner,seed=seed,verbose=verbose,
                                                                                                                 tensor_type=tensor_type,device=device)
        else:
            T,majorization_current_loss,current_loss_unreg= torch_semirexaledGW_linear_regularization(C1,h1,C2, init_mode,lin_reg,T_init=T_init,use_log = False,
                                                                                                      eps=eps_inner,max_iter=max_iter_inner,seed=seed,
                                                                                                      verbose=verbose,tensor_type=tensor_type,device=device)        
        if use_warmstart:
            T_init = T.detach().clone()
        q = T.sum(axis=0)
        current_loss = current_loss_unreg
        current_loss_reg = lambda_reg*th.sum((q+eps_reg)**p_reg)
        current_loss += current_loss_reg
        lin_reg = lambda_reg*p_reg*((ones_p@q[None,:])+eps_reg)**(p_reg-1)
        #if verbose:
        #    print('---outer_count: %s / log : %s  / q : %s '%(outer_count, log['loss'], q))
            
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss.item())
            log['loss_reg'].append(current_loss_reg.item())
            if inner_log:
                log['inner_loss'].append(inner_log_)
            #log['gamma'].append(gamma)
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss - current_loss.item())/ abs(previous_loss)
        else:
            convergence_criterion = abs(previous_loss - current_loss.item())/ abs(previous_loss+10**(-15))
        if current_loss.item() < best_loss:
            best_loss = current_loss.item()
            best_T = T.detach().clone()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss    



def entropic_semirelaxedGW(C1:th.tensor,
                           p:th.tensor,
                           C2:th.tensor, 
                           gamma_entropy:float,
                           init_mode:str,
                           T_init:th.tensor=None,
                           use_log:bool= False,
                           eps:float=10**(-5),
                           max_iter:int=1000,
                           seed:int=0,
                           verbose:bool=False,
                           device:str='cpu',
                           tensor_type:type=th.float64,
                           graph_mode='sym',
                           force_learning:bool=False):
    """ 
        Solver:
            Some computational changes to reduce runtimes - Goes 4 to 5 times faster than V0
            Using symmetry of C1 and C2 for better factorization
            +Avoid to compute 
    """
    assert graph_mode in  ['sym','asym']
    assert gamma_entropy>0
    N1 = C1.shape[0]
    N2 = C2.shape[0]
    previous_loss = 10**(8)
    if T_init is None:
        T= torch_initializer_semirelaxedGW(init_mode,p,N1,N2,seed=seed,tensor_type=tensor_type, device=device)
    else:
        assert list(T_init.shape)==[N1,N2] 
        T = T_init.clone()
    best_T = T.clone()
    #Get gradient from initial starting point
    q= T.sum(axis=0) 
    ones_p = th.ones(p.shape[0],dtype=tensor_type, device=device)
    ones_q = th.ones(q.shape[0],dtype=tensor_type, device=device)
    if graph_mode=='sym':
        constC, hC1, hC2 = torch_init_matrix_GW2(C1, C2, p, q,ones_p, ones_q)
        G = torch_GWgrad(constC,hC1,hC2,T)
    
    else:
        constC,hC1,hC2 = torch_init_matrix_asym(C1, C2, p, q, ones_p, ones_q)
        constCt, hC1t, hC2t = torch_init_matrix_asym(C1.T,C2.T,p,q,ones_p, ones_q) 
        subG =torch_GWgrad(constC,hC1,hC2,T)
        subG_T= torch_GWgrad(constCt, hC1t, hC2t,T)
        G = subG+subG_T
    current_loss = 0.5*th.sum(G*T)        
    #current_loss = f1
    if use_log:
        log={}
        log['loss']=[current_loss.item()]
    if not force_learning:
        best_loss = np.inf
        convergence_criterion = np.inf
    else:
        best_loss = current_loss
        convergence_criterion = np.inf
        
    outer_count=0
    while (convergence_criterion >eps) and (outer_count<max_iter):
        previous_loss = current_loss.clone()
        # Compute M_k(epsilon) = 2 <L(C1,C2) \kron T_k> - gamma_entropie* log(T_k)
        M= G -gamma_entropy*th.log(T)
        K= th.exp(-M/gamma_entropy)
        scaling = p/K.sum(axis=1)
        if verbose:
            print('current loss: %s / embedding :%s / gradient: %s / K :%s / scaling: %s  /scaling_denominator: %s'%(current_loss,q,M,K,scaling,K.sum(axis=1)))
        T = th.diag(scaling)@ K
        q= T.sum(axis=0)
        if graph_mode=='sym':
            constC, hC1, hC2 = torch_init_matrix_GW2(C1, C2, p, q,ones_p,ones_q)
            G= torch_GWgrad(constC, hC1, hC2, T)
            #evaluate the objective for the current state:
        else:
            constC,hC1,hC2 = torch_init_matrix_asym(C1, C2, p, q, ones_p, ones_q)
            constCt, hC1t, hC2t = torch_init_matrix_asym(C1.T,C2.T,p,q,ones_p, ones_q) 
            subG =torch_GWgrad(constC,hC1,hC2,T)
            subG_T= torch_GWgrad(constCt, hC1t, hC2t,T)
            G = subG+subG_T
        current_loss = 0.5*th.sum(G*T)
        outer_count+=1
        if use_log:
            log['loss'].append(current_loss.item())
        if previous_loss != 0:
            convergence_criterion = abs(previous_loss.item() - current_loss.item())/ abs(previous_loss.item())
        else:
            convergence_criterion = abs(previous_loss.item() - current_loss.item())/ abs(previous_loss.item()+10**(-15))
        if current_loss.item() < best_loss:
            best_loss = current_loss.item()
            best_T = T.clone()
    if use_log:
        return best_T, best_loss, log
    else:
        return best_T,best_loss   
    