import numpy as np
#import tensorly as tl
from functools import reduce
from itertools import product

def cp_n_params(tensor_size, rnk):
    """
    Number of parameters for CP structure
    """
    return np.sum( np.array( tensor_size ) - 1) * rnk

def tucker_n_params(tensor_size, rnk):
    """
    Number of parameters for tucker structure
    """
    tensor_dim = len(tensor_size)
    n_param_core = np.prod( np.array(rnk) )
    n_param_factor = sum( tensor_size[d] * rnk[d] for d in range(tensor_dim) )
    return n_param_core + n_param_factor

def train_n_params(tensor_size, rnk):
    """
    Number of parameters for train structure
    """
    tensor_dim = len(tensor_size)
    term = 0
    for d in range(tensor_dim):
        if d == 0:
            term += tensor_size[0] * rnk[0]
        elif d == tensor_dim - 1:
            term += rnk[d-1] * tensor_size[d]
        else:
            term += rnk[d-1] * tensor_size[d] * rnk[d]
    return term

def tuple_skipping_m(N, m):
    """
    For example,
    tuple_skipping_m(5,2) = (0,1,3,4)
    tuple_skipping_m(7,3) = (0,1,2,4,5,6,7)
    tuple_skipping_m(4,1) = (0,2,3)
    """
    return tuple(i for i in range(N) if i != m)

def NL(P,T, avoid_nan=False):
    if avoid_nan:
        """
        If P has zero value, KL might be nan.
        Thus, we avoid this case
        """
        Parr = P[ P != 0 ]
        Tarr = T[ P != 0 ]
        return - np.sum(Parr * np.log(Tarr))
    else:
        return - np.sum(P * np.log(T))

def KL_div(P, T, avoid_nan=False):
    """ KL divergence from tensor P to T
    Both P and T need to be postive.
    Their total sum can be larger than 1.
    """
    if avoid_nan:
        """
        If P has zero value, KL might be nan.
        Thus, we avoid this case
        """
        Parr = P[ P != 0 ]
        Tarr = T[ P != 0 ]
        return np.sum(Parr * np.log(Parr / Tarr)) - np.sum(P) + np.sum(T)
    else:
        return np.sum(P * np.log(P / T)) - np.sum(P) + np.sum(T)

def inv_KL_div(P, T, avoid_nan=False):
    return KL_div(T, P, avoid_nan=avoid_nan)

######################
## alpha divergence ##
######################

def alpha_div_sparse(T,P,alpha,avoid_nan=False):
    if alpha == 1.0:
        return KL_div(T.values,P.values,avoid_nan=avoid_nan)

    elif alpha == 0.0:
        return inv_KL_div(T.values,P.values,avoid_nan=avoid_nan)

    else:
        tensor_size = np.size(T)
        term = np.sum( T.values**(alpha) * P.values**(1-alpha) )
        return 1.0/ ( alpha*(1-alpha) ) * ( tensor_size - term )

def alpha_div(T,P,α,avoid_nan=False):
    if α == 1.0:
        return KL_div(T,P,avoid_nan=avoid_nan)

    elif α == 0.0:
        return inv_KL_div(T,P,avoid_nan=avoid_nan)

    else:
        term1 = α * np.sum( T )
        term2 = (1-α) * np.sum( P )
        term3 = np.sum( T**α * P**(1-α) )
        return 1.0 / ( α*(1-α) ) * (term1 + term2 - term3)


def Fnorm(P, T):
    """ Frobenius norm between tensor P to T 
    Both P and T need to have same number of 
    elements.
    """
    return tl.norm(P-T)

def get_rnk_indices_for_sum(k, ik, rnk):
    """
    Get all rnk vectors whose k-th index is ik.
    Example
    get_rnk_sum_indices(0,1,[2,2,2])
    (1,0,0)
    (1,0,1)
    (1,1,0)
    (1,1,1)
    """
    rnk_dim = len(rnk)
    indices_rnk_except_k_ik = [ [ rd for rd in range(rnk[d]) ] if d != k else [ik] for d in range(rnk_dim) ]
    #for t in product(*indices_rnk_except_k_ik):
    #    print(t)
    return indices_rnk_except_k_ik
