import numpy as np
import utils
import math
import utils_sum as us
from itertools import product
import sys
sys.path.append("../../data")
import dataset_info
import utils_CP as CP

import sp_tensor
import importlib
importlib.reload(us)
importlib.reload(utils)
importlib.reload(dataset_info)


def EMCP_sparse(T, rnk, learn_noise=False, verbose=True, max_iter=100, 
                verbose_interval=1, noise_update_rule=0, tol=1.0e-4,
                conv_check_interval=10):
    T.normalize()
    print("normalized done")
    
    tensor_size = T.tensor_size
    tensor_dim  = T.tensor_dim
    N = T.nnz
    AbsOmegaI = math.prod( tensor_size )
    
    # Initialization
    P      = sp_tensor.Sp_tensor( T.coords, np.random.rand(N), tensor_size, normalize=True )
    Q = { r : sp_tensor.Sp_tensor(T.coords, np.random.rand(N), tensor_size, check_empty=False) for r in range(rnk) }
    M = { r : sp_tensor.Sp_tensor(T.coords, Q[r].values * T.values / P.values, tensor_size, check_empty=False) for r in range(rnk) } 
    A = { r : [] for r in range(rnk) } # Dense vectors

    print("initialized done")
    
    prev_error_nl = np.inf
    prev_error_nl_for_conv = 1.0e+10
    noise = np.random.rand(1)[0]
    for n_iter in range(max_iter+1):
        
        # update M
        for r in range(rnk):
            M[r].values = T.values * Q[r].values / P.values
        Mr_sums = [ np.sum(M[r].values) for r in range(rnk) ]
        total = np.sum( Mr_sums )
        
        # update A
        # A[:][d][:] is dense matrix
        # A[r][d][id] where
        # r is rank, r=1,2,...,R, [rnk] 
        # d is tensor modes, d=1,2,...,D, [tensor_dim] 
        # id is d-th index of the tensor, id=1,2,..,Id [tensor_size[d]]
        for r in range(rnk):
            sums_results = us.reduce_sum_each_dim(M[r].coords, M[r].values, tensor_dim)
            A[r] = [ sums_results[d][1] * (total**(-1/tensor_dim)) * (Mr_sums[r])**(1/tensor_dim-1) \
                    for d in range(tensor_dim) ]
            
        # update Q
        for r in range(rnk):
            for n in range(N):
                # Naivly, A is dense, so Q can be also dense. 
                # However, we need only Q on T.coords. 
                # Thus, we keep Q as sparse tensor.
                Q[r].values[n] = math.prod( A[r][d][ T.coords[n][d] ] for d in range(tensor_dim) )
       
        # update noise
        if learn_noise:
            T_over_P = T.values / P.values
            
            term1 = 1.0 / AbsOmegaI * np.sum( T_over_P )
            #term2 = sum( T_over_P[n] * sum( Q[r].values[n] for r in range(rnk) ) for n in range(N) ) 
            term2 = total
            if noise_update_rule == 0:
                noise = noise * term1 / ( noise * term1 + (1-noise) * term2)
            else:
                noise = term1 / ( term1 + term2 )
        else:
            noise = 0
            
        
        # update P
        P.values = (1 - noise) * sum( Q[r].values for r in range(rnk) ) + noise / AbsOmegaI
        
        # To check if the normalization is satsified 
        # print( np.sum( sparse_CP_to_dense(A) ) )
        
               
        if verbose and n_iter > 0:
            if n_iter % verbose_interval == 0:
                # Since both P and T are normalized, we can evaluate KL in this way
                kl_error = utils.KL_div(T.values, P.values) 
                nl_error = utils.NL(T.values, P.values) 
                print( n_iter, noise, nl_error )
                if prev_error_nl < nl_error:
                    print("NL error is not monotonically decreasing...")
                prev_error_nl = nl_error
                
        if n_iter > 3 and n_iter % conv_check_interval == 0:
            nl_error = utils.NL(T.values, P.values) 
            res = abs( prev_error_nl_for_conv - nl_error ) / conv_check_interval  
            if res < tol:
                break
            else:
                prev_error_nl_for_conv = nl_error
                
    return A, noise

def sparse_CP_from_A_with_noise(A, noise, indices):
    tensor_dim = len(A[0])
    tensor_size = np.array([ len(A[0][d]) for d in range(tensor_dim) ], dtype=np.float64)
    AbsOmegaI = math.prod( tensor_size )

    low_rank_values = np.zeros( len(indices) )
    for n, idx in enumerate(indices):
        low_rank_value = sparse_CP_from_A(A, idx)
        low_rank_values[n] = (1-noise) * low_rank_value + noise / AbsOmegaI
        
    return low_rank_values

def sparse_CP_from_A(A, idx):
    rnk = len(A)
    tensor_dim = len(idx)
    #value_on_idx = sum( math.prod( A[r][d][ idx[d] ] for d in range(tensor_dim) ) for r in rnk)
    q = np.zeros(rnk)
    for r in range(rnk):
        q[r] = math.prod( A[r][d][ idx[d] ] for d in range(tensor_dim) )
    value_on_idx = sum(q)
    return value_on_idx

def sparse_CP_to_dense(A):
    rnk = len(A)
    tensor_dim = len(A[0])
    tensor_size = np.array([ len(A[0][d]) for d in range(tensor_dim) ])
    dense_CP = np.zeros( tensor_size )
    for idx in product( *(range(Jd) for Jd in tensor_size  ) ):
        dense_CP[idx] = sparse_CP_from_A(A, idx)
    return dense_CP

def sparse_CP_total_sum(A):
    rnk = len(A)
    tensor_dim = len(A[0])
    tensor_size = np.array([ len(A[0][d]) for d in range(tensor_dim) ])
    
    k = 0
    for idx in product( *(range(Jd) for Jd in tensor_size  ) ):
        k += sparse_CP_from_A(A, idx)
    return k
