import sys
print(sys.version)
import numpy as np
import tensor_op as op
import utils
import utils_Tucker as Tucker
import utils_CP as CP
from itertools import product
from functools import reduce

import importlib
importlib.reload(Tucker)
importlib.reload(utils)

def EMTucker(T, rnk, learn_noise=False, verbose=True, max_iter=100, 
             verbose_interval=10, noise_update_rule=0):
    T = T / np.sum(T)
    
    tensor_dim = np.ndim(T)
    tensor_size = np.shape(T)
    AbsOmegaI = np.prod(tensor_size)
    assert len(rnk) == tensor_dim, "Rank needs to be vector"
    
    # Initialization
    As = [ np.random.rand(tensor_size[d], rnk[d]) for d in range(tensor_dim) ]
    G  = np.random.rand(*rnk)
    P  = Tucker.Tucker_from_factors(G, As)
    T_over_P = T / P
    sum_T_over_P = np.sum(T_over_P)
    Q  = Tucker.get_Tucker_Q(G, As)
    
    indices_all_rnk = [ [rd for rd in range(rnk[d])] for d in range(tensor_dim) ]
    
    prev_error_nl = np.inf
    noise = np.random.rand(1)[0]
    for n_iter in range(max_iter):
        
        Q = Tucker.get_Tucker_Q(G, As)
        T_over_P = T / P
        sum_T_over_P = np.sum( T_over_P )
        
        # M-step
        # Update G
        for r in product(*indices_all_rnk):
            G[*r] = np.sum( Q[r] * T_over_P )
        # normalize G
        G /= np.sum(G)
            
        # Update As
        for d in range(tensor_dim):
            axis_to_sum = utils.tuple_skipping_m(tensor_dim, d)
            for rd in range(rnk[d]):
                indices_rnk = utils.get_rnk_indices_for_sum(d, rd, rnk)
                As[d][:,rd] = sum( np.sum( Q[r] * T_over_P, axis=axis_to_sum ) for r in product(*indices_rnk))
                
            # Normalize As
            for rd in range(rnk[d]):
                As[d][:, rd] /= np.sum(As[d][:,rd])

        if learn_noise:
            # Update noise
            term1 = 1.0 / AbsOmegaI * sum_T_over_P
            term2 = sum( np.sum(Q[r] * T_over_P ) for r in product(*indices_all_rnk) )
            if noise_update_rule == 0:
                noise = noise * term1 / ( noise * term1 + (1-noise) * term2)
            else:
                noise = term1 / ( term1 + term2 )
        else:
            noise = 0
        
        # E-step
        # Update P from (G, As)
        
        # NOTE:
        # Tucker.Tucker_from_factors(G, As) is twice faster than P = sum( Q.values() )
        P = (1 - noise) * Tucker.Tucker_from_factors(G, As) + noise / AbsOmegaI
        
        if verbose:
            if n_iter % verbose_interval == 0:
                kl_error = utils.KL_div(T,P)
                f_error  = np.linalg.norm(T-P)/np.linalg.norm(T)
                nl_error = utils.NL(T, P)
                print(n_iter, noise, f_error, kl_error, nl_error)
                if prev_error_nl < nl_error:
                    print("KL is not monotonicaly decreasing")
                prev_error_nl = nl_error
    
    return G, As