from typing import Any, Optional, Tuple
import torch
from torch import nn

class MHCA(nn.Module):
    def __init__(
            self, 
            dim: int, 
            num_heads: int = 1, 
            qkv_bias: bool = False, 
            qk_scale: Optional[float] = None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # Learnable input vector u
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02)

        # Q, K, V projections
        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)

    def forward(self, x: torch.Tensor, cls=None, **_: Any) -> Tuple[torch.Tensor, torch.Tensor]:
        B, N, C = x.shape

        if cls is not None:
            cls_token = cls
        else:
            cls_token = self.cls_token.expand(B, -1, -1)  # newly created class token

        # Project Q, K, V
        q = self.q_proj(cls_token).view(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # [B, H, 1, D]
        k = self.k_proj(x).view(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # [B, H, N, D]
        v = self.v_proj(x).view(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # [B, H, N, D]

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale   # [B, H, 1, N]
        attn = attn.softmax(dim=-1)

        out = (attn @ v).view(B, C)  # [B, H, 1, D] -> [B, C]

        return out