import utils_Tucker as Tucker
import utils_train as train
import numpy as np
import utils
import sp_tensor
from itertools import product
import importlib
importlib.reload(sp_tensor)

import sys
sys.path.append("../../data")
import dataset_info

def EMTrain_sparse(T, R, learn_noise=False, verbose=True, max_iter=1, 
                   verbose_interval=1, noise_update_rule=0, tol=1.0e-4,
                   conv_check_interval=10):
    T.normalize()
    print("normalized done")

    J = T.tensor_size
    D = T.tensor_dim
    assert len(R) + 1 == D, "rank is invalid"
    AbsOmegaI = np.prod(J)
    
    idx_d_id = get_idx_d_id(T.coords, D, J)

    ## Initialize cores G
    G = [ np.array([]) for d in range(D) ]
    G[0] = np.random.rand(1, J[0], R[0])
    for d in range(1,D-1):
        G[d] = np.random.rand(R[d-1], J[d], R[d])
    G[D-1] = np.random.rand(R[D-2], J[D-1], 1)
    
    coords_L = get_coords_L(T.coords)
    coords_R = get_coords_R(T.coords)
    GR = get_sparse_train_R(coords_R, G)
    GL = get_sparse_train_L(coords_L, G)
    P = GL[-1]

    prev_error_nl = np.inf
    prev_error_nl_for_conv = 1.0e+10
    T_over_P = sp_tensor.Sp_tensor( T.coords,  T.values / P.values, J)
    noise = np.random.rand(1)[0]
    for n_iter in range(max_iter+1):
        
        ## update cores
        for d in range(D):
            if d == 0:
                # Since GR[-1] is not sparse tensor but sclaer value "1", 
                # we need exceptional procedure as follows:
                for rdm1, jd, rd in product( range(1), range(J[d]), range(R[d])):
                    G[d][rdm1, jd, rd] = \
                    sum( T_over_P.coord_to_value[ *idx ] \
                          * 1 \
                          * G[d][rdm1, jd, rd] \
                          * GL[d].coord_to_value[ *([rd] + idx[d+1:] ) ] \
                          for idx in idx_d_id[d,jd] )
                    
            elif d == D-1:
                 # Since GL[D-1] is not sparse tensor but sclaer value "1", 
                 # we need exceptional procedure as follows:
                 for rdm1, jd, rd in product( range(R[d-1]), range(J[d]), range(1) ) :
                    G[d][rdm1, jd, rd] = \
                    sum( T_over_P.coord_to_value[ *idx ] \
                         * GR[d-1].coord_to_value[ *(idx[0:d] + [rdm1]) ] \
                         * G[d][rdm1, jd, rd] \
                         * 1 \
                         for idx in idx_d_id[d,jd] )
                     
            else: # d = 1, 2, ..., D-2
                for rdm1, jd, rd in product( range(R[d-1]), range(J[d]), range(R[d]) ) :
                    G[d][rdm1, jd, rd] = \
                    sum( T_over_P.coord_to_value[ *idx ] \
                         * GR[d-1].coord_to_value[ *(idx[0:d] + [rdm1]) ] \
                         * G[d][rdm1, jd, rd] \
                         * GL[d].coord_to_value[ *([rd] + idx[d+1:] ) ] \
                         for idx in idx_d_id[d,jd] )
        
        ## Normalizer G
        for d in range(D):
            if d != D - 1:
                for rd in range(R[d]):
                    G[d][:,:,rd] /= np.sum( G[d][:,:,rd] )
            else:
                G[d][:,:,0] /= np.sum( G[d][:,:,0] )

        ## noise update
        if learn_noise:
            term1 = 1.0 / AbsOmegaI * np.sum( T_over_P.values )
            term2 = np.sum( T_over_P.values * GL[-1].values )
            if noise_update_rule == 0:
                noise = noise * term1 / ( noise * term1 + (1-noise)*term2 )
            else:
                noise = term1 / ( term1 + term2 )
        else:
            noise = 0

        GL = get_sparse_train_L(coords_L, G)
        GR = get_sparse_train_R(coords_R, G)

        
        ## Update P
        # NOTE: P = GL[-1]
        P.values = (1 - noise) * GL[-1].values + noise / AbsOmegaI 
        
        T_over_P = sp_tensor.Sp_tensor( T.coords,  T.values / P.values, J)

        # To check if the normalization is satisified
        # print( np.sum( train.train_from_cores(G) ) )

        if verbose and n_iter > 0:
            if n_iter % verbose_interval == 0:
                # Since both P and T are normalized, 
                # NL is also monotonically decreasing.
                nl_error = utils.NL(T.values, P.values)
                # f_error  = np.linalg.norm( T.values - P.values ) / np.linalg.norm(T.values)
                print(n_iter, noise, nl_error )
                if prev_error_nl < nl_error:
                    print("NL error is not monotonically decreasing...")
                prev_error_nl = nl_error
                
        if n_iter > 3 and n_iter % conv_check_interval == 0:
            nl_error = utils.NL(T.values, P.values) 
            res = abs( prev_error_nl_for_conv - nl_error ) / conv_check_interval  
            if res < tol:
                break
            else:
                prev_error_nl_for_conv = nl_error
         

    return G, noise

def get_idx_d_id(idx, tensor_dim, tensor_size):
    idx_d_id = {}
    for d in range(tensor_dim):
        for i in range(tensor_size[d]):
            idx_d_id[d,i] = idx[idx[:, d] == i].tolist()
    return idx_d_id

