import torch
import torch.nn as nn
import torch.nn.functional as F

class TemperatureScaledAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, batch_first=True,
                 init_tau=1.0, learnable=True):
        super().__init__()
        self.mha = nn.MultiheadAttention(
            embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=batch_first
        )
        if learnable:
            self.tau = nn.Parameter(torch.tensor(float(init_tau)))
        else:
            self.register_buffer("tau", torch.tensor(float(init_tau)), persistent=False)

    def _prep_attn_mask(self, attn_mask, B, H, L, S, device, dtype):
        if attn_mask is None:
            return None

        if attn_mask.dim() == 2:
            return attn_mask.to(device=device, dtype=attn_mask.dtype)

        if attn_mask.dim() == 3:
            if attn_mask.size(0) == B:
                return attn_mask.to(device=device).repeat_interleave(H, dim=0)
            if attn_mask.size(0) == B * H:
                return attn_mask.to(device=device)

        raise RuntimeError(
            f"Unsupported attn_mask shape {tuple(attn_mask.shape)}; "
            f"expected (L,S) or (B,L,S) or (B*H,L,S) with L={L}, S={S}, B={B}, H={H}."
        )

    def forward(self, query, key, value, attn_mask=None, need_weights=True):

        B, L, _ = query.shape
        S = key.size(1)
        H = self.mha.num_heads

        attn_mask = self._prep_attn_mask(
            attn_mask, B=B, H=H, L=L, S=S, device=query.device, dtype=query.dtype
        )

        scale = 1.0 / torch.clamp(self.tau, min=1e-6)
        q_scaled = query * scale

        return self.mha(q_scaled, key, value, attn_mask=attn_mask, need_weights=need_weights)