import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads=4, dropout=0.1):
        super(AttentionBlock, self).__init__()
        
        #Bull Attention
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        #LayerNorm and Dropout
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        
        #Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
    
    def forward(self, query, key, value, mask=None):
        """
        Forward Passing
        
        Args:
            query, key, value: Attention Mechanism Input
            mask: Mask for variable length sequences
        """
        #Add location code
        seq_len = query.size(1)
        pos_embeddings = self._get_positional_encoding(seq_len, query.size(2), query.device)
        query = query + pos_embeddings
        
        #Bull Attention
        attn_output, _ = self.attention(
            query=query,
            key=key,
            value=value,
            key_padding_mask=None if mask is None else ~mask
        )
        
        #Residual connection and normalization
        query = self.layer_norm1(query + self.dropout(attn_output))
        
        #Feedforward network
        ffn_output = self.ffn(query)
        output = self.layer_norm2(query + self.dropout(ffn_output))
        
        return output
    
    def _get_positional_encoding(self, seq_len, d_model, device):
        """Generate Location Codes"""
        positions = torch.arange(seq_len, device=device).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-math.log(10000.0) / d_model))
        
        pos_encoding = torch.zeros(1, seq_len, d_model, device=device)
        pos_encoding[0, :, 0::2] = torch.sin(positions * div_term)
        pos_encoding[0, :, 1::2] = torch.cos(positions * div_term)
        
        return pos_encoding