import numpy as np
import utils_sum as us
import utils
import utils_Tucker as Tucker
import math
import os
import sys
import sp_tensor
from itertools import product
import importlib
sys.path.append("../../data")
import dataset_info

importlib.reload(utils)
importlib.reload(us)

import pathlib
pathlib.Path().resolve()

def EMTucker_sparse(T, rnk, learn_noise=False, verbose=True, max_iter=100, 
                    verbose_interval=1, noise_update_rule=0, tol=1.0e-4,
                    conv_check_interval=10):
    assert len(rnk) == T.tensor_dim, "rank need to be D-th dim vector"
    
    T.normalize()
    print("normalized done")
    
    tensor_size = T.tensor_size
    tensor_dim  = T.tensor_dim
    N = T.nnz
    R1R2R3 = list( range(Rd) for Rd in rnk )
    AbsOmegaI = math.prod(tensor_size)
    
    # Initialization
    A = [  np.random.rand( tensor_size[d], rnk[d] ) for d in range(tensor_dim) ] # Dense matrices
    G = np.random.rand( *rnk )
    P = sp_tensor.Sp_tensor( T.coords, np.random.rand(N), tensor_size, normalize=False )
    P.values = sparse_Tucker_from_GA_values(G, A, T.coords)
    
    Q = { r1r2r3 : sp_tensor.Sp_tensor(T.coords, np.random.rand(N), tensor_size, check_empty=False) \
         for r1r2r3 in product( *R1R2R3 ) }
    M = { r1r2r3 : sp_tensor.Sp_tensor(T.coords, Q[r1r2r3].values * T.values / P.values, tensor_size, check_empty=False) \
         for r1r2r3 in product( *R1R2R3 ) } 
    
    
    sumsM_results = { r1r2r3 : { d : [] for d in range(tensor_dim) } for r1r2r3 in product( *R1R2R3 ) }
    print("initialized done")
    
    prev_error_nl = np.inf
    prev_error_nl_for_conv = 1.0e+10
    noise = np.random.rand(1)[0]
    for n_iter in range(max_iter+1):
        
        # update Q
        # Q = update_Tucker_Q_sparse(Q, G, A, T)
        for n in range(T.nnz):
            for r1r2r3 in product( *R1R2R3 ):
                # Naivly, G and A are dense, so Q can be also dense. 
                # However, we need only Q on T.coords. 
                # Thus, we keep Q as sparse tensor.
                Q[r1r2r3].values[n] = G[r1r2r3] * math.prod( A[d][T.coords[n][d], r1r2r3[d]] for d in range(tensor_dim) )

        # update noise
        if learn_noise:
            term1 = 1.0 / AbsOmegaI * np.sum( T.values / P.values ) 
            term2 = np.sum(sum( M[r1r2r3].values for r1r2r3 in product( *R1R2R3) ))
            if noise_update_rule == 0:
                noise = noise * term1 / (noise * term1 + (1-noise) * term2)
            else:
                noise = term1 / (term1 + term2)
        else:
            noise = 0
        
        # update P
        # Alternatively, you can update this way:
        # P.values = sparse_Tucker_from_GA_values(G, A, T.coords)
        P.values = (1-noise) * sum( Q[r1r2r3].values for r1r2r3 in product( *R1R2R3) ) + noise / AbsOmegaI
        
        # update M
        for r1r2r3 in product( *R1R2R3 ):
            M[*r1r2r3].values = Q[r1r2r3].values * T.values / P.values
            
        # update G
        for r1r2r3 in product( *R1R2R3 ):
            G[*r1r2r3] = sum(M[r1r2r3].values)
        # normalize G
        G /= np.sum(G)
            
        # update A
        # A[d] is dense matrix
        # A[d][id,rd] where
        # rd is d-th rank, rd=1,2,...,Rd, [rnk] 
        # d is tensor modes, d=1,2,...,D, [tensor_dim] 
        # id is d-th index of the tensor, id=1,2,..,Id [tensor_size[d]]
        for r1r2r3 in product( *R1R2R3 ):
            tmp_results = us.reduce_sum_each_dim(M[r1r2r3].coords, M[r1r2r3].values, tensor_dim, sort=True)
            for d in range(tensor_dim):
                sumsM_results[r1r2r3][d] = tmp_results[d][1]
                
        for d in range(tensor_dim):
            for rd in range(rnk[d]):
                indices_rnk = utils.get_rnk_indices_for_sum(d, rd, rnk)
                A[d][:,rd]  = sum( sumsM_results[r1r2r3][d] for r1r2r3 in product(*indices_rnk) )
                
            # normalize A
            for rd in range(rnk[d]):
                 A[d][:,rd] /= np.sum( A[d][:,rd] )
        
        # To check if the normalization is satsified 
        # print( sparse_Tucker_total_sum(G,A) )
        
        if verbose and n_iter > 0:
            if n_iter % verbose_interval == 0:
                # Since both P and T are normalized, 
                # NL is also monotonically decreasing.
                nl_error = utils.NL(T.values, P.values) 
                # f_error  = np.linalg.norm( T.values - P.values ) / np.linalg.norm(T.values)
                print( n_iter, noise, nl_error )
                if prev_error_nl < nl_error:
                    print("NL error is not monotonically decreasing...")
                prev_error_nl = nl_error
                
        if n_iter > 3 and n_iter % conv_check_interval == 0:
            nl_error = utils.NL(T.values, P.values) 
            res = abs( prev_error_nl_for_conv - nl_error ) / conv_check_interval  
            if res < tol:
                break
            else:
                prev_error_nl_for_conv = nl_error
                
    return G, A, noise

