import numpy as np
import math
from functools import reduce
import os
import sys
import time

import importlib

"""
utils
"""

def axis_skip_dl(N, d, l):
    return tuple(x for x in range(N) if x != d and x != l)


"""
Margnalizations
"""
def get_X_from_samples(coords,values,I, small_value=1.0e-8):
    N = len(coords[0])
    X = {}
    for d in range(N):
        for l in range(N):
            if d < l:
                X[d,l] = np.zeros( (I[d], I[l]) )
                for s, coord in enumerate(coords):
                    X[d,l][coord[d], coord[l]] += values[s]
                # avoid error of log 0 
                X[d,l][ X[d,l] == 0.0 ] = small_value
                # normalize
                X[d,l] = X[d,l] / np.sum(X[d,l])
            else:
                X[d,l] = np.nan
    return X

def get_second_margnalizedX(T):
     # Tensor size (size of sample space)
    I = np.shape(T)
    # Tensor dim
    N = len(I)
    
    X = {}
    for d in range(N):
        for l in range(N):
            if d < l:
                X[d,l] = np.sum(T, axis=axis_skip_dl(N,d,l))
                X[d,l] = X[d,l] / np.sum(X[d,l])
            else:
                X[d,l] = np.nan
    return X

def get_random_margnalizedX():
    X = {}
    for d in range(N):
        for l in range(N):
            if d < l:
                X[d,l] = np.random.rand(I[d],I[l])
                X[d,l] = X[d,l] / np.sum(X[d,l])
            else:
                X[d,l] = np.nan
    return X

"""
Reconst and evaluations
"""

def reconst_CPD_lamb_A(lam, As):
    """
    Get low CP rank tensor from factor matrices
    """
    rnk = np.shape(As[0])[1]
    tensor_dim  = len(As)
    tensor_size = [ np.shape(As[d])[0] for d in range(tensor_dim) ]
    P = np.zeros(tensor_size)
    for r in range(rnk):
        P += lam[r] * reduce(np.multiply.outer, [ As[d][:,r] for d in range(tensor_dim) ] )
    return P

def sparse_CPD_from_A_indices(A, lamb, indices):
    tensor_dim = len(A)
    tensor_size = [ np.shape(A[d])[0] for d in range(tensor_dim)]
    
    low_rank_values = np.zeros( len(indices) )
    for n, idx in enumerate(indices):
        low_rank_value = sparse_CPD_from_A(A, lamb, idx)
        low_rank_values[n] = low_rank_value
        
    return low_rank_values

def sparse_CPD_from_A(A, lamb, idx):
    rnk = np.shape(A[0])[1]
    tensor_dim = len(A)
    
    #I = [ np.shape(A[d])[0] for d in range(N)]
    #value_on_idx = sum( math.prod( A[r][d][ idx[d] ] for d in range(tensor_dim) ) for r in rnk)
    q = np.zeros(rnk)
    for r in range(rnk):
        q[r] = lamb[r] * math.prod( A[d][ idx[d], r ] for d in range(tensor_dim) )
    value_on_idx = sum(q)
    return value_on_idx


"""
Objective functions
"""

def objective2(X,A,lamb):
    N = len(A)
    cost = 0.0
    for d in range(N):
        for l in range(N):
            if d < l:
                cost += objective2_d_l(d,l,X,A,lamb)
                # if A and lamb is approprietly normalized,
                # then, np.sum(A[d] @ D @ A[l].T) == 1
    return cost
    
def objective2_d_l(d,l,X,A,lamb):
    assert d < l, "l shoud be larger than d"
    Xdl = X[d,l]
    D = np.diag(lamb)
    return KL_divergence(Xdl, A[d] @ D @ A[l].T)

def objective1_l(l,X,A,lamb):
    N = len(A)
    cost = 0.0
    for d in range(N):
        if d < l:
            cost += objective2_d_l(d,l,X,A,lamb)

    return cost
    
def KL_divergence(P, T):
    return np.sum(P * np.log(P / T)) - np.sum(P) + np.sum(T)

def NLL(P, T):
    return - np.sum( P * np.log(T) )

"""
For update factors
"""

def get_V(X,A,lamb):
    D = np.diag(lamb)
    N = len(A)
    F = np.shape(A[0])[1]
    I = [ np.shape(A[d])[0] for d in range(N)]
    
    V = {}
    for d in range(N):
        for l in range(N):
            if d < l:
                Ydl = X[d,l] / ( A[d] @ D @ A[l].T )
                V[d,l] = (Ydl).T @ A[d]
                assert np.shape(V[d,l]) == (I[l],F)
            else:
                V[d,l] = np.nan
                
    return V
    
def get_diff_Al(l,X,A,lamb):
    D = np.diag(lamb)
    N = len(A)
    I = [ np.shape(A[d])[0] for d in range(N)]
    F = np.shape(A[0])[1]

    V = get_V(X,A,lamb)
   
    diff_Al = np.zeros( (I[l], F) )
    for il in range(I[l]):
        for f in range(F):
            value = 0.0
            for d in range(N):
                if d < l:
                    value -= lamb[f] * V[d,l][il,f]
            
            diff_Al[il,f] = value
    
    return diff_Al

