from Data_generation import *
import TLSM_method as Tm
from multiprocessing import Pool
import time
import numpy as np
import tensorly as tl
import pickle
tl.set_backend('numpy')


def save(filename, data):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)


def load(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


M, K, R, s_n, scalar, N, sigma = 5, 2, 2,  0.1, 1, 1, 0
fraction_training, iter_tune, iter_learn = 0.9, 60, 300
Lambda = [1e-5, 1e-4, 1e-3, 1e-2, 0.1]
Repetition = 50
seed = range(Repetition)
n_list = [50, 100, 200, 400]


def multi_processing(parameter):
    A = parameter[0]
    self = Tm.TLSM(A, K, fraction_training)
    best_lambda_n = self.cross_validation(Lambda, iter_tune, N)
    self = Tm.TLSM(A, K)
    self.lambda_n = best_lambda_n
    self.Num_ite = iter_learn
    label = self.training()
    return mmd(self.Identity, [self.alpha, self.alpha, self.beta])


def multi_processing_0(parameter):
    A = parameter[0]
    self = Tm.TLSM(A, K)
    self.lambda_n = 0
    self.Num_ite = iter_learn
    labels = self.training()
    return mmd(self.Identity, [self.alpha, self.alpha, self.beta])


def f_norm_lose(parameter):
    theta_1 = parameter[0]
    theta_2 = parameter[1]
    n, M = theta_1.shape[0], theta_1.shape[2]
    return np.linalg.norm(theta_1 - theta_2)/(n * np.sqrt(M))


if __name__ == '__main__':
    pool = Pool(int(10))
    time_start = time.time()
    for n in n_list:
        err = []
        err_0 = []
        for i in range(5):
            configuration = []
            for nr in range(10):
                configuration.append([n, M, K, R, sigma, s_n, scalar, seed[nr + 10 * i]])
            data = pool.map(lgm_theta, configuration)  # generate data according to the latent graph embedding model
            parameterlist = []
            for j in range(10):
                parameterlist.append([data[j][0]])
            theta_hat_list = pool.map(multi_processing, parameterlist)
            theta_pair_list = []
            for j in range(10):
                theta_pair_list.append([data[j][1], theta_hat_list[j]])
            err += pool.map(f_norm_lose, theta_pair_list)
            theta_hat_list = pool.map(multi_processing_0, parameterlist)
            for j in range(10):
                theta_pair_list[j][1] = theta_hat_list[j]
            err_0 += pool.map(f_norm_lose, theta_pair_list)
        print(err)
        err = np.array(err)
        print("Average error and standard deviation with lambda=tuning",
              err.mean(), err.std() / np.sqrt(Repetition))
        print(err_0)
        err_0 = np.array(err_0)
        print("Average error and standard deviation with lambda=0",
              err_0.mean(), err_0.std() / np.sqrt(Repetition))
        time_end = time.time()
        print("Running time:", time_end - time_start)
    pool.close()
    pool.join()
