import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        """
        Args:
            embed_dim: Dimension of the embedding vectors.
            num_heads: Number of attention heads.
        """
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Ensure embedding dimension is divisible by the number of heads
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by num_heads"

        self.head_dim = embed_dim // num_heads

        # Define the layers to transform input into queries, keys, and values
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key   = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        # Final linear layer to merge attention outputs
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of multi-head self-attention.

        Args:
            x: Input tensor of shape (batch_size, sequence_length, embed_dim).

        Returns:
            Output tensor of shape (batch_size, sequence_length, embed_dim).
        """
        batch_size, seq_length, embed_dim = x.size()

     
        Q = self.query(x)  # (batch_size, seq_length, embed_dim)
        K = self.key(x)    # (batch_size, seq_length, embed_dim)
        V = self.value(x)  # (batch_size, seq_length, embed_dim)

        #    after reshaping: (batch_size, num_heads, seq_length, head_dim)
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        #    attention_scores: (batch_size, num_heads, seq_length, seq_length)
        attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / (self.head_dim ** 0.5)

        # 4) Apply softmax along the last dimension
        attention_weights = torch.softmax(attention_scores, dim=-1)

        #    out: (batch_size, num_heads, seq_length, head_dim)
        out = torch.matmul(attention_weights, V)

        #    (batch_size, seq_length, embed_dim).
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)
        out = self.fc_out(out)
        return out


# Example usage:
if __name__ == "__main__":
    # Suppose we have a batch of size 2, sequence length 5, and an embedding dimension of 16
    # Let's define 4 attention heads
    batch_size = 2
    seq_length = 5
    embed_dim = 16
    num_heads = 4

    x = torch.randn(batch_size, seq_length, embed_dim)

    mh_attention = MultiHeadSelfAttention(embed_dim=embed_dim, num_heads=num_heads)
    output = mh_attention(x)
    print("Output shape:", output.shape)  # should be [2, 5, 16]
