import numpy as np
import TLSM_method as Tm
from Hamming_error import hungarian_hamming_error
from spectral_aggregation import mean_adj
import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
import tensorly as tl
from hosvd_tucker import hosvd_tucker


tl.set_backend('numpy')
rpy2.robjects.numpy2ri.activate()
robjects.r.source('LSM.r')
robjects.r.source('Codes_Spectral_Matrix.r')

Layer = ['lunch', 'facebook', 'coauthor', 'leisure', 'work']


with open('aucs_edgelist.txt') as f:
    lines = f.readlines()
with open('aucs_nodelist.txt') as f:
    lines_node = f.readlines()

U_1 = []
U_2 = []
R = []
for i in lines:
    Text = i.split(',')
    U_1.append(Text[0])
    U_2.append(Text[1])
    R.append(Text[2].rstrip())
Uni_User = np.unique(U_1+U_2).tolist()
A = np.zeros([len(Uni_User), len(Uni_User), 5])
for i in range(len(U_1)):
    Index_1 = Uni_User.index(U_1[i])
    Index_2 = Uni_User.index(U_2[i])
    Index_L = Layer.index(R[i])
    A[Index_1, Index_2, Index_L], A[Index_2, Index_1, Index_L] = 1, 1

Users = []
Group = []
for i in lines_node:
    Text = i.split(',')
    Users.append(Text[0])
    Group.append(Text[1][0:2])
Label = []
for u in Uni_User:
    INdex_u = Users.index(u)
    Label.append(Group[INdex_u])
Label = np.array(Label)
Label[Label == 'G1'] = 1
Label[Label == 'G2'] = 2
Label[Label == 'G3'] = 3
Label[Label == 'G4'] = 4
Label[Label == 'G5'] = 5
Label[Label == 'G6'] = 6
Label[Label == 'G7'] = 7
Label[Label == 'G8'] = 0
Label[Label == 'NA'] = 0
Label = np.array(Label, dtype=int)
"""
Remark: there is only one node, node 19, whose label is 'G8'. Degree of node 19 is 2. 
The eigen-gap plot shows that there are 8 communities, so we classifies node 19 to the outlier community.
"""

n, M = A.shape[0], A.shape[2]

np.random.seed(1)

# --------------------- self ---------------------
self = Tm.TLSM(A, 8)
self.Num_ite = 1000
self.lambda_n = 0.001
Label_YM = self.training()
ZYM_error = hungarian_hamming_error([Label_YM, Label])
print(ZYM_error)

# ---------------------- mean-adj ---------------------------
psi_hat_meanadj = mean_adj([A, 8])
meanadj_error = hungarian_hamming_error([psi_hat_meanadj, Label])
print(meanadj_error)


# ----------------------- Tucker ------------------------
psi_hat_tucker = hosvd_tucker([A, 8])
tucker_error = hungarian_hamming_error([psi_hat_tucker, Label])
print(tucker_error)

# ----------------------- LSE ------------------------------
A_lei = tl.unfold(A, 2)
A_lei = tl.fold(A_lei, 0, (M, n, n))
psi_hat_lei = robjects.r.GetCluster(A_lei, 8)
psi_hat_lei = np.array(psi_hat_lei) - 1
lse_error = hungarian_hamming_error([psi_hat_lei, Label])
print(lse_error)

# ---------------------- SPECK -------------------------------
X = []
for m in range(M):
    X.append(A[:, :, m])
psi_hat_speck = np.array(robjects.r.speck(X, n, 8)) - 1
speck_error = hungarian_hamming_error([psi_hat_speck, Label])
print(speck_error)

print("The number of missing classifying node of self, LSE, MASE, HOSVD-Tucker, and SPECK are:")
print(ZYM_error * n, lse_error*n, meanadj_error*n, tucker_error*n, speck_error*n, 'respectively')


"""
Note: speck is not stable, the average "speck_error * n" over 100 independent experiment is 17.81.
lse is not stable, the average "lse_error *n" over 100 independent experiment is 21.
"""
