import TLSM_fun as tf
import numpy as np
from tensorly.tenalg import multi_mode_dot as mmd
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize


class TLSM(object):
    def __init__(self, A, K=4, frac_training=1):
        self.A = A                                  # Multi-layer network adjacency tensor
        self.K = K                                  # Number of communities
        self.frac_training = frac_training          # fraction of training samples
        self.n = self.A.shape[1]                    # number of vertices
        self.M = self.A.shape[2]                    # the number of layers
        self.R = K                                  # the embedding dimension
        self.Identity = np.zeros([self.R] * 3)      # Identity tensor
        for r in range(self.R):
            self.Identity[tuple([r] * 3)] = 1
        self.B = np.random.binomial(1, self.frac_training, (self.n, self.n, self.M))
        for m in range(self.M):
            self.B[:, :, m] = np.triu(self.B[:, :, m]) + np.triu(self.B[:, :, m], 1).T
        self.B0 = np.zeros(self.B.shape)
        for m in range(self.M):
            self.B0[:, :, m] = np.triu(self.B[:, :, m])      # indicator of independent random variables in A.
        s = (self.A * self.B).sum(axis=(1, 2)) / (self.n * self.M * self.frac_training)
        self.s_n = np.max(s) + np.min(s)
        self.theta_max = 200
        self.phi_nM = (self.n+1)*self.n*self.M/2
        self.eta = 50
        self.initial_iter = 50
        self.tuning_scheme = 'no_tuning'
        self.lambda_n = 1e-8
        self.alpha = np.random.normal(0, 1, (self.n, self.R))
        self.beta = np.random.normal(0, 1, (self.M, self.R))
        self.centers = np.random.normal(0, 1, (self.K, self.R))
        self.loss = 10**6
        self.Num_ite = 2000
        self.check = 40
        self.KM_dist = 0.
        self.ABCL_Set = []
        self.Loss_crite = [self.loss]
        self.labels = np.zeros(self.n)
        self.i = 0

    def Get_likelihood(self, alpha, beta, phi):
        value = 0
        Theta = mmd(self.Identity, [alpha, alpha, beta])
        Theta[Theta >= self.theta_max] = self.theta_max
        Theta[Theta <= -self.theta_max] = -self.theta_max
        X = self.s_n / (1 - self.s_n + np.exp(-Theta))
        value += ((np.log(1 + X) - self.A * np.log(X)) * self.B0).sum()/phi
        value += (self.lambda_n / self.n) * self.KM_dist
        return value

    def Get_held_out_value(self, alpha, beta, phi):
        value = 0
        Theta = mmd(self.Identity, [alpha, alpha, beta])
        Theta[Theta >= self.theta_max] = self.theta_max
        Theta[Theta <= -self.theta_max] = -self.theta_max
        X = self.s_n / (1 - self.s_n + np.exp(-Theta))
        B0_tilde = 1 - self.B0
        for m in range(self.M):
            B0_tilde[:, :, m] = np.triu(B0_tilde[:, :, m])
        value += ((np.log(1 + X) - self.A * np.log(X)) * B0_tilde).sum()/phi
        return value

    def Get_link_prediciton_error(self, alpha, beta):
        Theta = mmd(self.Identity, [alpha, alpha, beta])
        Theta[Theta >= self.theta_max] = self.theta_max
        Theta[Theta <= -self.theta_max] = -self.theta_max
        P = self.s_n / (1 + np.exp(-Theta))
        A_hat = np.random.binomial(1, P)
        B0_tilde = 1 - self.B0
        for m in range(self.M):
            B0_tilde[:, :, m] = np.triu(B0_tilde[:, :, m])
        return (np.abs((A_hat - self.A) * B0_tilde)).sum()/B0_tilde.sum()

    def Get_estimated_ablation_lambda(self, alpha, beta, phi):
        Theta = mmd(self.Identity, [alpha, alpha, beta])
        Theta[Theta >= self.theta_max] = self.theta_max
        Theta[Theta <= -self.theta_max] = -self.theta_max
        X = self.s_n / (1 - self.s_n + np.exp(-Theta))
        negative_log_likelihood = ((np.log(1 + X) - self.A * np.log(X)) * self.B0).sum() / phi
        return negative_log_likelihood/self.KM_dist

    def Get_CL(self):
        kmeans = KMeans(n_clusters=self.K).fit(self.alpha)
        self.KM_dist = kmeans.inertia_
        return kmeans.cluster_centers_, kmeans.labels_

    def initialization(self):
        if self.tuning_scheme == 'cross_validation':
            self.alpha, self.beta = tf.transformed_tucker(self.A*self.B, self.R, self.initial_iter)
        else:
            self.alpha, self.beta = tf.transformed_tucker(self.A, self.R, self.initial_iter)
        self.centers, self.labels = self.Get_CL()
        self.loss = self.Get_likelihood(self.alpha, self.beta, self.phi_nM * self.frac_training)
        self.ABCL_Set.append([self.alpha, self.beta, self.centers, self.labels])
        self.Loss_crite.append(self.loss)

    def pgd(self):
        X_ab = np.zeros((self.n, self.R))
        X_aa = np.zeros((self.M, self.R))              # up to a transpose as the notation in the paper
        X_b = np.zeros((self.n, self.R))
        X = np.zeros((self.n, self.M))
        Theta = mmd(self.Identity, [self.alpha, self.alpha, self.beta])
        Theta[Theta >= self.theta_max] = self.theta_max
        Theta[Theta <= -self.theta_max] = -self.theta_max
        P = self.s_n / (1 + np.exp(-Theta))
        T = (np.exp(-Theta) / (1 - self.s_n + np.exp(-Theta))) * (P - self.A)
        if self.tuning_scheme == 'cross_validation':
            T *= self.B
        Q = mmd(T, [self.alpha, self.beta], modes=[1, 2], transpose=True)
        for r in range(self.R):
            X_ab[:, r] = Q[:, r, r]
        Q = mmd(T, [self.alpha, self.alpha], modes=[0, 1], transpose=True)
        for r in range(self.R):
            X_aa[:, r] = Q[r, r, :]
        Q = mmd(T, [self.beta], modes=[2], transpose=True)
        for i in range(self.n):
            X_b[i] = Q[i, i, :]
            X[i] = T[i, i, :]
        gradient_alpha = (X_ab + X_b * self.alpha) / self.phi_nM + 2 * (self.lambda_n / self.n) * (
                    self.alpha - self.centers[self.labels])
        gradient_beta = (X_aa + X.T @ (self.alpha * self.alpha)) / (2 * self.phi_nM)
        self.alpha -= self.eta * gradient_alpha
        self.beta -= self.eta * gradient_beta
        self.centers, self.labels = self.Get_CL()
        if (self.i % self.check == 0) and (self.i != 0):
            # for r in range(self.R):
            #    self.alpha[:, r] *= np.sqrt(np.linalg.norm(self.beta[:, r])/np.sqrt(self.M))
            for i in range(self.n):
                norm = np.linalg.norm(self.alpha[i])
                if norm > 100:
                    self.alpha[i] *= 100/norm
            self.beta = normalize(self.beta.T).T*np.sqrt(self.M)
            temp_loss = self.Get_likelihood(self.alpha, self.beta, self.phi_nM)
            if temp_loss > self.loss:
                self.alpha, self.beta, self.centers, self.labels = self.ABCL_Set[-1]
                self.centers, self.labels = self.Get_CL()
                self.eta *= 0.75
                print('Learning rate change')
            else:
                self.loss = np.copy(temp_loss)
                self.Loss_crite.append(temp_loss)
                self.ABCL_Set.append([self.alpha, self.beta, self.centers, self.labels])

    def training(self):
        self.initialization()
        for i in range(self.Num_ite):
            self.i = i
            self.pgd()
            if (self.i % self.check == 0) and (self.i != 0):
                diff = self.Loss_crite[-2] - self.Loss_crite[-1]
                if diff / self.Loss_crite[-2] < 10 ** -6:
                    print('break')
                    break
        if self.tuning_scheme == 'cross_validation':
            phi_0 = self.phi_nM * (1 - self.frac_training) / self.frac_training
            return self.Get_held_out_value(self.alpha, self.beta, phi_0)
        else:
            return self.labels

    def cross_validation(self, lambda_list, number_of_iterations, N):
        self.Num_ite = number_of_iterations
        self.tuning_scheme = 'cross_validation'
        H = len(lambda_list)
        held_out_value = np.zeros((N, H))
        for repetition in range(N):
            self.B = np.random.binomial(1, self.frac_training, (self.n, self.n, self.M))
            for m in range(self.M):
                self.B[:, :, m] = np.triu(self.B[:, :, m]) + np.triu(self.B[:, :, m], 1).T
            for m in range(self.M):
                self.B0[:, :, m] = np.triu(self.B[:, :, m])
            s = (self.A * self.B).sum(axis=(1, 2)) / (self.n * self.M * self.frac_training)
            self.s_n = np.max(s) + np.min(s)
            for h in range(H):
                self.lambda_n = lambda_list[h]
                held_out_value[repetition, h] = self.training()
        average_held_out_value = held_out_value.mean(axis=0)
        index = average_held_out_value.argmin()
        # print("The best lambda is:, ", lambda_list[index])
        return lambda_list[index]