def get_coords_R(coords): # ( --> d )
    D = coords[0].shape[0]
    coords_R = [ np.unique(coords[:,0:d+1],axis=0) for d in range(D) ]
    return coords_R

def get_coords_L(coords): #( <-- d )
    D = coords[0].shape[0]
    coords_L = [ np.unique(coords[:, D-d-1:D],axis=0) for d in range(D) ]
    return coords_L

def get_sparse_train_R(coords_R, G): # ( --> d)
    D = len(G)
    tensor_size = [ np.shape(G[d])[1] for d in range(D) ]
    R = [ np.shape(G[d])[2] for d in range(D) ]

    GR = {}
    ## d = -1
    GR[-1] = 1
    
    GR[0]  = sp_tensor.dense_to_sparse(np.squeeze(G[0]))
    for d in range(1,D):
        coords_Rd  = coords_R[d]
        GRd_coords = [ [idx for idx in coords_rd] + [rd] for coords_rd in coords_Rd for rd in range(R[d]) ]
        GRd_values = [ sum( GR[d-1].coord_to_value[ *idxr[0:d] + [rdm1] ] * G[d][rdm1, idxr[-2], idxr[-1]] for rdm1 in range(R[d-1]) ) for idxr in GRd_coords ]
        tensor_size_GRd = tensor_size[0:d+1] + [R[d]]
        GR[d] = sp_tensor.Sp_tensor( np.array(GRd_coords), np.array(GRd_values), tensor_size_GRd, check_empty=False )

    # GR[D-1] will be the same tensor as the reconst. 
    # i.e., GR[D-1] == train_from_cores(G)
    return GR

def get_sparse_train_L(coords_L, G): #( d <-- )
    D = len(G)
    tensor_size = [ np.shape(G[d])[1] for d in range(D) ]
    R = [ np.shape(G[d])[2] for d in range(D) ]

    GL = {}
    
    # d = D-1
    GL[D-1] = np.array([1])

    # d = D-2
    coords_Ld  = coords_L[0]
    GLd_coords = [ [rd] + [idx for idx in coords_ld] for coords_ld in coords_Ld for rd in range(R[-2]) ]
    GLd_values = [ sum( G[D-1][ridx[0], ridx[1] ,rd] for rd in range(R[-1]) ) for ridx in GLd_coords ]
    tensor_size_GLd = [ R[-2] ] + tensor_size[-1:] 
    GL[D-2] = sp_tensor.Sp_tensor( np.array(GLd_coords), np.array(GLd_values), tensor_size_GLd, check_empty=False )

    # d = D-3, D-4, ..., 0, -1
    for k in range(1, D):
        coords_Ld  = coords_L[k]
        if k != D - 1:
            GLd_coords = [ [rd] + [idx for idx in coords_ld] for coords_ld in coords_Ld for rd in range(R[-k-2]) ]
            tensor_size_GLd = [ R[-k-2] ] + tensor_size[-k-1:] 
        elif k == D - 1:
            # R[-D-1] is undefined. Thus, I replace R[-D-1] with 1. 
            GLd_coords = [ [rd] + [idx for idx in coords_ld] for coords_ld in coords_Ld for rd in range(1) ]
            tensor_size_GLd = [ 1 ] + tensor_size[-k-1:] 
            
        GLd_values = [ sum( G[D-k-1][ridx[0], ridx[1] ,rd] * GL[D-k-1].coord_to_value[ *([rd] + ridx[-k:]) ] for rd in range(R[-k-1]) ) for ridx in GLd_coords ]
        GL[D-k-2] = sp_tensor.Sp_tensor( np.array(GLd_coords), np.array(GLd_values), tensor_size_GLd, check_empty=False )
            
    ## For debug
    ## GL[D-2] is equivalent with G[D-1]
    ## GL[D-3] is np.tensordot( G[D-2], G[D-1],  axes=1 )
    ## GL[D-4] is np.tensordot( G[D-3], GL[D-3], axes=1 )
    ## GL[D-5] is np.tensordot( G[D-4], GL[D-4], axes=1 )
    ## ...
    ## GL[-1] is the same tensor as the reconst. 
    ## i.e., GR[-1] == train_from_cores(G)
   
    return GL

def sparse_train_reconst(G, noise, coords):
    coords_L = get_coords_L(coords)
    GL = get_sparse_train_L(coords_L,G)
    tensor_size = [ np.shape(G[d])[1] for d in range( len(G) ) ]
    AbsOmegaI = np.prod(np.array(tensor_size, dtype=np.float64))
    reconst = GL[-1]
    reconst_with_noise = (1-noise) * reconst.values + noise / AbsOmegaI
    return reconst_with_noise

def sparse_train_from_cores_with_noise(G, noise):
    values_with_noise = sparse_train_from_cores_with_noise(G, noise, idxs)
    return np.sum(total)
    

def sparse_train_from_cores_with_noise(G, noise, idxs):
    tensor_dim  = len(G)
    tensor_size = [ np.shape(G[d])[1] for d in range(tensor_dim) ]
    #AbsOmegaI = np.prod(np.shape(tensor_size))
    AbsOmegaI = np.prod(np.array(tensor_size, dtype=np.float64))
    
    values = [ train.train_from_cores_idx(G, idx) for idx in idxs ]
    values_with_noise = (1-noise) * np.array(values) + noise / AbsOmegaI
    return values_with_noise
    
