import numpy as np
import utils
import math
import utils_sum as us
import utils_train as ut
import sparse_em_cp as secp
import sparse_em_train as setr
from itertools import product
import sys
sys.path.append("../../data")
import dataset_info

import sp_tensor
import importlib
importlib.reload(us)
importlib.reload(sp_tensor)
importlib.reload(utils)
importlib.reload(dataset_info)

       
def EMMix_sparse(T, R_cp, R_train, verbose=True, max_iter=10, verbose_interval=1,
                mix_update_rule=0, model=(1,1,1), tol=1.0e-4, conv_check_interval=10):
    T.normalize()

    J = T.tensor_size
    D = T.tensor_dim
    N = T.nnz
    AbsOmegaI = math.prod(J)

    ### Initialization
    
    ## Normalized Weight
    ## NOTE: the total sum of eta should be 1.0
    eta_cp, eta_train, eta_noise = init_eta(model)
    if eta_cp == 0:
        Ccp = 0
    if eta_train == 0:
        Ctrain = 0
    if eta_noise == 0:
        Cnoise = 0
  
    ## Mixture Tensor 
    P = sp_tensor.Sp_tensor( T.coords, np.random.rand(N), J, normalize=True )
    T_over_P = sp_tensor.Sp_tensor( T.coords,  T.values / P.values, J)
    
    ## Pure Tensor 
    Pcp = sp_tensor.Sp_tensor( T.coords, np.random.rand(N), J, normalize=True )
    Ptrain = sp_tensor.Sp_tensor( T.coords, np.random.rand(N), J, normalize=True )
    
    ## CP factors
    Qcp = { r : sp_tensor.Sp_tensor(T.coords, np.random.rand(N), J, check_empty=False) for r in range(R_cp) }
    Mcp = { r : sp_tensor.Sp_tensor(T.coords, Qcp[r].values * T.values / P.values, J, check_empty=False) for r in range(R_cp) } 
    A = { r : [] for r in range(R_cp) } # Dense vectors
    Mcpr_sums = np.zeros(R_cp)

    ## Train cores
    G = [ np.array([]) for d in range(D) ]
    G[0] = np.random.rand(1, J[0], R_train[0])
    for d in range(1,D-1):
        G[d] = np.random.rand(R_train[d-1], J[d], R_train[d])
    G[D-1] = np.random.rand(R_train[D-2], J[D-1], 1)
    
    coords_L = setr.get_coords_L(T.coords)
    coords_R = setr.get_coords_R(T.coords)
    GR = setr.get_sparse_train_R(coords_R, G)
    GL = setr.get_sparse_train_L(coords_L, G)
    Ptrain = GL[-1]
    
    ## For train:
    ## Get coord where d th idx is id
    ## G is obtaind by summation of GR, G, GL on these idices.
    idx_d_id = setr.get_idx_d_id(T.coords, D, J)

    T_over_P = sp_tensor.Sp_tensor( T.coords,  T.values / P.values, J)
    prev_error_nl = np.inf
    prev_error_nl_for_conv = np.inf
    for n_iter in range(max_iter+1):

        if eta_cp != 0.0:
            ###########################
            ## M Step for CP
            ###########################
            for r in range(R_cp):
                Mcp[r].values = eta_cp * T.values * Qcp[r].values / P.values
                #Mcp[r].values = T.values * Qcp[r].values / P.values
                Mcpr_sums[r] = np.sum(Mcp[r].values)
            total = np.sum(Mcpr_sums)
    
            # update A
            # A[:][d][:] is dense matrix
            # A[r][d][id] where
            # r is rank, r=1,2,...,R, [rnk] 
            # d is tensor modes, d=1,2,...,D, [tensor_dim] 
            # id is d-th index of the tensor, id=1,2,..,Id [tensor_size[d]]
            for r in range(R_cp):
                # Update by the closed-form update rule
                sums_results = us.reduce_sum_each_dim(Mcp[r].coords, Mcp[r].values, D)
                A[r] = [ sums_results[d][1] * (Mcpr_sums[r])**(1/D-1) * ( total ** (-1/D) ) for d in range(D) ]
    
            ## Mcp has no guranteed to be normalized 1.
            ## However, Pcp need to be normalize. Thus we normalize each A
            ## Normalize A[r]
            ## for r in range(R_cp):
            ##   A[r] /= total**(1/D)
    
            # Checking the normalization
            # print( secp.sparse_CP_total_sum(A) )
    
            # update Q
            for r in range(R_cp):
                for n in range(N):
                    Qcp[r].values[n] = math.prod( A[r][d][ T.coords[n][d] ] for d in range(D) )
    
            # update low-cp tensor
            Pcp.values = sum( Qcp[r].values for r in range(R_cp) )

        if eta_train != 0:
            ###########################
            ## M Step for Train
            ###########################
    
            ## update cores
            for d in range(D):
                if d == 0:
                    # Since GR[-1] is not sparse tensor but sclaer value "1", 
                    # we need exceptional procedure as follows:
                    for rdm1, jd, rd in product( range(1), range(J[d]), range(R_train[d])):
                        G[d][rdm1, jd, rd] = \
                        sum( eta_train * T_over_P.coord_to_value[ *idx ] \
                              * 1 \
                              * G[d][rdm1, jd, rd] \
                              * GL[d].coord_to_value[ *([rd] + idx[d+1:] ) ] \
                              for idx in idx_d_id[d,jd] )
                        
                elif d == D-1:
                     # Since GL[D-1] is not sparse tensor but sclaer value "1", 
                     # we need exceptional procedure as follows:
                     for rdm1, jd, rd in product( range(R_train[d-1]), range(J[d]), range(1) ) :
                        G[d][rdm1, jd, rd] = \
                        sum( eta_train * T_over_P.coord_to_value[ *idx ] \
                             * GR[d-1].coord_to_value[ *(idx[0:d] + [rdm1]) ] \
                             * G[d][rdm1, jd, rd] \
                             * 1 \
                             for idx in idx_d_id[d,jd] )
                         
                else: # d = 1, 2, ..., D-2
                    for rdm1, jd, rd in product( range(R_train[d-1]), range(J[d]), range(R_train[d]) ) :
                        G[d][rdm1, jd, rd] = \
                        sum( eta_train * T_over_P.coord_to_value[ *idx ] \
                             * GR[d-1].coord_to_value[ *(idx[0:d] + [rdm1]) ] \
                             * G[d][rdm1, jd, rd] \
                             * GL[d].coord_to_value[ *([rd] + idx[d+1:] ) ] \
                             for idx in idx_d_id[d,jd] )
            
            ## Normalizer G
            for d in range(D):
                if d != D - 1:
                    for rd in range(R_train[d]):
                        G[d][:,:,rd] /= np.sum( G[d][:,:,rd] )
                else:
                    G[d][:,:,0] /= np.sum( G[d][:,:,0] )
    
            
            # To check if the normalization is satisified
            # print( np.sum( ut.train_from_cores(G) ) )
    
            GL = setr.get_sparse_train_L(coords_L, G)
            GR = setr.get_sparse_train_R(coords_R, G)
    
            # update low-train tensor
            Ptrain.values = GL[-1].values
            
        ###########################
        ## M Step for Weights 
        ###########################

        if mix_update_rule == 0:
            if eta_cp != 0:
                Ccp    = eta_cp    * np.sum( T_over_P.values * Pcp.values )
            if eta_train != 0:
                Ctrain = eta_train * np.sum( T_over_P.values * Ptrain.values )
            if eta_noise != 0:
                Cnoise = eta_noise / AbsOmegaI * np.sum( T_over_P.values )
        else:
            if eta_cp != 0:
                Ccp    = np.sum( T_over_P.values * Pcp.values )
            if eta_train != 0:
                Ctrain = np.sum( T_over_P.values * Ptrain.values )
            if eta_noise != 0:
                Cnoise = 1.0 / AbsOmegaI * np.sum( T_over_P.values )

        eta_cp    = Ccp    / (Ccp + Ctrain + Cnoise)
        eta_train = Ctrain / (Ccp + Ctrain + Cnoise)
        eta_noise = Cnoise / (Ccp + Ctrain + Cnoise)

        ## E-Step
        ##########################
        # update mixture tensor
        ##########################
        P.values = eta_cp * Pcp.values + eta_train * Ptrain.values + eta_noise / AbsOmegaI

        #T_over_P = sp_tensor.Sp_tensor( T.coords,  T.values / P.values, J)
        T_over_P.values = T.values / P.values
        T_over_P.update_coord_to_value()
        
        if verbose and n_iter > 0:
            if n_iter % verbose_interval == 0:
                # Since both P and T are normalized, 
                # NL is also monotonically decreasing.
                nl_error = utils.NL(T.values, P.values)
                print(n_iter, nl_error )
                print("mix ratio:", eta_cp, eta_train, eta_noise)
                if prev_error_nl < nl_error:
                    print("NL error is not monotonically decreasing...")
                prev_error_nl = nl_error

        if n_iter > 3 and n_iter % conv_check_interval == 0:
            nl_error = utils.NL(T.values, P.values) 
            res = abs( prev_error_nl_for_conv - nl_error ) / conv_check_interval  
            if res < tol:
                break
            else:
                prev_error_nl_for_conv = nl_error

    return A, G, eta_cp, eta_train, eta_noise

