import numpy as np
from itertools import product

def get_low_train_rank_tensor(R, J):
    D = len(J)
    assert len(R) + 1 == D, "rank is invalid"
    cores = [ np.array([]) for d in range(D) ]
    cores[0] = np.random.rand(1, J[0], R[0])
    for d in range(1,D-1):
        cores[d] = np.random.rand(R[d-1], J[d], R[d])
    cores[D-1] = np.random.rand(R[D-2], J[D-1], 1)

    return train_from_cores(cores)

def train_from_cores(cores):
    tensor_shape = [core.shape[1] for core in cores]
    result = np.zeros(tensor_shape)
    indices = np.indices(tensor_shape).reshape(len(tensor_shape), -1).T
    for idx in indices:
        temp = cores[0][:, idx[0], :].reshape(-1) # The first core
        for i in range(1, len(cores)):
            core = cores[i]
            temp = np.tensordot(temp, core[:, idx[i], :], axes=([0], [0]))
        result[tuple(idx)] = temp.item()
    return result
    
def train_from_cores_idx(cores,idx):
    tensor_dim = len(cores)
    rnk = [ np.shape(cores[d])[0] for d in range(1,tensor_dim) ]
    k = 0
    for r in product(*( range(Rd) for Rd in rnk )):
        m = cores[0][0, idx[0], r[0]] * cores[tensor_dim-1][r[tensor_dim-2], idx[tensor_dim-1], 0]
        for d in range(1, tensor_dim-1):
            m *= cores[d][r[d-1], idx[d], r[d]]
        k += m
    return k

def get_train_R(cores):
    # G( --> d )
    # GR[d][i1,i2,...,id,rd]
    
    GR = {}
    D  = len(cores)
    GR[-1] = np.array([1])
    GR[0] = cores[0][0,:,:]
    for d in range(1,D):
        GR[d] = np.tensordot(GR[d-1], cores[d], axes=1)
        
    ## GR[D-1] should be same as full_reconst 
    ## print( np.squeeze(GR[D-1]) - train_from_cores(cores) )
    return GR

def get_train_L(cores):
    GL = {}

    # G( d <--- )
    # GL[d][rd,id+1,id+2,...,iD]
    
    D = len(cores)
    GL[D-1] = np.array([1])
    # if you do not need full_reconst, GL[-1], loop
    # for d in range(2,D+1), instead of the below code
    for d in range(2,D+2):
        GL[D-d] = np.tensordot( cores[D-d+1], GL[D-d+1], axes=[[2],[0]])
        
    #GL[0-1] should be same as full_reconst 
    # print( np.squeeze(GL[0-1]) - train_from_cores(cores) )
    
    return GL