import torch 
import torch.nn as nn

from torch import Tensor
import math
from typing import Optional, Tuple
from torch.nn import Dropout
from torch.nn.functional import softmax

class MultiheadAttention(nn.Module):
    def __init__(
        self, 
        embed_dim, 
        num_heads,
        kdim=None, 
        vdim=None, 
        dropout=0.0, 
        bias=True
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
        
        self.num_heads = num_heads
        self.dropout_module = Dropout(dropout)
        self.head_dim = embed_dim // num_heads
        assert(self.head_dim *self. num_heads == embed_dim), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        assert self.qkv_same_dim, "self-attention requires q, k and v to have same sizes"

        self.k_proj = nn.Linear(self.kdim, self.embed_dim, bias=bias)
        self.v_proj = nn.Linear(self.vdim, self.embed_dim, bias=bias)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        if self.qkv_same_dim:   # if not using self-attention, then add else condition
            nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
        nn.init.xavier_uniform_(self.out_proj.weight)
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0.)
    
    def forward(
            self, 
            query, 
            key: Optional[Tensor], 
            value: Optional[Tensor], 
            attn_bias: Optional[Tensor], 
            key_padding_mask:Optional[Tensor] = None, 
            need_weights: bool = True,
            before_softmax: bool = False, 
            need_head_weights: bool = False
    ):
        if need_head_weights:
            need_weights = True
        tgt_len, bsz, feature_dim = query.size()

        assert feature_dim == self.embed_dim, f"query's embed_dim {feature_dim} doesn't match model's embed_dim {self.embed_dim}"
        if key is not None: 
            src_len = key.size(0)
            assert key.size() == (src_len, bsz, self.embed_dim)
        if value is not None:
            assert value.size() == (src_len, bsz, self.embed_dim)
        
        q = self.q_proj(query)  # tgt_len, bsz, embed_dim
        k = self.k_proj(query)  # tgt_len, bsz, embed_dim
        v = self.v_proj(query)  # tgt_len, bsz, embed_dim
        q *= self.scaling

        q = (
            q.contiguous()
            .view(tgt_len, bsz * self.num_heads, self.head_dim)
            .transpose(0, 1)
        )   # embed_dim = num_heads x head_dim
        if k is not None:
            k = (
                k.contiguous()
                .view(-1, bsz * self.num_heads, self.head_dim)
                .transpose(0, 1)
            )   # bsz * num_heads, src_len , head_dim

        if v is not None:
            v = (
                v.contiguous()
                .view(-1, bsz * self.num_heads, self.head_dim)
                .transpose(0, 1)
            )   # bsz * num_heads, src_len , head_dim
        assert k is not None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        attn_weights = torch.bmm(q, k.transpose(1, 2))   # bsz * num_heads, tgt_len, src_len
        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_bias is not None:
            attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
        
        if key_padding_mask is not None: 
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
                float("-inf")
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if before_softmax:
            return attn_weights, v
        attn_weights= softmax(attn_weights, dim = -1)
        # attn_weights = attn_weights_float.type_as(attn_weights)
        attn_weights = self.dropout_module(attn_weights)

        assert v is not None
        attn = torch.bmm(attn_weights, v)
        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
        attn = self.out_proj(attn)  # tgt_len, bsz, embed_dim

        # if need_weights:
        #     attn_weights = attn_weights.view(
        #         bsz, self.num_heads, tgt_len, src_len
        #     ).transpose(1, 0)
        #     if not need_head_weights:
        #         # average attention weights over heads
        #         attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
        

        







        