from Data_generation import lgm
import TLSM_method as Tm
from Hamming_error import *
from multiprocessing import Pool
import time
import numpy as np
import tensorly as tl
import pickle
tl.set_backend('numpy')


def save(filename, data):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)


def load(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


K, R, sigma, scalar = 3, 3, 1, 2.5
iter_learn = 300
Repetition = 50
seed = range(Repetition)
core = 10
n = 300
M = 5


def multi_processing(parameter):
    A = parameter[0]
    K = parameter[1]
    number_of_iter_learn = parameter[2]
    self = Tm.TLSM(A, K)
    self.Num_ite = number_of_iter_learn
    return self.training()


if __name__ == '__main__':
    pool = Pool(int(core))
    time_start = time.time()
    assignment_self = []
    for i in range(5):
        configuration = []
        for nr in range(10):
            configuration.append([n, M, K, R, sigma, 0.3, scalar, seed[i * 10 + nr]])
        data = pool.map(lgm, configuration)  # generate data according to the latent graph embedding model
        parameterlist = []
        for j in range(10):
            parameterlist.append([data[j][0], K, iter_learn])
        psi_hat_self = pool.map(multi_processing, parameterlist)
        for j in range(10):
            assignment_self.append([data[j][1], psi_hat_self[j]])
    Error_self = pool.map(hungarian_hamming_error, assignment_self)
    print(Error_self)
    Error_self = np.array(Error_self)
    print("Average error and standard deviation by self =  ", Error_self.mean(),
              Error_self.std() / np.sqrt(Repetition))
    save("ablation_vector_sigma1.pkl", Error_self)
    pool.close()
    pool.join()