def init_eta(model):
    # model = (CP, Train, Noise)
    # "1" means active, "0" means deactive
    # Eg model=(1,0,1) means CP with noise model
    eta_cp = 0 
    eta_train = 0 
    eta_noise = 0 
    
    ## cp + train + noise
    if model == (1,1,1):
        etas = np.random.rand(3)
        etas /= np.sum(etas)
        eta_cp, eta_train, eta_noise = etas

    ## cp + train
    elif model == (1,1,0):
        etas = np.random.rand(2)
        etas /= np.sum(etas)
        eta_cp, eta_train = etas
        
    ## cp + noise
    elif model == (1,0,1):
        etas = np.random.rand(2)
        etas /= np.sum(etas)
        eta_cp, eta_noise = etas
        
    ## train + noise
    elif model == (0,1,1):
        etas = np.random.rand(2)
        etas /= np.sum(etas)
        eta_train, eta_noise = etas

    ## cp
    elif model == (1,0,0):
        eta_cp = 1.0
        
    ## train
    elif model == (0,1,0):
        eta_train = 1.0

    else:
        error("Invalid model")
     
    return eta_cp, eta_train, eta_noise
 

def mix_values_idxs(A, G, eta_cp, eta_train, eta_noise, idxs):
    print("NOTE: idxs should be sorted otherwise the evaluation would be incorrect")
    N = len(idxs)
    D = len(G)
    J = [ np.shape(G[d])[1] for d in range(D) ]
    AbsOmegaI = np.prod(np.array(J, dtype=np.float64))

    ## Reconst low-train tensor
    coords_L = setr.get_coords_L(idxs)
    GL = setr.get_sparse_train_L(coords_L, G)
    Ptrain = GL[-1]

    ## Reconst low-cp tensor
    Rcp = len(A)
    Pcp = sp_tensor.Sp_tensor( idxs, np.zeros(N), J, normalize=False, check_empty=False )
    if eta_cp != 0:
        for n in range(N):
            Pcp.values[n] = sum( math.prod( A[r][d][ idxs[n][d] ] for d in range(D) ) for r in range(Rcp) )

    reconst = eta_cp * Pcp.values + eta_train * Ptrain.values + eta_noise / AbsOmegaI
    return reconst

def eval_EMMix(A, G, eta_cp, eta_train, eta_noise, gt):
    D = len(G)
    reconst = mix_values_idxs(A, G, eta_cp, eta_train, eta_noise, gt.coords)
    nl_score = utils.NL(gt.values, reconst) 
    return nl_score
