import logging
import math
from typing import Optional, Tuple
import warnings
import torch
from collections import defaultdict
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.modules import TransformerDecoder as OrgTransformerDecoder
from torch.nn.modules import TransformerDecoderLayer  as OrgTransformerDecoderLayer
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.functional import scaled_dot_product_attention, _in_projection_packed, _in_projection, linear
from torch.overrides import (has_torch_function, handle_torch_function)

logger = logging.getLogger(__name__)

BOS, EOS, PAD, MASK = '[BOS]', '[EOS]', '[PAD]', '[MASK]'

class TransformerDecoderLayer(OrgTransformerDecoderLayer):

    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation=F.relu,
                 layer_norm_eps=0.00001,
                 batch_first=False,
                 norm_first=False,
                 device=None,
                 dtype=None,
                 output_attention=False) -> None:

        self.output_attention = output_attention
        # self.attention_weights = []
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation,
                         layer_norm_eps, batch_first, norm_first, device,
                         dtype)
        self.multihead_attn = MultiheadAttention(d_model,
                                                 nhead,
                                                 dropout=dropout,
                                                 batch_first=batch_first,
                                                 **factory_kwargs)
        self.self_attn = MultiheadAttention(d_model,
                                                 nhead,
                                                 dropout=dropout,
                                                 batch_first=batch_first,
                                                 **factory_kwargs)

    def forward(self,
                tgt: Tensor,
                memory: Tensor,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:

        x = tgt
        if self.norm_first:
            sa_block_x, sa_block_attention = self._sa_block(self.norm1(x), tgt_mask,
                                   tgt_key_padding_mask)
            x = x + sa_block_x
            mha_block_x, mha_block_attention = self._mha_block(
                self.norm2(x), memory, memory_mask, memory_key_padding_mask)
            # x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
            x = x + mha_block_x
            x = x + self._ff_block(self.norm3(x))
        else:
            sa_block_x, sa_block_attention = self._sa_block(x, tgt_mask, tgt_key_padding_mask)
            x = self.norm1(x + sa_block_x)
            mha_block_x, mha_block_attention = self._mha_block(
                x, memory, memory_mask, memory_key_padding_mask)
            # x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
            x = self.norm2(x + mha_block_x)
            x = self.norm3(x + self._ff_block(x))

        return x, {'cross_attn':mha_block_attention, 'decoder_self_attn': sa_block_attention}
    
    # self-attention block
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x, attention_weight = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=self.output_attention)
        return self.dropout1(x), attention_weight

    # multihead attention block
    def _mha_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor],
                   key_padding_mask: Optional[Tensor]) -> Tensor:
        x, attention_weight = self.multihead_attn(
            x,
            mem,
            mem,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=self.output_attention)

        return self.dropout2(x), attention_weight


class TransformerDecoder(OrgTransformerDecoder):

    def __init__(self,
                 decoder_layer,
                 num_layers,
                 norm=None,
                 output_attention=False):
        self.output_attention = output_attention
        super().__init__(decoder_layer, num_layers, norm)

    def forward(self,
                tgt: Tensor,
                memory: Tensor,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:

        output = tgt
        self.attention_weights = defaultdict(list)

        for mod in self.layers:
            output, attn_dict = mod(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
            if self.output_attention:
                for key in attn_dict:
                    self.attention_weights[key].append(attn_dict[key])

        if self.norm is not None:
            output = self.norm(output)
        if self.output_attention:
            return output, self.attention_weights
        else:
            return output, None


class PositionalEncoding(nn.Module):

    def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) /
                        emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):

    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
