from torch import nn
import torch
import torch.nn.functional as F
from pos_encodings import FixedPosEncTrig, LearnNonAugPosEnc, OneHotPosEnc, PosEnc, MAX_LEN



####### Notations preserved from those in this paper: https://arxiv.org/pdf/2402.02098
class Transformer(nn.Module):
    def __init__(self, model_dim, qk_dim, ff_dim, no_layers, n_att_heads, att_mlp=True, lin_att=False,
                 gating = False, skip_conn=True, lyr_norm=True, pos_enc_type=PosEnc.FIXED_TRIG_FUNCS, 
                 extra_input_lin_layer=False, extra_output_lin_layer = False, io_layer_dim=None, projection = False, att_init_scale=None, device='cpu'):
        super(Transformer, self).__init__()
        
        # self.use_pos_enc = enc_pos
        self.pos_enc_type = pos_enc_type
        self.model_dim = model_dim
        self.extra_input_lin_layer = extra_input_lin_layer
        self.extra_output_lin_layer = extra_output_lin_layer
        self.io_layer_dim = io_layer_dim
        self.projection = projection
        self.inner_dim = model_dim
        self.learn_non_aug_pe = False
        self.device = device

        if self.pos_enc_type != None:
            if self.pos_enc_type == PosEnc.FIXED_TRIG_FUNCS:
                self.pos_encoder = FixedPosEncTrig(model_dim, device=self.device)

            elif self.pos_enc_type == PosEnc.LEARNABLE_NON_AUG:
                self.pos_encoder = LearnNonAugPosEnc(device=self.device)
                self.learn_non_aug_pe = True
            else:
                raise Exception("Positional encoding not suported.")

        if self.extra_input_lin_layer:
            self.input_layer = nn.Linear(self.model_dim, self.io_layer_dim, bias=True, device=self.device)

        self.layers = nn.ModuleList(
            [TransformerLayer(self.inner_dim, self.io_layer_dim, ff_dim, n_att_heads, att_mlp, 
                              lin_att, gating, skip_conn, lyr_norm, att_init_scale, 
                              self.extra_input_lin_layer, extra_output_lin_layer, self.io_layer_dim, self.projection, self.learn_non_aug_pe, 
                              self.device) for _ in range(no_layers)]
        )

        if self.extra_output_lin_layer:
            self.output_layer = nn.Linear(self.io_layer_dim, self.model_dim, bias=True, device=self.device)


    def forward(self, X): # receives a data block
        transf_params = {}
        if self.extra_input_lin_layer:
            X = self.input_layer(X)
            input_embed = self.input_layer.state_dict()['weight'].reshape(1, self.model_dim, self.io_layer_dim).transpose(-2, -1)
            transf_params.update({"input_embed": input_embed})
        
        if self.pos_enc_type != None:
            X = self.pos_encoder(X)

        for idx, layer in enumerate(self.layers):
            X, layer_params = layer(X)
            transf_params.update({"lyr_" + str(idx + 1): layer_params})

        if self.extra_output_lin_layer:
            X = self.output_layer(X)
            output_embed = self.output_layer.state_dict()['weight'].reshape(1, self.io_layer_dim, self.model_dim).transpose(-2, -1)
            transf_params.update({"output_embed": output_embed})

        return X, transf_params
    

class TransformerLayer(nn.Module):
    def __init__(self, model_dim, qk_dim, ff_dim, n_att_heads, att_mlp= True, 
                 lin_att=False, gating = False, skip_conn=True, lyr_norm=True, 
                 att_init_scale=None, extra_input_lin_layer=False, extra_output_lin_layer = False, 
                 io_layer_dim=None, projection=False, learn_non_aug_pe=False, device='cpu'):
        super(TransformerLayer, self).__init__()
        
        self.model_dim = model_dim
        self.qk_dim = qk_dim
        self.ff_dim = ff_dim
        self.skip_conn = skip_conn
        self.lyr_norm = lyr_norm
        self.att_mlp = att_mlp
        self.extra_input_lin_layer = extra_input_lin_layer
        self.extra_output_lin_layer = extra_output_lin_layer
        self.io_layer_dim = io_layer_dim
        self.projection = projection
        self.device = device
        
        #self.attention = MaskedAttention(model_dim, qk_dim, n_att_heads, linear=lin_att, gating=gating, att_init_scale=att_init_scale, learn_non_aug_pe=learn_non_aug_pe, device=device)
        self.attention = MaskedAttention(projection, io_layer_dim, io_layer_dim, n_att_heads, linear=lin_att, gating=gating, att_init_scale=att_init_scale, learn_non_aug_pe=learn_non_aug_pe, device=device)
        #self.output_layer = nn.Linear(io_layer_dim, model_dim, bias=True, device=self.device)

        # Feed-forward network
        if self.att_mlp:
            #self.W_F1 = nn.Linear(model_dim, ff_dim, bias=True, device=self.device)
            #self.W_F2 = nn.Linear(ff_dim, model_dim, bias=True, device=self.device)
            self.W_F1 = nn.Linear(io_layer_dim, io_layer_dim, bias=True, device=self.device)
            self.W_F2 = nn.Linear(io_layer_dim, io_layer_dim, bias=True, device=self.device)

        # Layer normalization
        if self.lyr_norm:
            #self.norm1 = nn.LayerNorm(model_dim, device=self.device)
            #self.norm2 = nn.LayerNorm(model_dim, device=self.device)
            self.norm1 = nn.LayerNorm(io_layer_dim, device=self.device)
            self.norm2 = nn.LayerNorm(io_layer_dim, device=self.device)

    def forward(self, X):
        att_output, att_params = self.attention.forward(X)

        # if Skip connections
        Z = att_output + X if self.skip_conn else att_output

        # if layer norm
        Z = self.norm1(Z) if self.lyr_norm else Z
        
        X_prime = Z

        # if MLP
        if self.att_mlp:
            H = self.W_F2(F.relu(self.W_F1(Z)))

            X_prime = H + Z if self.skip_conn else H

            X_prime = self.norm2(X_prime) if self.lyr_norm else X_prime
        
        layer_params = {"att": att_params,
                        "mlp": None}
        
        #if self.extra_output_lin_layer: 
        #    X_prime = self.output_layer(X_prime)

        return X_prime, layer_params

        
