from Data_generation import lgm      # latent graph embedding model
import TLSM_method as Tm
from Hamming_error import *
from multiprocessing import Pool
import time
import numpy as np
from spectral_aggregation import mean_adj
import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
import tensorly as tl
from hosvd_tucker import hosvd_tucker
tl.set_backend('numpy')
rpy2.robjects.numpy2ri.activate()
robjects.r.source('LSM.r')
robjects.r.source('Codes_Spectral_Matrix.r')


K, R, sigma, s_n, scalar, N = 4, 4, 1.5, 0.15, 2.5, 2
# Lambda = 10 ** np.linspace(-8, -2, 7)
Lambda = [1e-8]
fraction_training, iter_tune, iter_learn = 0.9, 0, 300
Repetition = 50
batch = 10
number_of_batch = 5
seed = range(Repetition)

nlist = [100]
Mlist = [4]


def multi_processing(parameter):
    A = parameter[0]
    K = parameter[1]
    Lambda_list = parameter[2]
    fraction_of_training = parameter[3]
    number_of_iter_tune = parameter[4]
    number_of_iter_learn = parameter[5]
    N = parameter[6]
    self = Tm.TLSM(A, K, fraction_of_training)
    best_lambda_n = self.cross_validation(Lambda_list, number_of_iter_tune, N)
    self = Tm.TLSM(A, K)
    self.lambda_n = best_lambda_n
    self.Num_ite = number_of_iter_learn
    return self.training()


if __name__ == '__main__':
    pool = Pool(int(batch))
    for n in nlist:
        for M in Mlist:
            time_start = time.time()
            assignment_self = []
            assignment_lei = []
            assignment_meanadj = []
            assignment_speck = []
            assignment_tucker = []
            for i in range(number_of_batch):
                configuration = []
                for nr in range(batch):
                    configuration.append([n, M, K, R, sigma, s_n, scalar, seed[i * batch + nr]])
                data = pool.map(lgm, configuration)
                parameterlist = []
                for j in range(batch):
                    parameterlist.append([data[j][0], K, Lambda, fraction_training, iter_tune, iter_learn, N])
                psi_hat_self = pool.map(multi_processing, parameterlist)
                for j in range(batch):
                     parameterlist[j] = parameterlist[j][0:2]
                psi_hat_meanadj = pool.map(mean_adj, parameterlist)
                psi_hat_tucker = pool.map(hosvd_tucker, parameterlist)
                for j in range(batch):
                    assignment_self.append([data[j][1], psi_hat_self[j]])
                    A_lei = tl.unfold(data[j][0], 2)
                    A_lei = tl.fold(A_lei, 0, (M, n, n))
                    psi_hat_lei = robjects.r.GetCluster(A_lei, K)
                    psi_hat_lei = np.array(psi_hat_lei) - 1
                    assignment_lei.append([psi_hat_lei, data[j][1]])
                    assignment_meanadj.append([psi_hat_meanadj[j], data[j][1]])
                    assignment_tucker.append([psi_hat_tucker[j], data[j][1]])
                    X = []
                    for m in range(M):
                        X.append(data[j][0][:, :, m])
                    assignment_speck.append([np.array(robjects.r.speck(X, n, K)) - 1, data[j][1]])
            print("The results for (n, m) = ", (n, M))
            Error_self = pool.map(hamming_error, assignment_self)
            print(Error_self)
            Error_self = np.array(Error_self)
            print("Average error and standard deviation by self =  ", Error_self.mean(),
                  Error_self.std()/np.sqrt(Repetition))
            Error_lei = pool.map(hamming_error, assignment_lei)
            print(Error_lei)
            Error_lei = np.array(Error_lei)
            print("Average error and standard deviation by lei =  ", Error_lei.mean(),
                  Error_lei.std() / np.sqrt(Repetition))
            Error_meanadj = pool.map(hamming_error, assignment_meanadj)
            print(Error_meanadj)
            Error_meanadj = np.array(Error_meanadj)
            print("Average error and standard deviation by mean adjacency matrix =  ", Error_meanadj.mean(),
                  Error_meanadj.std() / np.sqrt(Repetition))
            Error_tucker = pool.map(hamming_error, assignment_tucker)
            print(Error_tucker)
            Error_tucker = np.array(Error_tucker)
            print("Average error and standard deviation by hosvd-tucker =  ", Error_tucker.mean(),
                  Error_tucker.std() / np.sqrt(Repetition))
            Error_speck = pool.map(hamming_error, assignment_speck)
            print(Error_speck)
            Error_speck = np.array(Error_speck)
            print("Average error and standard deviation by spectral kernel =  ", Error_speck.mean(),
                 Error_speck.std() / np.sqrt(Repetition))
            time_end = time.time()
            print("Running time:", time_end - time_start)
    pool.close()
    pool.join()

