import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
import numpy as np
import tensorly as tl
from tensorly.tenalg import multi_mode_dot as mmd
import pickle

tl.set_backend('numpy')
rpy2.robjects.numpy2ri.activate()
robjects.r.source('LSM.r')


def save(filename, my_data):
    with open(filename, 'wb') as f:
        pickle.dump(my_data, f)


def load(filename):
    with open(filename, 'rb') as f:
        my_data = pickle.load(f)
    return my_data


Tensor = np.load('wat_edge.npy')
Node_set = np.load('wat_node.npy')
Node_set[58] = "Cote d'Ivoire"
Tensor[Tensor > 0] = 1


M = 32

layer_deg = Tensor.sum(axis=(0, 1))
index = layer_deg.argsort()
index = list(index)
index.reverse()
index = index[:M]
NT = Tensor[:, :, index]

deg = NT.sum(axis=(1, 2))
ind = np.where(deg/M > 9)[0]
# Node_set = Node_set[ind]
NT = NT[ind]
NT = NT[:, ind, :]


"""
deg = Tensor.sum(axis=(1, 2))
ind = np.where(deg/364 > 9)[0]
Node_set = Node_set[ind]
NT = Tensor[ind]
NT = NT[:, ind, :]
"""

K = 6
fraction_of_training = 0.8
number_of_iter_learn = 1800
Repetition = 50

n = NT.shape[0]
A_lei = tl.unfold(NT, 2)
A_lei = tl.fold(A_lei, 0, (M, n, n))

link_prediciton_error = []
for i in range(Repetition):
    B = np.random.binomial(1, fraction_of_training, (M, n, n))
    for m in range(M):
        B[m] = np.triu(B[m]) + np.triu(B[m], 1).T
    B0 = np.zeros(B.shape)
    for m in range(M):
        B0[m] = np.triu(B[m])                         # indicator of independent random variables in A.
    psi_hat_lei = robjects.r.GetCluster(A_lei * B, K)
    psi_hat_lei = np.array(psi_hat_lei) - 1
    Z = np.eye(K)[psi_hat_lei]
    ind_list = []
    for k in range(K):
        ind_list.append(np.where(psi_hat_lei == k)[0])
    Core = np.zeros((M, K, K))
    for k1 in range(K):
        for k2 in range(K):
            if len(ind_list[k1]) * len(ind_list[k2]) != 0:
                Core[:, k1, k2] = A_lei[:, ind_list[k1], :][:, :, ind_list[k2]].sum(axis=(1, 2))\
                                  /(len(ind_list[k1]) * len(ind_list[k2]))
            else:
                Core[:, k1, k2] = np.zeros(M)
    P = mmd(Core, [Z]*2, modes=[1, 2])
    A_hat = np.random. binomial(1, P)
    B0_tilde = 1 - B0
    for m in range(M):
        B0_tilde[:, :, m] = np.triu(B0_tilde[:, :, m])
    link_prediciton_error.append((np.abs((A_hat - A_lei) * B0_tilde)).sum() / B0_tilde.sum())

print("The averaged link prediction error by LSE over 50 independent replications is:",
      np.array(link_prediciton_error).mean(), "with standard error is:",
      np.array(link_prediciton_error).std()/np.sqrt(Repetition))