def sparse_Tucker_from_GA_with_noise(G, A, noise, idxs):
    tensor_dim = np.ndim(G)
    tensor_size = [ np.shape(A[d])[0] for d in range(tensor_dim) ]
    AbsOmegaI = math.prod( np.array(tensor_size, dtype=np.float64) )
    values = [ sparse_Tucker_from_GA(G, A, idx) for idx in idxs ]
    values_with_noise = (1-noise) * np.array(values) + noise / AbsOmegaI
        
    return values_with_noise

def sparse_Tucker_from_GA_values(G, A, idxs):
    values_on_idces = [ sparse_Tucker_from_GA(G, A, idx) for idx in idxs ]
    return values_on_idces

def sparse_Tucker_from_GA(G, A, idx):
    rnk = G.shape
    tensor_dim = len(rnk)
    R1R2R3 = list( range(Rd) for Rd in rnk )
    q = np.zeros((rnk))
    for r1r2r3 in product( *R1R2R3 ):
        q[r1r2r3] = G[r1r2r3] * math.prod( A[d][idx[d],r1r2r3[d]] for d in range(tensor_dim) )
    value_on_idx = np.sum(q)
    return value_on_idx

def sparse_Tucker_to_dense(G, A):
    rnk = G.shape
    tensor_dim = len(rnk)
    tensor_size = np.array([ A[d].shape[0] for d in range(tensor_dim) ])
    dense_Tucker = np.zeros( tensor_size )
    R1R2R3 = list( range(Rd) for Rd in rnk )
    I1I2I3 = list( range(Id) for Id in tensor_size ) 
    for i1i2i3 in product( *I1I2I3 ):
        dense_Tucker[i1i2i3] = sparse_Tucker_from_GA(G,A,i1i2i3)
    return dense_Tucker

def sparse_Tucker_total_sum(G, A):
    rnk = G.shape
    tensor_dim = len(rnk)
    tensor_size = np.array([ A[d].shape[0] for d in range(tensor_dim) ])
    #dense_Tucker = np.zeros( tensor_size )
    I1I2I3 = list( range(Id) for Id in tensor_size ) 
    k = 0
    for i1i2i3 in product(*I1I2I3):
        k += sparse_Tucker_from_GA(G, A, i1i2i3)
    return k