class MaskedAttention(nn.Module):
    def __init__(self, projection, model_dim, qk_dim, n_att_heads, linear=False, gating=False, 
                 att_init_scale=None, learn_non_aug_pe=False, max_len=MAX_LEN, device='cpu'):
        
        super(MaskedAttention, self).__init__()
        self.projection = projection
        self.n_heads = n_att_heads
        self.qk_dim = qk_dim
        self.model_dim = model_dim
        self.linear = linear
        self.learn_non_aug_pe = learn_non_aug_pe
        self.device = device
        # the size of the positional embedding vector
        self.embed_dim = 5*model_dim #2*model_dim  #TODO: this need to be changed to a fixed maxlen
        self.gating = gating  

        self.Q = nn.Linear(model_dim, qk_dim * self.n_heads, bias=False, device=self.device)
        self.K = nn.Linear(model_dim, qk_dim * self.n_heads, bias=False, device=self.device)
        self.V = nn.Linear(model_dim, model_dim * self.n_heads, bias=False, device=self.device)
        prjs = torch.empty(size=(self.n_heads, model_dim, model_dim), device=self.device)
        torch.nn.init.xavier_uniform_(prjs, gain=1.0, generator=None)                                    
        self.Prj = nn.Parameter(prjs, requires_grad=True)
        
        if self.learn_non_aug_pe:
            # The second argument is the len. of the vector representing the one hot.
            # It can be tuned, but should be no larger than seq. len. (otherwise should bring no improvement).
            self.pos_embedding = nn.Embedding(max_len, self.embed_dim, device=self.device) # 100 is the max. sequence length
            self.PQ = nn.Linear(self.embed_dim, qk_dim, bias=False, device=self.device) # weight matrix for the Q part in the positional attention
            self.PK = nn.Linear(self.embed_dim, qk_dim, bias=False, device=self.device) # weight matrix for the K part in the positional attention

        if att_init_scale is not None:
            torch.nn.init.normal_(self.Q.weight, 0.0, att_init_scale)
            torch.nn.init.normal_(self.K.weight, 0.0, att_init_scale)
            torch.nn.init.normal_(self.V.weight, 0.0, att_init_scale)
            torch.nn.init.normal_(prjs, 0.0, att_init_scale)
            self.Prj = nn.Parameter(prjs, requires_grad=True)
            if self.learn_non_aug_pe:
                torch.nn.init.normal_(self.PQ.weight, 0.0, att_init_scale)
                torch.nn.init.normal_(self.PK.weight, 0.0, att_init_scale)

    def forward(self, X):
        if self.gating:
            return self.__linear_attention_with_gating(X)
        
        batch_sz = X.shape[0]
        seq_len = X.shape[1]
        Q = self.Q(X).reshape(batch_sz, seq_len, self.n_heads, self.qk_dim).permute(0, 2, 1, 3)
        K = self.K(X).reshape(batch_sz, seq_len, self.n_heads, self.qk_dim).permute(0, 2, 1, 3)
        V = self.V(X).reshape(batch_sz, seq_len, self.n_heads, self.model_dim).permute(0, 2, 1, 3)

        causal_M_infty = nn.Transformer.generate_square_subsequent_mask(K.size(2), device=self.device)
        scale_factor = K.size(-1)**0.5
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale_factor # h attention heads (batch, heads, seqlen, seqlen)

        if (self.learn_non_aug_pe):
            T = X.shape[-2] # seq len
            positions = torch.arange(0, T, device=self.device).expand(X.shape[0], T) # one-hot for the whole sequence
            positions = self.pos_embedding(positions) # embed the one-hot to model_dim (input state dim)
            PQ = self.PQ(positions)
            PK = self.PK(positions)

            # TODO: should I not have to learn a single positional encoding here instead of 50? I'm confused
            # actually no, it s probably right just double check
            pos_scores = torch.matmul(PQ, torch.transpose(PK, -2, -1)) / scale_factor

        A = F.softmax(attn_scores + causal_M_infty, dim=-1) if not self.linear else attn_scores * torch.exp(causal_M_infty)

        P = torch.ones_like(A) if not self.learn_non_aug_pe \
                        else (F.softmax(pos_scores + causal_M_infty, dim=-1)).unsqueeze(1) # unsqueeze to account for number of heads dimension
        
        PA = torch.mul(P, A)
        multihead_att_lyr = torch.matmul(PA, V)
        if self.projection:
            multihead_att_lyr = torch.matmul(multihead_att_lyr, self.Prj)

        result = torch.sum(multihead_att_lyr, dim=1) # sum over heads

        W_QK, W_V, flat_attn_scores, flat_A, flat_PA, flat_P = self.__get_params(False, attn_scores, A, PA, P)
        params = {
                    "W_QK": W_QK,
                    "W_V": W_V,
                    "QK": flat_attn_scores,
                    "att": flat_A, # with softmax, if exists, otherwise same as QK up to scaling
                    "pos_att": flat_P,
                    "PA": flat_PA,
                    "proj": self.Prj}

        return result, params
    

    def __linear_attention_with_gating(self, X):
        batch_sz = X.shape[0]
        seq_len = X.shape[1]
        Q = self.Q(X).reshape(batch_sz, seq_len, self.n_heads, self.qk_dim).permute(0, 2, 1, 3)
        K = self.K(X).reshape(batch_sz, seq_len, self.n_heads, self.qk_dim).permute(0, 2, 1, 3)
        V = self.V(X).reshape(batch_sz, seq_len, self.n_heads, self.model_dim).permute(0, 2, 1, 3)
        
        assert K.shape[-1] == V.shape[-1], "For gating to work we need K.shape[-1] == V.shape[-1]"
        gate = torch.eye(K.shape[-1], device=self.device)

        output = torch.zeros_like(Q)
        for h in range(self.n_heads):
            S_t = torch.zeros(batch_sz, K.shape[-1], V.shape[-1], device=self.device)
            for t in range(seq_len):
                S_t = S_t + torch.mul(gate, torch.matmul(K[:, h, t, :].unsqueeze(1).transpose(-2,-1), V[:, h, t, :].unsqueeze(1)))
                S_t_proj = torch.matmul(S_t, self.Prj[h])
                # tmp = torch.matmul(Q[:, h, t, :].unsqueeze(1), S_t).squeeze(1)
                output[:, h, t, :] = torch.matmul(Q[:, h, t, :].unsqueeze(1), S_t_proj).squeeze(1)
                # output = torch.matmul(output_t[:, h, t, :], self.Prj[h])

        X_prime = torch.sum(output, dim=1) # sum over heads


        W_QK, W_V, _, _, _, _ = self.__get_params(False, None, None, None, None)
        params = {
                    "W_QK": W_QK,
                    "W_V": W_V,
                    "QK": torch.matmul(Q, K.transpose(-2, -1)) / K.size(-1),
                    "att": torch.zeros(1, 1, seq_len, seq_len), #same as above but with softmax. does nto apply in this case
                    "pos_att": torch.zeros(1, 1, seq_len, seq_len), # if we use pos enc, what are params
                    "PA": torch.matmul(Q, K.transpose(-2, -1)) / K.size(-1) * torch.exp(nn.Transformer.generate_square_subsequent_mask(K.size(2), device=self.device)),
                    "proj": self.Prj}
        
        return X_prime, params
    
    def __get_params(self, flattened, attn_scores, A, PA, P):
        # Compute needed params
        W_QK = torch.matmul(self.Q.state_dict()['weight'].reshape(self.n_heads, self.qk_dim, self.qk_dim), 
                    self.K.state_dict()['weight'].reshape(self.n_heads, self.qk_dim, self.qk_dim).transpose(-2, -1))
        W_V = self.V.state_dict()['weight'].reshape(self.n_heads, self.model_dim, self.model_dim).transpose(-2, -1)
        
        if flattened:
            # Flatten along heads
            W_QK = W_QK.reshape(W_QK.shape[0] *W_QK.shape[1], W_QK.shape[2])
            W_V = W_V.reshape(W_V.shape[0] *W_V.shape[1], W_V.shape[2])

            flat_attn_scores = attn_scores.reshape(attn_scores.shape[0], 
                                                attn_scores.shape[1] * attn_scores.shape[2],
                                                attn_scores.shape[3])
            
            flat_A = A.reshape(A.shape[0], A.shape[1] * A.shape[2], A.shape[3])

            flat_PA = PA.reshape(PA.shape[0], PA.shape[1] * PA.shape[2], PA.shape[3])
            flat_P = P.reshape(P.shape[0], P.shape[1] * P.shape[2], P.shape[3])

            return W_QK, W_V, flat_attn_scores, flat_A, flat_PA, flat_P
        
        return W_QK, W_V, attn_scores, A, PA, P