from torch import nn
import torch
import torch.nn.functional as F
from pos_encodings import FixedPosEncTrig, LearnNonAugPosEnc, OneHotPosEnc, PosEnc, MAX_LEN

####### Notations preserved from those in this paper: https://arxiv.org/pdf/2402.02098
class LinTransformer(nn.Module):
    def __init__(self, model_dim, qk_dim, no_layers, n_att_heads, lin_att=False,
                 lyr_norm=True, pos_enc_type=None, extra_input_lin_layer=False, extra_output_lin_layer = False, 
                 io_layer_dim=None, projection = False, att_init_scale=None, device='cpu'):
        super(LinTransformer, self).__init__()
        
        self.model_dim = model_dim
        self.io_layer_dim = model_dim
        self.pos_enc_type = pos_enc_type
        self.extra_input_lin_layer = extra_input_lin_layer
        self.extra_output_lin_layer = extra_output_lin_layer
        self.device = device

        if self.pos_enc_type != None:
            if self.pos_enc_type == PosEnc.FIXED_TRIG_FUNCS:
                self.pos_encoder = FixedPosEncTrig(model_dim, device=device)

            elif self.pos_enc_type == PosEnc.LEARNABLE_NON_AUG:
                self.pos_encoder = LearnNonAugPosEnc(device=device)
            else:
                raise Exception("Positional encoding not suported.")

        if self.extra_input_lin_layer:
            self.io_layer_dim = io_layer_dim
            self.input_layer = nn.Linear(model_dim, io_layer_dim, bias=True, device=device)

        learn_non_aug_pe = (pos_enc_type == PosEnc.LEARNABLE_NON_AUG)
        self.layers = nn.ModuleList(
            [TransformerLayer(self.io_layer_dim, self.io_layer_dim, n_att_heads, lin_att, lyr_norm,
                              att_init_scale, projection, learn_non_aug_pe, device) 
            for _ in range(no_layers)]
        )

        if self.extra_output_lin_layer:
            self.output_layer = nn.Linear(self.io_layer_dim, model_dim, bias=True, device=device)


    def forward(self, X): # receives a data block
        transf_params = {}
        if self.extra_input_lin_layer:
            X = self.input_layer(X)
            input_embed = self.input_layer.state_dict()['weight'].reshape(1, self.model_dim, self.io_layer_dim).transpose(-2, -1)
            transf_params.update({"input_embed": input_embed})
        
        if self.pos_enc_type != None:
            X = self.pos_encoder(X)

        for idx, layer in enumerate(self.layers):
            X, layer_params = layer(X)
            transf_params.update({"lyr_" + str(idx + 1): layer_params})

        if self.extra_output_lin_layer:
            X = self.output_layer(X)
            output_embed = self.output_layer.state_dict()['weight'].reshape(1, self.io_layer_dim, self.model_dim).transpose(-2, -1)
            transf_params.update({"output_embed": output_embed})

        return X, transf_params
    

class TransformerLayer(nn.Module):
    def __init__(self, model_dim, qk_dim, n_att_heads, lin_att=False, lyr_norm=True, 
                 att_init_scale=None, projection=False, learn_non_aug_pe=False, device='cpu'):
        super(TransformerLayer, self).__init__()
        
        self.qk_dim = qk_dim
        self.io_layer_dim = model_dim
        self.lyr_norm = lyr_norm
        self.device = device
                    
        self.attention = MaskedAttentionTFGD(projection, model_dim, n_att_heads, 
                                             att_init_scale, linear=lin_att, device=device)

        # Layer normalization
        if self.lyr_norm:
            self.norm1 = nn.LayerNorm(model_dim, device=device)

    def forward(self, X):
        att_output, att_params = self.attention.forward(X)

        # Skip connection
        Z = att_output + X

        # if layer norm
        Z = self.norm1(Z) if self.lyr_norm else Z
        
        X_prime = Z
        
        layer_params = {"att": att_params, "mlp": None}

        return X_prime, layer_params


