import sys
import utils_train as train
import utils
from itertools import product
import os
import numpy as np

import importlib
importlib.reload(train)

def EMTrain(T, R, learn_noise=False, verbose=True, max_iter=10, 
            verbose_interval=1, noise_update_rule=0):
    T = T / np.sum(T)
    D = np.ndim(T)
    J = np.shape(T)
    AbsOmegaI = np.prod(J)

    G = [ np.array([]) for d in range(D) ]
    G[0] = np.random.rand(1, J[0], R[0])
    for d in range(1,D-1):
        G[d] = np.random.rand(R[d-1], J[d], R[d])
    G[D-1] = np.random.rand(R[D-2], J[D-1], 1)
    
    P = train.train_from_cores(G) 
    T_over_P = T / P

    r = [*R, 1]
    prev_error_kl = np.inf
    noise = np.random.rand(1)[0]
    for n_iter in range(max_iter):
        T_over_P = T / P
        GR = train.get_train_R(G) # G(  --> d )
        GL = train.get_train_L(G) # G( d <--  )

        ## Update G
        for d in range(D):
            sum_axes = utils.tuple_skipping_m(D,d)
            for rdm1, rd in product(range(r[d-1]), range(r[d])):
                slice_GR = [slice(None)] * (GR[d-1].ndim - 1) + [rdm1] 
                GR_new = np.tensordot(GR[d-1][tuple(slice_GR)], G[d][rdm1,:,rd], axes=0)
                
                slice_GL = [rd] + [slice(None)] * (GL[d].ndim - 1)
                X = np.tensordot(GR_new, GL[d][tuple(slice_GL)], axes=0)
                G[d][rdm1,:,rd] = np.sum( T_over_P * X, axis=sum_axes ) / np.sum( T_over_P * X )

        ## Normalize G
        for d in range(D):
            for rd in range(r[d]):
                G[d][:,:,rd] /= np.sum( G[d][:,:,rd] )

        # update noise
        if learn_noise:
            term1 = 1 / AbsOmegaI * np.sum(T_over_P)
            term2 = np.sum( T_over_P * np.squeeze(GR[D-1]) )
            if noise_update_rule == 0:
                noise = noise * term1 / ( noise * term1 + (1-noise) * term2)
            else:
                noise = term1 / (term1 + term2)
        else:
            noise = 0.0

        ## E-step
        ## Update P
        P = (1-noise) * train.train_from_cores(G) + noise / AbsOmegaI 
        
        # To check if the normalization is satsified 
        # print( np.sum(P) )
            
        if verbose:
            if n_iter % verbose_interval == 0:
                kl_error = utils.KL_div(T,P)
                f_error  = np.linalg.norm(T-P)/np.linalg.norm(T)
                print(n_iter, noise, f_error, kl_error)
                if prev_error_kl < kl_error:
                    print("KL is not monotonicaly decreasing")
                prev_error_kl = kl_error
    
    return G