import torch
import torch.utils.data as Data
from torch import nn
import numpy as np
import torch.nn.init as init

device = 'cuda' if torch.cuda.is_available() else 'cpu'
random_matrix = torch.randn(400, 400, device=device)

def get_attn_pad_mask(seq_q, seq_k):
    '''
        seq_q: [batch_size, seq_len]
        seq_k: [batch_size, seq_len]
        seq_len could be src_len or it could be tgt_len
        seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], True is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq, device):
    """
    seq: [batch_size, tgt_len]
    """
    batch_size, tgt_len = seq.size()
    attn_shape = (batch_size, tgt_len, tgt_len)
    subsequence_mask = torch.triu(torch.ones(attn_shape, dtype=torch.uint8, device=device), diagonal=1)
    return subsequence_mask

def attn_mask(X_input, device):
    '''
        X_input: [batch_size, tgt_len]
    '''
    dec_self_attn_pad_mask = get_attn_pad_mask(X_input, X_input) # [batch_size, tgt_len, d_model]
    dec_self_attn_subsequence_mask = get_attn_subsequence_mask(X_input, device) # [batch_size, tgt_len, d_model] 
    dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, d_model] 

    return dec_self_attn_mask


class MultiHeadAttention(nn.Module):
    def __init__(self, args, layer_idx):
        super(MultiHeadAttention, self).__init__()

        self.d_k = args.d_k
        self.d_v = args.d_v
        self.d_model = args.d_model
        self.layer_idx = layer_idx
        
        self.W_Q = nn.Linear(args.d_model, args.d_k, bias=False)
        self.W_K = nn.Linear(args.d_model, args.d_k, bias=False)
        self.W_V = nn.Linear(args.d_model, args.d_v, bias=False)
        self.W_O = nn.Linear(args.d_v, args.d_model, bias=False)
        self.layernorm = nn.LayerNorm(args.d_model)
        
        self.identity_matrix = torch.eye(self.d_model, device=device)
        self.identity_matrix.requires_grad = False 


        if self.layer_idx == 0:
            self.alpha = nn.Parameter(torch.tensor(0, dtype=torch.float32, device=self.W_Q.weight.device))
            self.beta = nn.Parameter(torch.tensor(args.vo_add, dtype=torch.float32, device=self.W_Q.weight.device))
        else:
            self.alpha = nn.Parameter(torch.tensor(args.qk_add, dtype=torch.float32, device=self.W_Q.weight.device))
            self.beta = nn.Parameter(torch.tensor(args.vo_add, dtype=torch.float32, device=self.W_Q.weight.device))
        
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        
        # (B, S, D) -proj-> (B, S, D_new)
        Q = self.W_Q(input_Q)  # Q: [batch_size, len_q, d_k]
        K = self.W_K(input_K)  # K: [batch_size, len_k, d_k]
        V = self.W_V(input_V)  # V: [batch_size, len_v(=len_k), d_v]

        identity_matrix_expanded = self.identity_matrix.expand(batch_size, input_Q.size(1), self.d_model, self.d_model)
        
        # scores : [batch_size, len_q, len_k]
        attn = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)

        QIK = torch.matmul(torch.matmul(input_Q.unsqueeze(2), identity_matrix_expanded).squeeze(2), input_K.transpose(-1, -2))

        QIK = QIK - torch.diag_embed(torch.diagonal(QIK, dim1=-2, dim2=-1))
        attn += self.alpha * QIK


        masked_attn = attn.masked_fill(attn_mask, -1e9)
        # Fills elements of self tensor with value where mask is True.
        softmax_attn = nn.Softmax(dim=-1)(masked_attn)

        # VO = self.W_O(V)
        VO = self.W_O(V) + self.beta * torch.matmul(input_V.unsqueeze(2), identity_matrix_expanded.transpose(-1, -2)).squeeze(2)
        output = torch.matmul(softmax_attn, VO)

        return self.layernorm(output + residual), softmax_attn

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, args, layer_idx):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(args.d_model, args.d_feedforward, bias=False),
            nn.ReLU(),
            nn.Linear(args.d_feedforward, args.d_model, bias=False)
        )
        self.layernorm = nn.LayerNorm(args.d_model)

    def forward(self, hidden_state):
        '''
        hidden_state: [batch_size, seq_len, d_model]
        '''
        residual = hidden_state
        output = self.fc(hidden_state)
        return self.layernorm(output + residual) # [batch_size, seq_len, d_model]

class DecoderLayer(nn.Module):
    def __init__(self, args, layer_idx):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(args, layer_idx)
        self.pos_ffn = PoswiseFeedForwardNet(args, layer_idx)

    def forward(self, hidden_state, dec_self_attn_mask):
        '''
            hidden_state: [batch_size, tgt_len, d_model]
            dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        '''
        # Attention
        # hidden_state: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        hidden_state, dec_self_attn = self.dec_self_attn(hidden_state, hidden_state, hidden_state, dec_self_attn_mask)

        hidden_state = self.pos_ffn(hidden_state)  # [batch_size, tgt_len, d_model]
        return hidden_state, dec_self_attn


class Decoder(nn.Module):
    def __init__(self, args, device):
        super(Decoder, self).__init__()
        self.device = device
        self.layers = nn.ModuleList([DecoderLayer(args, i) for i in range(args.n_layers)])



    def forward(self, hidden_state, dec_self_attn_mask):
        '''
            hidden_state: [batch_size, tgt_len]
        '''
        dec_self_attns = []
        for layer in self.layers:
            # hidden_state: [batch_size, tgt_len, d_model]
            # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
            hidden_state, dec_self_attn = layer(hidden_state, dec_self_attn_mask)
   
            dec_self_attns.append(dec_self_attn)

        return hidden_state, dec_self_attns

class Embedding(nn.Module):
    def __init__ (self, args, device):
        super(Embedding, self).__init__()
        self.device = device
        self.tgt_emb = nn.Embedding(args.vocab_size, args.d_model)
        self.pos_emb = nn.Embedding(args.seq_len, args.d_model)

    def forward(self, X_input):
        seq_len = X_input.size(1)
        pos = torch.arange(seq_len, dtype = torch.long, device = self.device)
        pos = pos.unsqueeze(0).expand_as(X_input)

        tgt_emb = self.tgt_emb(X_input)
        pos_emb = self.pos_emb(pos)
        emb = tgt_emb + pos_emb

        return emb

class myGPT_single_head_add_identity(nn.Module):
    def __init__(self, args, device):
        super(myGPT_single_head_add_identity, self).__init__()

        self.device = device
        self.embedding = Embedding(args, device)
        self.decoder = Decoder(args, device)
        self.projection = nn.Linear(args.d_model, args.vocab_size)
        # nn.init.normal_(self.projection.weight, mean=0.0, std=(self.projection.weight.size(1))**(-1))


    def forward(self, X_input):
        """
            dec_inputs: [batch_size, tgt_len]
        """
        hidden_state = self.embedding(X_input)

        dec_self_attn_mask = attn_mask(X_input, self.device)

        hidden_state, dec_self_attns = self.decoder(hidden_state, dec_self_attn_mask)
        
        dec_logits = self.projection(hidden_state)
        
        return dec_logits.view(-1, dec_logits.size(-1)), dec_self_attns
    




