import numpy as np
#import tensorly as tl
from functools import reduce
from itertools import product

def tuple_skipping_m(N, m):
    """
    For example,
    tuple_skipping_m(5,2) = (0,1,3,4)
    tuple_skipping_m(7,3) = (0,1,2,4,5,6,7)
    tuple_skipping_m(4,1) = (0,2,3)
    """
    return tuple(i for i in range(N) if i != m)

def NL(P,T):
    return - np.sum(P * np.log(T))

def KL_div(P, T):
    """ KL divergence from tensor P to T
    Both P and T need to be postive.
    Their total sum can be larger than 1.
    """
    return np.sum(P * np.log(P / T)) - np.sum(P) + np.sum(T)

def Fnorm(P, T):
    """ Frobenius norm between tensor P to T 
    Both P and T need to have same number of 
    elements.
    """
    return tl.norm(P-T)

def get_rnk_indices_for_sum(k, ik, rnk):
    """
    Get all rnk vectors whose k-th index is ik.
    Example
    get_rnk_sum_indices(0,1,[2,2,2])
    (1,0,0)
    (1,0,1)
    (1,1,0)
    (1,1,1)
    """
    rnk_dim = len(rnk)
    indices_rnk_except_k_ik = [ [ rd for rd in range(rnk[d]) ] if d != k else [ik] for d in range(rnk_dim) ]
    #for t in product(*indices_rnk_except_k_ik):
    #    print(t)
    return indices_rnk_except_k_ik
