import sys
import utils_CP as CP
import utils

from numpy import linalg
from functools import reduce
import numpy as np
sys.version

import importlib
importlib.reload(utils)
importlib.reload(CP)

def EMCP(T, rnk, learn_noise=False, verbose=True, max_iter=100, 
         verbose_interval=10, noise_update_rule=0):
    T = T / np.sum(T)
    
    tensor_dim  = np.ndim(T)
    tensor_size = np.shape(T)
    AbsOmegaI = np.prod(tensor_size)
    
    # Initializeation
    As = [ np.random.rand(tensor_size[d], rnk) for d in range(tensor_dim) ]
    P  = CP.CP_from_factors(As)
    T_over_P = T / P
    sum_T_over_P = np.sum( T_over_P )
    Q = CP.get_CP_Q(As)
    
    prev_error_kl = np.inf
    noise = np.random.rand(1)[0]
    for n_iter in range(max_iter):
        
        Q = CP.get_CP_Q(As)
        T_over_P = T / P
        sum_T_over_P = np.sum( T_over_P )
        
        # M-Step
        # update As
        sum_rnk = [ np.sum(T_over_P * Q[r]) for r in range(rnk) ]
        total = np.sum( sum_rnk )
        for d in range(tensor_dim):
            axis_to_sum = utils.tuple_skipping_m(tensor_dim, d)
            for r in range(rnk):
                As[d][:,r] = np.sum( Q[r] * T_over_P , axis=axis_to_sum) / ( total**(1/tensor_dim) * sum_rnk[r] **(1-1/tensor_dim) )

        # update noise
        if learn_noise:
            term1 = 1 / AbsOmegaI * sum_T_over_P
            #term2 = (1-noise) * sum_T_over_P * sum( np.sum( Q[r] ) for r in range(rnk) )
            term2 = sum( np.sum( Q[r] * T_over_P ) for r in range(rnk) )
            if noise_update_rule == 0:
                noise = noise*term1 / ( noise*term1 + (1-noise)*term2 )
            else:
                noise = term1 / ( term1 + term2 )
        else:
            noise = 0
        
        # E-Step
        # update P from As
        P  = (1 - noise) * CP.CP_from_factors_(P, As) + noise / AbsOmegaI
                
        if verbose and n_iter > 0:
            if n_iter % verbose_interval == 0:
                kl_error = utils.KL_div(T, P)
                print(n_iter, noise, np.linalg.norm(T-P)/np.linalg.norm(T), kl_error)
                if prev_error_kl < kl_error:
                    print("KL is not monotonically decreasing")
                prev_error_kl = kl_error
        
    return As