def update_Al(l,X,A,lamb,alpha):
    diff_Al = get_diff_Al(l,X,A,lamb)

    A[l] = A[l] * np.exp( -alpha * diff_Al)
    each_norm = np.linalg.norm(A[l], ord=1, axis=0) # ord 1 means L1 norm
    A[l] = A[l] / each_norm[np.newaxis,:]
    return A[l]

def optimize_Al(l,X,A,lamb,alpha,max_iter, verbose=False):
    for _ in range(max_iter):
        A[l] = update_Al(l,X,A,lamb,alpha)
        if verbose:
            loss = objective1_l(l,X,A,lamb)
            print(loss)

    return A[l]

"""
For update lambs
"""
def get_diff_lamb(X,A,lamb):
    value = 0.0
    N = len(A)
    D = np.diag(lamb)
    F = np.shape(A[0])[1]
    
    Z = {}
    for d in range(N):
        for l in range(N):
            if d < l:
                Ydl = X[d,l] / ( A[d] @ D @ A[l].T )
                Z[d,l] = A[d].T @ Ydl @ A[l]
                assert np.shape(Z[d,l]) == (F,F)
            else:
                Z[d,l] = np.nan

    diff_lamb = np.zeros(F)
    for f in range(F):
        value = 0.0
        for d in range(N):
            for l in range(N):
                if d < l:
                    value += -Z[d,l][f,f]

        diff_lamb[f] = value

    return diff_lamb
    
def update_lamb(X,A,lamb,alpha):
    diff_lamb = get_diff_lamb(X,A,lamb)

    #print(diff_lamb)

    lamb = lamb * np.exp( -alpha * diff_lamb )
    lamb = lamb/np.sum(lamb)
    return lamb

def optimize_lamb(X,A,lamb,alpha,max_iter,verbose=False):
    for _ in range(max_iter):
        lamb = update_lamb(X,A,lamb,alpha)
        if verbose:
            loss = objective2(X,A,lamb)
            print(loss)
    return lamb


"""
body
"""

def CNMFOPT_dense(T, F, alpha=0.1, max_iter_outer=50, max_iter_inner=50, verbose=False, tol=1.0e-5, conv_check_interval=20):
    # We assume the input need to be normalized
    T = T / np.sum(T)
    
    # Tensor size (size of sample space)
    I = np.shape(T)
    # Tensor dim
    N = len(I)

    # second margnalized X
    X = get_second_margnalizedX(T) 

    # Initialize CP factors
    A = {}
    for d in range(N):
        A[d] = np.random.rand(I[d],F)
        # normalize each column 
        # np.sum(A[d], axis=0) should be (1,...,1)
        each_norm = np.linalg.norm(A[d], ord=1, axis=0) # ord 1 means L1 norm
        A[d] = A[d] / each_norm[np.newaxis,:]
        
    # Initialize lambs
    lamb = np.random.rand(F)
    lamb = lamb / np.sum(lamb)

    prev_loss = 1.0e10
    for t in range(max_iter_outer):
        lamb = optimize_lamb(X,A,lamb,alpha,max_iter_inner)
        for l in range(N):
            A[l] = optimize_Al(l,X,A,lamb,alpha,max_iter_inner)

        if t % conv_check_interval == 0:
            P = reconst_CPD_lamb_A(lamb,A)
            loss = KL_divergence(T, P)
            if verbose:
                print(t, loss)
                
            if t > 2 and abs(prev_loss - loss) < tol:
                print(prev_loss, loss)
                break
            else:
                prev_loss = loss


    return lamb, A

def CNMFOPT_sparse(coords, values, I, F, alpha=0.1, max_iter_outer=50, max_iter_inner=50, verbose=False, tol=1.0e-5, conv_check_interval=20):
    ##
    # I ... Tensor size (size of sample space)
    # F ... Rank
    ##
    
    # Tensor dim
    N = len(I)

    # second margnalized X
    X = get_X_from_samples(coords, values, I)
    
    # Initialize CP factors
    A = {}
    
    np.random.seed(os.getpid() + int(time.time() * 1000) % 1000)
    for d in range(N):
        A[d] = np.random.rand(I[d],F)
        # normalize each column 
        # np.sum(A[d], axis=0) should be (1,...,1)
        each_norm = np.linalg.norm(A[d], ord=1, axis=0) # ord 1 means L1 norm
        A[d] = A[d] / each_norm[np.newaxis,:]
        
    # Initialize lambs
    lamb = np.random.rand(F)
    lamb = lamb / np.sum(lamb)

    prev_loss = 1.0e10
    for t in range(max_iter_outer):
        lamb = optimize_lamb(X,A,lamb,alpha,max_iter_inner,verbose=False)
        for l in range(N):
            A[l] = optimize_Al(l,X,A,lamb,alpha,max_iter_inner)

        if t % conv_check_interval == 0:
            #loss = KL_divergence(values/np.sum(values), sparse_CPD_from_A_indices(A, lamb, coords))
            loss = NLL(values/np.sum(values), sparse_CPD_from_A_indices(A, lamb, coords))
            if verbose:
                print(t, loss)

            if t > 2 and abs(prev_loss - loss) < tol:
                break
            else:
                prev_loss = loss

    return lamb, A
