import torch.nn as nn
from fast_attention import *

class pretrain_loss(nn.Module):

    def __init__(self, weight):
        super(pretrain_loss, self).__init__()
        self.weight = weight

    def forward(self,y, y_std):
        y = y.reshape(-1)
        # print(y[:-1].shape)
        y_std = y_std.reshape(-1)
        # print(y_std.shape)

        loss_1 = nn.MSELoss()
        loss_2 = nn.MSELoss()
        l1 = loss_1(y[:-1], y_std[:-1])
        l2 = loss_2(y[-1], y_std[-1])
        res = (1-self.weight)*l1 + self.weight*l2
        return res

class GDloss(nn.Module):

    def __init__(self, constant = 1, lr = 1, norm = False):
        super(GDloss, self).__init__()
        self.constant = constant
        self.lr = lr
        self.norm = norm

    def forward(self, y_std, y_pred):    # y_std: [demon_num, dim]   y_pred: [demon_num, dim]
        N = y_std.shape[0]
        # print(N)
        scale = 1/ (self.constant * self.lr)
        # print(scale)
        if self.norm == True:
            y_std = F.normalize(y_std, dim = 1)
            y_pred = F.normalize(y_pred, dim = 1)
            scale = scale / N
        res_sum = (y_std * y_pred).sum(dim = [0, 1])
        # print(res_sum)
        res = - res_sum * scale
        # print(res)
        return res

class GDloss_Regular(nn.Module):

    def __init__(self, constant = 1, lr = 1, norm = False, alpha = 0.1):
        super(GDloss_Regular, self).__init__()
        self.constant = constant
        self.lr = lr
        self.norm = norm
        self.alpha = alpha

    def forward(self, y_std, y_pred, W):    # y_std: [demon_num, dim]   y_pred: [demon_num, dim]
        N = y_std.shape[0]
        # print(N)
        scale = 1/ (self.constant * self.lr)
        # print(scale)
        if self.norm == True:
            y_std = F.normalize(y_std, dim = 1)
            y_pred = F.normalize(y_pred, dim = 1)
            scale = scale / N
        res_sum = (y_std * y_pred).sum(dim = [0, 1])
        # print(res_sum)
        res = - res_sum * scale
        regul_F = self.alpha * torch.square(torch.norm(W)) / (2*self.lr)
        res =  res + regul_F
        return res


def data_transformation(tokens,demon_num, W_q, W_k, W_v, projection_matrix):

    K = torch.einsum("YX,NX->NY", W_k, tokens)
    phi_K = softmax_kernel_transformation(K, False, projection_matrix)

    Q = torch.einsum("YX,NX->NY", W_q, tokens)
    phi_Q = softmax_kernel_transformation(Q, True, projection_matrix)

    V = torch.einsum("YX,NX->NY", W_v, tokens)


    train_x = phi_K[:demon_num, :]
    train_y_std = V[:demon_num, :]
    test_x = phi_Q[demon_num:, :]

    D = noncausal_denominator(phi_Q, phi_K)[-1]
    W_0 = torch.einsum("nd,nm->dm", V[-1:,:], phi_K[-1:,:])/D

    return train_x, train_y_std, test_x, D, W_0


def data_transformation_FFN(tokens, demon_num, W_q, W_k, W_v, projection_matrix, W_F):

    K = torch.einsum("YX,NX->NY", W_k, tokens)
    phi_K = softmax_kernel_transformation(K, False, projection_matrix)

    Q = torch.einsum("YX,NX->NY", W_q, tokens)
    phi_Q = softmax_kernel_transformation(Q, True, projection_matrix)

    V = torch.einsum("YX,NX->NY", W_v, tokens)


    train_x = phi_K[:demon_num, :]
    train_y_std = V[:demon_num, :]
    train_y_std = torch.einsum("oY,NY ->No", W_F, train_y_std)
    
    test_x = phi_Q[demon_num:, :]

    D = noncausal_denominator(phi_Q, phi_K)[-1]
    W_0 = torch.einsum("nd,nm->dm", V[-1:,:], phi_K[-1:,:])/D
    W_init = torch.einsum("od,dm->om", W_F, W_0)
    # print(W_0.shape)


    return train_x, train_y_std, test_x, D, W_init


class GD_Data(Dataset):
    def __init__(self, data_x, data_y):
        super(Dataset, self).__init__()
        self.data_x = data_x
        self.data_y = data_y
        
        
    def __getitem__(self, index):

        data_x = self.data_x[index,:]
        data_y = self.data_y[index,:]
        return data_x, data_y

    def __len__(self):
        return self.data_x.shape[0]