class MaskedAttentionTFGD(nn.Module):
    def __init__(self, projection, model_dim, n_att_heads, att_init_scale, linear=False, device='cpu'):
        
        super(MaskedAttentionTFGD, self).__init__()
        self.projection = projection
        self.n_heads = n_att_heads
        self.model_dim = model_dim
        self.linear = linear
        self.device = device

        self.P = nn.Parameter(torch.empty(size=(self.n_heads, model_dim, 1), device=device))
        self.Q = nn.Parameter(torch.empty(size=(self.n_heads, model_dim-1, model_dim), device=device))
        self.Prj = nn.Parameter(torch.empty(size=(self.n_heads, model_dim, model_dim), device=device))

        #torch.nn.init.xavier_uniform_(self.Prj, gain=att_init_scale, generator=None)                                    
        #torch.nn.init.xavier_uniform_(self.P, gain=att_init_scale, generator=None)                                    
        #torch.nn.init.xavier_uniform_(self.Q, gain=att_init_scale, generator=None)

        torch.nn.init.xavier_normal_(self.Prj, gain=att_init_scale, generator=None)                                    
        torch.nn.init.xavier_normal_(self.P, gain=att_init_scale, generator=None)                                    
        torch.nn.init.xavier_normal_(self.Q, gain=att_init_scale, generator=None)                                  

    def forward(self, X):
        ### TODO: I removed the traditional scaling. Probably ok, but check w/wo
        batch_sz = X.shape[0]
        seq_len = X.shape[-2]
        mask = torch.zeros(seq_len, seq_len, device=self.device)
        mask[:seq_len-1, :seq_len-1] = torch.eye(seq_len-1, device=self.device) # TODO: here i need to see how this broadcasts

        P_padded = F.pad(self.P, (self.model_dim - 1, 0), value=0.)
        Q_padded = F.pad(self.Q, (0, 0, 0, 1), value=0.)

        query_key = torch.einsum('Bid, HdD, BDj -> BHij', (X, Q_padded, X.transpose(-2,-1)))
        masked = torch.einsum('BHik, kj -> BHij', (query_key, mask))
        atten_scores = (1/seq_len) * torch.einsum('BHpj, Bjk, Hki -> BHpi', (masked, X, P_padded))
        # atten_scores = (1/(seq_len*(self.model_dim - 1)**0.5)) * torch.einsum('BHpj, Bjk, Hki -> BHpi', (masked, X, P_padded))

        result = torch.sum(atten_scores, dim=1) # sum over heads

        params = {"W_QK": Q_padded[0, :, :],
                    "W_V": P_padded[0, :, :],
                    "b_top": self.P[0, :, :],
                    "A": self.Q[0, :, :],
                    "QK": None,
                    "att": None, # with softmax, if exists, otherwise same as QK up to scaling
                    "pos_att": None,
                    "PA": None,
                    "proj": self.Prj}

        return result, params
    
        
