from Data_generation import lgm
import TLSM_method as Tm
from Hamming_error import *
from multiprocessing import Pool
import time
import numpy as np
import tensorly as tl
import pickle
tl.set_backend('numpy')


def save(filename, data):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)


def load(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


K, R, sigma, scalar = 3, 3, 1, 2.5
iter_learn = 300
Repetition = 50
seed = range(Repetition)
s_nlist = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]
# s_nlist = [0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
core = len(s_nlist)
n = 300
M = 5


def multi_processing(parameter):
    A = parameter[0]
    K = parameter[1]
    my_sn = parameter[2]
    number_of_iter_learn = parameter[3]
    self = Tm.TLSM(A, K)
    self.s_n = my_sn
    self.Num_ite = number_of_iter_learn
    return self.training()


if __name__ == '__main__':
    pool = Pool(int(core))
    ablation_matrix = np.zeros((Repetition, core))
    time_start = time.time()
    for i in range(Repetition):
        assignment_self = []
        configuration = [n, M, K, R, sigma, 0.3, scalar, seed[i]]
        data = lgm(configuration)
        parameterlist = []
        for j in range(core):
            parameterlist.append([data[0], K, s_nlist[j], iter_learn])
        psi_hat_self = pool.map(multi_processing, parameterlist)
        for j in range(core):
            assignment_self.append([data[1], psi_hat_self[j]])
        Error_self = pool.map(hungarian_hamming_error, assignment_self)
        print(Error_self)
        Error_self = np.array(Error_self)
        ablation_matrix[i] = Error_self
        time_end = time.time()
        print("Running time:", time_end - time_start)
    save("ablation_matrix_additional_before_05.pkl", ablation_matrix)
    pool.close()
    pool.join()

