from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.backends.cuda.enable_flash_sdp(False)
'''
class MultiHeadAttention(nn.Module):
    
    def __init__(
        self,
        q_dim:      int,
        kv_dim:     int,
        qk_out_dim: Optional[int] = None,
        v_out_dim:  Optional[int] = None,
        output_dim: Optional[int] = None,
        heads:      int = 1,
        dropout:    float = 0.0
        ):
        
        super().__init__()

        if qk_out_dim is None:
            qk_out_dim = q_dim
        if v_out_dim is None:
            v_out_dim  = qk_out_dim
        if output_dim is None:
            output_dim = v_out_dim

        self.heads       = heads
        self.qk_head_dim = qk_out_dim // heads
        self.v_head_dim  = v_out_dim // heads

        self.qeury = nn.Linear(q_dim, qk_out_dim)
        self.key   = nn.Linear(kv_dim, qk_out_dim)
        self.value = nn.Linear(kv_dim, v_out_dim)

        self.projection = nn.Linear(v_out_dim, output_dim)
        self.dropout    = nn.Dropout(dropout)
    
    def forward(
        self,
        x_q: torch.Tensor,
        x_kv: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
        ):
        
        batch = x_q.shape[0]
        query_len, key_len, value_len = x_q.shape[1], x_kv.shape[1], x_kv.shape[1]

        queries = self.qeury(x_q)
        keys    = self.key(x_kv)
        values  = self.value(x_kv)
        

        # [N, len, embed_size] --> [N, len, heads, head_dim]
        queries = queries.reshape(batch, query_len, self.heads, self.qk_head_dim)
        keys    = keys.reshape(batch, key_len, self.heads, self.qk_head_dim)
        values  = values.reshape(batch, value_len, self.heads, self.v_head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        if attention_mask is not None:
            energy = energy.masked_fill(attention_mask == 0, float("-1e20"))
        attention = torch.softmax(energy / (self.qk_head_dim ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        attention = self.dropout(attention)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            batch, query_len, self.heads * self.v_head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)

        out = self.projection(out)
        # (N, query_len, embed_size)
        return out
'''

class SelfAttention(nn.Module):

    def __init__(
        self,
        q_dim: int,
        qk_out_dim: Optional[int] = None,
        v_out_dim: Optional[int] = None,
        heads: int = 1,
        dropout: float = 0.0):

        super().__init__()

        self.norm = nn.LayerNorm(q_dim)
        self.attention = MultiHeadAttention(
            q_dim=q_dim,
            kv_dim=q_dim,
            qk_out_dim=qk_out_dim,
            v_out_dim=v_out_dim,
            output_dim=q_dim,
            heads=heads,
            dropout=dropout
        )

    def forward(self, x_q: torch.Tensor, attention_mask: torch.Tensor = None):
        x_q = self.norm(x_q)
        return self.attention(x_q=x_q, x_kv=x_q, attention_mask=attention_mask)


class CrossAttention(nn.Module):

    def __init__(
        self,
        q_dim: int,
        kv_dim: int,
        qk_out_dim: Optional[int] = None,
        v_out_dim: Optional[int] = None,
        heads: int = 1,
        dropout: float = 0.0,
        ):

        super().__init__()
        self.q_norm = nn.LayerNorm(q_dim)
        self.kv_norm = nn.LayerNorm(kv_dim)
        self.attention = MultiHeadAttention(
            q_dim=q_dim,
            kv_dim=kv_dim,
            qk_out_dim=qk_out_dim,
            v_out_dim=v_out_dim,
            output_dim=q_dim,
            heads=heads,
            dropout=dropout
        )

    def forward(self, x_q, x_kv, attention_mask=None):
        x_q = self.q_norm(x_q)
        x_kv = self.kv_norm(x_kv)
        return self.attention(x_q, x_kv, attention_mask=attention_mask)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, kv_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.kv_dim = kv_dim

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(kv_dim, embed_dim)
        self.v_proj = nn.Linear(kv_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, attention_mask=None):
        batch_size, q_length, embed_dim = query.size()
        batch_size, k_length, kv_dim = key.size()
        #require all the kqv to be 3D tensor
        # Project query, key, and value
        query = self.q_proj(query)
        key = self.k_proj(key)
        value = self.v_proj(value)
        
        # Reshape and transpose for multi-head attention
        query = query.view(batch_size, q_length, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, k_length, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, k_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply scaled dot-product attention
        attn_output = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=self.dropout.p)
        
        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_length, embed_dim)
        attn_output = self.out_proj(attn_output)
        
        return attn_output