class MaskedAttention(nn.Module):
    def __init__(self, projection, model_dim, qk_dim, n_att_heads, linear=False,
                 att_init_scale=None, learn_non_aug_pe=False, max_len=MAX_LEN, device='cpu'):
        
        super(MaskedAttention, self).__init__()
        self.projection = projection
        self.n_heads = n_att_heads
        self.qk_dim = qk_dim
        self.model_dim = model_dim
        self.linear = linear
        self.learn_non_aug_pe = learn_non_aug_pe
        self.device = device
        # the size of the positional embedding vector
        self.embed_dim = 5*model_dim #2*model_dim  #TODO: this need to be changed to a fixed maxlen 

        self.Q = nn.Linear(model_dim, qk_dim * self.n_heads, bias=False, device=device)
        self.K = nn.Linear(model_dim, qk_dim * self.n_heads, bias=False, device=device)
        self.V = nn.Linear(model_dim, model_dim * self.n_heads, bias=False, device=device)
        prjs = torch.empty(size=(self.n_heads, model_dim, model_dim), device=device)
        torch.nn.init.xavier_uniform_(prjs, gain=att_init_scale, generator=None)                                    
        self.Prj = nn.Parameter(prjs, requires_grad=True)
        
        if self.learn_non_aug_pe:
            # The second argument is the len. of the vector representing the one hot.
            # It can be tuned, but should be no larger than seq. len. (otherwise should bring no improvement).
            self.pos_embedding = nn.Embedding(max_len, self.embed_dim, device=device) # 100 is the max. sequence length
            self.PQ = nn.Linear(self.embed_dim, qk_dim, bias=False, device=device) # weight matrix for the Q part in the positional attention
            self.PK = nn.Linear(self.embed_dim, qk_dim, bias=False, device=device) # weight matrix for the K part in the positional attention

        if att_init_scale is not None:
            torch.nn.init.normal_(self.Q.weight, 0.0, att_init_scale)
            torch.nn.init.normal_(self.K.weight, 0.0, att_init_scale)
            torch.nn.init.normal_(self.V.weight, 0.0, att_init_scale)
            torch.nn.init.normal_(prjs, 0.0, att_init_scale)
            self.Prj = nn.Parameter(prjs, requires_grad=True)
            if self.learn_non_aug_pe:
                torch.nn.init.normal_(self.PQ.weight, 0.0, att_init_scale)
                torch.nn.init.normal_(self.PK.weight, 0.0, att_init_scale)

    def forward(self, X):
        batch_sz = X.shape[0]
        seq_len = X.shape[-2]
        Q = self.Q(X).reshape(batch_sz, seq_len, self.n_heads, self.qk_dim).permute(0, 2, 1, 3)
        K = self.K(X).reshape(batch_sz, seq_len, self.n_heads, self.qk_dim).permute(0, 2, 1, 3)
        V = self.V(X).reshape(batch_sz, seq_len, self.n_heads, self.model_dim).permute(0, 2, 1, 3)

        causal_M_infty = nn.Transformer.generate_square_subsequent_mask(K.size(2), device=self.device)
        scale_factor = K.size(-1)**0.5
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale_factor # h attention heads (batch, heads, seqlen, seqlen)

        if (self.learn_non_aug_pe):
            positions = torch.arange(0, seq_len, device=self.device).expand(X.shape[0], seq_len) # one-hot for the whole sequence
            positions = self.pos_embedding(positions) # embed the one-hot to model_dim (input state dim)
            PQ = self.PQ(positions)
            PK = self.PK(positions)

            pos_scores = torch.matmul(PQ, torch.transpose(PK, -2, -1)) / scale_factor

        A = F.softmax(attn_scores + causal_M_infty, dim=-1) if not self.linear else attn_scores * torch.exp(causal_M_infty)

        P = torch.ones_like(A) if not self.learn_non_aug_pe \
                        else (F.softmax(pos_scores + causal_M_infty, dim=-1)).unsqueeze(1) # unsqueeze to account for number of heads dimension
        
        PA = torch.mul(P, A)
        multihead_att_lyr = torch.matmul(PA, V)
        if self.projection:
            multihead_att_lyr = torch.matmul(multihead_att_lyr, self.Prj)

        result = torch.sum(multihead_att_lyr, dim=1) # sum over heads

        W_QK, W_V, flat_attn_scores, flat_A, flat_PA, flat_P = self.__get_params(False, attn_scores, A, PA, P)
        params = {"W_QK": W_QK,
                    "W_V": W_V,
                    "QK": flat_attn_scores,
                    "att": flat_A, # with softmax, if exists, otherwise same as QK up to scaling
                    "pos_att": flat_P,
                    "PA": flat_PA,
                    "proj": self.Prj}

        return result, params

    
    def __get_params(self, flattened, attn_scores, A, PA, P):
        # Compute needed params
        W_QK = torch.matmul(self.Q.state_dict()['weight'].reshape(self.n_heads, self.qk_dim, self.qk_dim), 
                    self.K.state_dict()['weight'].reshape(self.n_heads, self.qk_dim, self.qk_dim).transpose(-2, -1))
        W_V = self.V.state_dict()['weight'].reshape(self.n_heads, self.model_dim, self.model_dim).transpose(-2, -1)
        
        if flattened:
            # Flatten along heads
            W_QK = W_QK.reshape(W_QK.shape[0] *W_QK.shape[1], W_QK.shape[2])
            W_V = W_V.reshape(W_V.shape[0] *W_V.shape[1], W_V.shape[2])

            flat_attn_scores = attn_scores.reshape(attn_scores.shape[0], 
                                                attn_scores.shape[1] * attn_scores.shape[2],
                                                attn_scores.shape[3])
            
            flat_A = A.reshape(A.shape[0], A.shape[1] * A.shape[2], A.shape[3])

            flat_PA = PA.reshape(PA.shape[0], PA.shape[1] * PA.shape[2], PA.shape[3])
            flat_P = P.reshape(P.shape[0], P.shape[1] * P.shape[2], P.shape[3])

            return W_QK, W_V, flat_attn_scores, flat_A, flat_PA, flat_P
        
        return W_QK, W_V, attn_scores, A, PA, P