import numpy as np
import math
from itertools import product
import sp_tensor

def re_order(T_train, T_valid, T_test):
    chosen, _ = get_MI(T_train)
    order_tensor_size = tuple(T_train.tensor_size[d] for d in chosen)

    order_coords_train = T_train.coords[:, chosen]
    order_coords_valid = T_valid.coords[:, chosen]
    order_coords_test = T_test.coords[:, chosen]
    
    T_train_order = sp_tensor.Sp_tensor(order_coords_train, T_train.values, order_tensor_size)
    T_valid_order = sp_tensor.Sp_tensor(order_coords_valid, T_valid.values, order_tensor_size, check_empty=False)
    T_test_order = sp_tensor.Sp_tensor(order_coords_test, T_test.values, order_tensor_size, check_empty=False)
    
    return T_train_order, T_valid_order, T_test_order

def kl_div(p, q, small_val=1.0e-10):
    N, M = np.shape(p)
    kl = 0
    for n in range(N):
        for m in range(M):
            if p[n,m] > 0:
                kl += p[n,m] * math.log( p[n,m] ) - p[n,m] * math.log( max(q[n,m], small_val) )
    return kl
    
def get_MI(T):
    coords = T.coords
    J = T.tensor_size
    D = T.tensor_dim
    v = { d : np.zeros( J[d] ) for d in range(D) }
    w = { (d,l) : np.zeros( (J[d], J[l] )) for (d,l) in product(range(D), range(D) ) }
    
    for d in range(D):
        for coord in coords:
            val = coord[d]
            v[d][val] += T.coord_to_value[*coord]
    
    for d in range(D):
        for l in range(D):
            for coord in coords:
                val_d = coord[d]
                val_l = coord[l]
                w[d,l][val_d,val_l] += T.coord_to_value[*coord]

    p = { d: v[d] / np.sum(v[d]) for d in range(D) }
    p_joint = { (d,l): w[d,l] / np.sum(w[d,l]) for (d,l) in product(range(D),range(D)) }

    ## Get Mutual Information
    MI = np.zeros((D,D))
    for d in range(D):
        for l in range(D):
            MI[d,l] = kl_div(p_joint[d,l], np.outer(p[d], p[l]))

    ## Get Normalized Mutual Information
    NMI = np.zeros((D,D))
    for d in range(D):
        for l in range(d-1):
            NMI[d,l] = MI[d,l] / math.sqrt( MI[d,d] * MI[l,l] )

    ## Find largest NMI
    chosen = []
    i, j = np.unravel_index(np.argmax(NMI), NMI.shape)
    chosen.append(i)
    chosen.append(j)

    d = 0
    while( d < D ):
        # Find next to j
        tmp = -1.0
        current_j = -1
        for idx in [d for d in range(D) if d not in chosen]:
            if tmp < NMI[j,idx]:
                tmp = NMI[j,idx]
                current_j = idx

        d += 1
        if current_j  != -1:
            chosen.append(current_j)
            
        print(chosen)

        tmp = -1.0
        current_i = -1
        for idx in [d for d in range(D) if d not in chosen]:
            if tmp < NMI[i,idx]:
                tmp = NMI[i,idx]
                current_i = idx

        if current_i != -1:
            chosen.insert(0, current_i)

        i = current_i
        j = current_j
        d += 1
        
        print(chosen)
    
    return chosen, NMI