import numpy as np
import torch

###################
## L2 score      ##
###################

def l2_score(T,P):
    if isinstance(T, torch.Tensor):
        return l2_score_torch(T,P)
    else:
        return l2_score_np(T,P)

def l2_score_np(T,P):
    L2score = np.linalg.norm(T-P, ord='fro') 
    fitscore = 1 - L2score / np.linalg.norm(T, ord='fro')
    return L2score, fitscore

def l2_score_torch(T,P):
    #L2score = torch.linalg.norm(T-P, ord='fro') 
    #fitscore = 1 - L2score / torch.linalg.norm(T, ord='fro')
    L2score = torch.norm(T-P, p=2) 
    fitscore = 1 - L2score / torch.norm(T, p=2)
    return L2score, fitscore



###################
## KL divergence ##
###################

def kl_div(T,P):
    """KL divergence between tensor T and P
    input might be non-normalized
    """
    if isinstance(T, torch.Tensor):
        return kl_div_torch(T,P)
    else:
        return kl_div_np(T,P)

def kl_div_np(T,P):
    return np.sum( T * np.log(T/P) - np.sum(T) + np.sum(P) )
    
def kl_div_torch(T,P):
    return torch.sum( T * torch.log( T/P ) - torch.sum(T) + torch.sum(P) )

def inv_kl_div(T,P):
    """KL divergence between tensor T and P
    input might be non-normalized
    """
    return kl_div(P,T)

######################
## Renyi divergence ##
######################

def renyi_div(T,P,α):
    """ Renyi-divergence from T to P
    T and P need to be normalized.
    If T or P are not normalized, it returns np.nan
    See the defination in https://arxiv.org/pdf/1206.2459
    """
    if isinstance(T, torch.Tensor):
        return renyi_div_torch(T,P,α)
    else:
        return renyi_div_np(T,P,α)

def renyi_div_torch(T,P,α):
    if α == 0.0:
        return -torch.log( torch.sum(P) )
    elif α == 1.0:
        return kl_div(T,P)
    return 1.0/(α-1.0) * torch.log( torch.sum( T**α * P**(1-α) ) )

def renyi_div_np(T,P,α):
    assert check_prob(T) and check_prob(P), "normalized error"
    if not(check_prob(T)) or not(check_prob(P)): 
        return np.nan
    α = α * 1.0
    
    if α == 0.0:
        return -np.log( np.sum(P) )
    elif α == 1.0:
        return kl_div(T,P)
    else:
        return 1.0/(α-1.0) * np.log( np.sum( T**α * P**(1-α) ) )

def mix_renyi_div(T,P,αs,weights=None):
    """ Mixture renyi divergence from T to P
    w_1 Dα1(T,P) + w_2 Dα2(T,P) + ... + w_K DαK(T,P)
    if weight is None, then w_j = 1/K
    """
    
    # αs is just a number, then return single divergence
    if isinstance(αs, (int, float)):
        return renyi_div(T,P,αs)

    # defalut weights is uniform 1/K
    K = len(αs)
    if weights is None:
        weights = [1/K for k in range(K)]
        
    mix_renyi = 0
    for (i,α) in enumerate(αs):
        mix_renyi += weights[i] * renyi_div(T,P,α)
        
    return mix_renyi

######################
## alpha divergence ##
######################

def alpha_div(T,P,α):
    """alpha divergence between tensor T and P
    input might be non-normalized
    """
    if isinstance(T, torch.Tensor):
        return alpha_div_torch(T,P,α)
    else:
        return alpha_div_np(T,P,α)
        
def alpha_div_np(T,P,α):
    if α == 1.0:
        return kl_div(T,P)
        
    elif α == 0.0:
        return inv_kl_div(T,P)
        
    else:
        term1 = α * np.sum( T )
        term2 = (1-α) * np.sum( P )
        term3 = np.sum( T**α * P**(1-α) )
        return 1.0 / ( α*(1-α) ) * (term1 + term2 - term3)

def alpha_div_torch(T,P,α):
    if α == 1.0:
        return kl_div(T,P)
        
    elif α == 0.0:
        return inv_kl_div(T,P)
        
    else:
        term1 = α * torch.sum( T )
        term2 = (1-α) * torch.sum( P )
        term3 = torch.sum( T**α * P**(1-α) )
        return 1.0 / ( α*(1-α) ) * (term1 + term2 - term3)

def mix_alpha_div(T,P,αs,weights=None):
    """ Mixture alpha divergence from T to P
    w_1 Dα1(T,P) + w_2 Dα2(T,P) + ... + w_K DαK(T,P)
    if weight is None, then w_j = 1/K
    """
    # αs is just a number, then return single divergence
    if isinstance(αs, (int, float)):
        return alpha_div(T,P,αs)

    # defalut weights is uniform 1/K
    K = len(αs)
    if weights is None:
        weights = [1/K for k in range(K)]
        
    mix_alpha = 0
    for (i,α) in enumerate(αs):
        mix_alpha += weights[i] * alpha_div(T,P,α)
    
    return mix_alpha

###############
## Utilities ##
###############

def check_prob(T, tol=1.0e-9):
    """If the tensor T is probablity (normalized and non-negative)
    then True else False
    """
    return check_nonnegative(T) and check_normalize(T, tol=tol)

def check_normalize(T, tol=1.0e-9):
    """If tensor T is normalized to 1, then True else False
    """
    if np.abs(np.sum(T) - 1.0) < tol:
        return True
    else:
        return False

def check_nonnegative(T):
    """If tensor T is non negative, then True else False
    """
    if np.min(T) >= 0.0:
        return True
    else:
        return False