from functools import partial
import warnings
import torch
from typing import Optional
from torch import nn
import torch.nn.functional as F
from torch.nn.functional import scaled_dot_product_attention
from mmengine.model import BaseModule
from mmcv.cnn.bricks.drop import build_dropout
from .layer_scale import LayerScale

# try import flash-attn; set None if unavailable
try:
    from flash_attn import flash_attn_func, flash_attn_qkvpacked_func

    _HAS_FLASH_ATT = True
except Exception:
    flash_attn_func = None
    flash_attn_qkvpacked_func = None
    _HAS_FLASH_ATT = False


def _to_flash_shape(q, k, v):
    # accept shapes like (B, L, H, D) or (B, H, L, D) or (B, L, D*H) etc.
    # normalize to (B, L, H, D)
    def norm(x):
        if x.ndim == 4:
            # guess (B, L, H, D) is desired
            return x
        if x.ndim == 3:
            # (B, L, H*D) -> try to infer H by splitting last dim when possible (caller should avoid)
            raise RuntimeError(
                "Please provide q/k/v as (B, L, H, D) for flash_attn wrapper."
            )
        if x.ndim == 5:
            # (B, H, L, D, ?) unlikely
            raise RuntimeError("Unexpected q/k/v ndim")
        return x

    return norm(q), norm(k), norm(v)


def flash_compatible_attention(
    q,
    k,
    v,
    attn_mask: torch.Tensor = None,  # expects shape broadcastable to (B, nheads, Lq, Lk) OR None
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: float = None,
    use_qkvpacked: bool = False,
):
    """
    Wrapper: try to call flash_attn; fallback to torch.nn.functional.scaled_dot_product_attention.

    q,k,v expected shapes for PyTorch path: any shape accepted by F.scaled_dot_product_attention.
    For flash_attn usage, q/k/v should be (B, L, H, D) (contiguous).
    attn_mask: if provided, must be broadcastable to flash format or used in torch fallback.
    """

    # fallback target: PyTorch API expects (L, N, E) or (B, L, H, D) depending on your code;
    # but here we try to support the common (B, L, H, D) layout for both.

    # 1) if flash not available, use torch
    if not _HAS_FLASH_ATT:
        # very likely user has q shape (B,L,H,D) and F accepts q,k,v in the same layout starting torch 2.0
        return F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, scale=scale
        )

    # 2) prefer flash when dtype is fp16 or bf16 (best support)
    dtype = q.dtype
    if dtype not in (torch.float16, torch.bfloat16):
        # flash might support fp32 but often slower/unsupported; fallback
        return F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, scale=scale
        )

    # 3) ensure shapes are (B,L,H,D) and contiguous
    # If q is (B, H, L, D) convert:
    def ensure_BLHD(x):
        if x.ndim == 4 and x.shape[1] == q.shape[1] and x.shape[2] != q.shape[2]:
            # ambiguous; prefer (B,L,H,D)
            return x
        if x.ndim == 4 and x.shape[1] != q.shape[1] and x.shape[2] == q.shape[1]:
            # probably (B,H,L,D) -> transpose
            return x.permute(0, 2, 1, 3).contiguous()
        return x.contiguous()

    qf = ensure_BLHD(q)
    kf = ensure_BLHD(k)
    vf = ensure_BLHD(v)

    # optional softmax_scale: flash accepts softmax_scale param
    softmax_scale = None
    if scale is not None:
        softmax_scale = scale

    # If masks provided but not causal, flash supports 'bias' param broadcastable to (B, nheads, Lq, Lk)
    bias = attn_mask  # user must ensure shape align; else fallback
    try:
        if use_qkvpacked:
            # pack qkv -> (B, L, 3, H, D)
            qkv = torch.stack([qf, kf, vf], dim=2)  # (B,L,3,H,D)
            out = flash_attn_qkvpacked_func(
                qkv, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=is_causal
            )
        else:
            out = flash_attn_func(
                qf,
                kf,
                vf,
                dropout_p=dropout_p,
                softmax_scale=softmax_scale,
                causal=is_causal,
            )
        return out
    except Exception as e:
        # if flash failed (mask shape, etc.), fallback to torch's implementation
        # print warning optionally
        # print("flash_attn call failed, falling back to torch:", e)
        return F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, scale=scale
        )


def scaled_dot_product_attention_pyimpl(
    query, key, value, attn_mask=None, dropout_p=0.0, scale=None, is_causal=False
):
    scale = scale or query.size(-1) ** 0.5
    if is_causal and attn_mask is not None:
        attn_mask = torch.ones(query.size(-2), key.size(-2), dtype=torch.bool).tril(
            diagonal=0
        )
    if attn_mask is not None and attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.masked_fill(not attn_mask, -float("inf"))

    attn_weight = query @ key.transpose(-2, -1) / scale
    if attn_mask is not None:
        attn_weight += attn_mask
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, True)
    return attn_weight @ value


class MultiheadAttention(BaseModule):
    """Multi-head Attention Module.

    This module implements multi-head attention that supports different input
    dims and embed dims. And it also supports a shortcut from ``value``, which
    is useful if input dims is not the same with embed dims.

    Args:
        embed_dims (int): The embedding dimension.
        num_heads (int): Parallel attention heads.
        input_dims (int, optional): The input dimension, and if None,
            use ``embed_dims``. Defaults to None.
        attn_drop (float): Dropout rate of the dropout layer after the
            attention calculation of query and key. Defaults to 0.
        proj_drop (float): Dropout rate of the dropout layer after the
            output projection. Defaults to 0.
        dropout_layer (dict): The dropout config before adding the shortcut.
            Defaults to ``dict(type='Dropout', drop_prob=0.)``.
        qkv_bias (bool): If True, add a learnable bias to q, k, v.
            Defaults to True.
        qk_scale (float, optional): Override default qk scale of
            ``head_dim ** -0.5`` if set. Defaults to None.
        proj_bias (bool) If True, add a learnable bias to output projection.
            Defaults to True.
        v_shortcut (bool): Add a shortcut from value to output. It's usually
            used if ``input_dims`` is different from ``embed_dims``.
            Defaults to False.
        use_layer_scale (bool): Whether to use layer scale. Defaults to False.
        layer_scale_init_value (float or torch.Tensor): Init value of layer
            scale. Defaults to 0.
        init_cfg (dict, optional): The Config for initialization.
            Defaults to None.
    """

    def __init__(
        self,
        embed_dims,
        num_heads,
        input_dims=None,
        attn_drop=0.0,
        proj_drop=0.0,
        dropout_layer=dict(type="Dropout", drop_prob=0.0),
        qkv_bias=True,
        qk_scale=None,
        proj_bias=True,
        v_shortcut=False,
        use_layer_scale=False,
        layer_scale_init_value=0.0,
        init_cfg=None,
    ):
        super(MultiheadAttention, self).__init__(init_cfg=init_cfg)

        self.input_dims = input_dims or embed_dims
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.v_shortcut = v_shortcut

        self.head_dims = embed_dims // num_heads
        if qk_scale is not None:
            self.scaled_dot_product_attention = partial(
                scaled_dot_product_attention_pyimpl, scale=self.head_dims**-0.5
            )
        else:
            self.scaled_dot_product_attention = scaled_dot_product_attention

        self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
        self.attn_drop = attn_drop
        self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        self.out_drop = build_dropout(dropout_layer)

        if use_layer_scale:
            warnings.warn(
                "The `use_layer_scale` in `MultiheadAttention` will "
                "be deprecated. Please use `layer_scale_init_value` "
                "to control whether using layer scale or not."
            )

        if use_layer_scale or (layer_scale_init_value > 0):
            layer_scale_init_value = layer_scale_init_value or 1e-5
            self.gamma1 = LayerScale(
                embed_dims, layer_scale_init_value=layer_scale_init_value
            )
        else:
            self.gamma1 = nn.Identity()

    def forward(self, x):
        B, N, _ = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, self.head_dims)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_drop = self.attn_drop if self.training else 0.0
        x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
        x = x.transpose(1, 2).reshape(B, N, self.embed_dims)

        x = self.proj(x)
        x = self.out_drop(self.gamma1(self.proj_drop(x)))

        if self.v_shortcut:
            x = v.squeeze(1) + x
        return x

    # class CrossMultiheadAttention(BaseModule):
    # """
    # Multi-head Attention that supports self-attention and cross-attention.

    # Usage:
    #     # self-attention (原用法)
    #     out = attn(x)  # x: [B, N, C_in]，C_in defaults to embed_dims

    #     # cross-attention
    #     out = attn(query, kv=kv_tensor)  # query: [B, Nq, Cq], kv_tensor: [B, Nk, Ckv]
    # """

    # def __init__(
    #     self,
    #     embed_dims,
    #     num_heads,
    #     input_dims=None,
    #     kv_input_dims: Optional[int] = None,
    #     attn_drop=0.0,
    #     proj_drop=0.0,
    #     dropout_layer=dict(type="Dropout", drop_prob=0.0),
    #     qkv_bias=True,
    #     qk_scale=None,
    #     proj_bias=True,
    #     v_shortcut=False,
    #     use_layer_scale=False,
    #     layer_scale_init_value=0.0,
    #     init_cfg=None,
    # ):
    #     super(CrossMultiheadAttention, self).__init__(init_cfg=init_cfg)

    #     self.input_dims = input_dims or embed_dims
    #     self.kv_input_dims = kv_input_dims or self.input_dims
    #     self.embed_dims = embed_dims
    #     self.num_heads = num_heads
    #     self.v_shortcut = v_shortcut

    #     assert embed_dims % num_heads == 0, "embed_dims must be divisible by num_heads"
    #     self.head_dims = embed_dims // num_heads

    #     if qk_scale is not None:
    #         self.scaled_dot_product_attention = partial(
    #             scaled_dot_product_attention, scale=self.head_dims**-0.5
    #         )
    #     else:
    #         self.scaled_dot_product_attention = scaled_dot_product_attention

    #     # separate projections for query and kv (supports different input dims)
    #     # q: Cq -> embed_dims
    #     self.q_proj = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
    #     # kv: Ckv -> 2 * embed_dims (k and v)
    #     self.kv_proj = nn.Linear(self.kv_input_dims, embed_dims * 2, bias=qkv_bias)

    #     self.attn_drop = attn_drop
    #     self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
    #     self.proj_drop = nn.Dropout(proj_drop)

    #     self.out_drop = build_dropout(dropout_layer)

    #     if use_layer_scale or (layer_scale_init_value > 0):
    #         layer_scale_init_value = layer_scale_init_value or 1e-5
    #         self.gamma1 = LayerScale(
    #             embed_dims, layer_scale_init_value=layer_scale_init_value
    #         )
    #     else:
    #         self.gamma1 = nn.Identity()

    #     # For v_shortcut: if kv_input_dims != embed_dims we need projection to embed_dims
    #     if v_shortcut:
    #         if self.kv_input_dims != self.embed_dims:
    #             self.proj_v_shortcut = nn.Linear(
    #                 self.kv_input_dims, self.embed_dims, bias=False
    #             )
    #         else:
    #             self.proj_v_shortcut = None
    #     else:
    #         self.proj_v_shortcut = None

    # def forward(
    #     self, x: torch.Tensor, kv: Optional[torch.Tensor] = None
    # ) -> torch.Tensor:
    #     """
    #     Args:
    #         x: Query tensor, shape [B, Nq, Cq] where Cq == self.input_dims (or compatible)
    #         kv: Optional Key/Value tensor, shape [B, Nk, Ckv].
    #             If None, use x for k and v (self-attention).
    #     Returns:
    #         Tensor of shape [B, Nq, embed_dims]
    #     """
    #     Bq, Nq, _ = x.shape

    #     # project queries
    #     q = self.q_proj(x)  # [B, Nq, embed_dims]
    #     q = q.reshape(Bq, Nq, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
    #     # q: [B, num_heads, Nq, head_dims]

    #     # select kv input (self-attn if kv is None)
    #     kv_input = x if kv is None else kv
    #     Bk, Nk, _ = kv_input.shape
    #     assert Bk == Bq, "Batch size of query and kv must match"

    #     # project kv -> get k and v
    #     kv_proj = self.kv_proj(kv_input)  # [B, Nk, 2*embed_dims]
    #     kv_proj = kv_proj.reshape(Bk, Nk, 2, self.num_heads, self.head_dims).permute(
    #         2, 0, 3, 1, 4
    #     )
    #     # kv_proj: [2, B, num_heads, Nk, head_dims]
    #     k, v = kv_proj[0], kv_proj[1]  # each: [B, num_heads, Nk, head_dims]

    #     # attention
    #     attn_drop_p = self.attn_drop if self.training else 0.0
    #     out = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop_p)
    #     # out: [B, num_heads, Nq, head_dims]

    #     # merge heads
    #     out = out.transpose(1, 2).reshape(
    #         Bq, Nq, self.embed_dims
    #     )  # [B, Nq, embed_dims]

    #     # output projection
    #     out = self.proj(out)
    #     out = self.out_drop(self.gamma1(self.proj_drop(out)))

    #     # optional v_shortcut: add a shortcut from (projected) value to output.
    #     if self.v_shortcut:
    #         # construct a value shortcut in the original per-token ordering
    #         # v currently has shape [B, num_heads, Nk, head_dims]
    #         # we want a [B, Nq, embed_dims] to add to out
    #         # simplest and reasonable choice: use kv_input projected to embed_dims
    #         if self.proj_v_shortcut is not None:
    #             # project kv_input -> [B, Nk, embed_dims]
    #             v_short = self.proj_v_shortcut(kv_input)
    #         else:
    #             # kv_input already has embed_dims, use it directly
    #             v_short = kv_input
    #         # If Nq != Nk (query length differs), try broadcasting if Nk==1 or match lengths
    #         if v_short.shape[1] == Nq:
    #             out = out + v_short
    #         elif v_short.shape[1] == 1:
    #             out = out + v_short.repeat(1, Nq, 1)
    #         else:
    #             # lengths differ and cannot broadcast: fallback to no shortcut (or mean-pool)
    #             pooled = v_short.mean(dim=1, keepdim=True).repeat(1, Nq, 1)
    #             out = out + pooled

    #     return out


class CrossMultiheadAttention(BaseModule):
    """
    Multi-head Attention that supports self-attention and cross-attention.

    Usage:
        # self-attention (原用法)
        out = attn(x)  # x: [B, N, C_in]，C_in defaults to embed_dims

        # cross-attention
        out = attn(query, kv=kv_tensor)  # query: [B, Nq, Cq], kv_tensor: [B, Nk, Ckv]
    """

    def __init__(
        self,
        embed_dims,
        num_heads,
        input_dims=None,
        kv_input_dims: Optional[int] = None,
        attn_drop=0.0,
        proj_drop=0.0,
        dropout_layer=dict(type="Dropout", drop_prob=0.0),
        qkv_bias=True,
        qk_scale=None,
        proj_bias=True,
        v_shortcut=False,
        use_layer_scale=False,
        layer_scale_init_value=0.0,
        init_cfg=None,
    ):
        super(CrossMultiheadAttention, self).__init__(init_cfg=init_cfg)

        self.input_dims = input_dims or embed_dims
        self.kv_input_dims = kv_input_dims or self.input_dims
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.v_shortcut = v_shortcut

        assert embed_dims % num_heads == 0, "embed_dims must be divisible by num_heads"
        self.head_dims = embed_dims // num_heads

        # 如果你想用自定义的 scale，可以用 qk_scale 覆盖
        if qk_scale is not None:
            self.scale = qk_scale
        else:
            self.scale = self.head_dims**-0.5

        # separate projections for query and kv (supports different input dims)
        # q: Cq -> embed_dims
        self.q_proj = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
        # kv: Ckv -> 2 * embed_dims (k and v)
        self.kv_proj = nn.Linear(self.kv_input_dims, embed_dims * 2, bias=qkv_bias)

        self.attn_drop_p = attn_drop
        self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        self.out_drop = build_dropout(dropout_layer)

        if use_layer_scale or (layer_scale_init_value > 0):
            layer_scale_init_value = layer_scale_init_value or 1e-5
            self.gamma1 = LayerScale(
                embed_dims, layer_scale_init_value=layer_scale_init_value
            )
        else:
            self.gamma1 = nn.Identity()

        # For v_shortcut: if kv_input_dims != embed_dims we need projection to embed_dims
        if v_shortcut:
            if self.kv_input_dims != self.embed_dims:
                self.proj_v_shortcut = nn.Linear(
                    self.kv_input_dims, self.embed_dims, bias=False
                )
            else:
                self.proj_v_shortcut = None
        else:
            self.proj_v_shortcut = None

        # 注意：不再使用原来的 scaled_dot_product_attention，而使用类内部实现
        # self.scaled_dot_product_attention = ...

    def _scaled_dot_product_attention_with_weights(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        dropout_p: float = 0.0,
    ):
        """
        手写一个 scaled dot-product attention，返回 (out, attn_weights)

        q: [B, num_heads, Nq, head_dim]
        k: [B, num_heads, Nk, head_dim]
        v: [B, num_heads, Nk, head_dim]
        attn_mask: optional, broadcastable to [B, num_heads, Nq, Nk]
        """
        # [B, num_heads, Nq, Nk]
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if attn_mask is not None:
            # 假设 mask 已经是 [-inf, 0] 这样的加性 mask 或 boolean mask
            if attn_mask.dtype == torch.bool:
                attn_scores = attn_scores.masked_fill(~attn_mask, float("-inf"))
            else:
                attn_scores = attn_scores + attn_mask

        attn_weights = F.softmax(attn_scores, dim=-1)

        if dropout_p > 0.0 and self.training:
            attn_weights = F.dropout(attn_weights, p=dropout_p)

        # [B, num_heads, Nq, head_dim]
        out = torch.matmul(attn_weights, v)
        return out, attn_weights

    @torch.no_grad()
    def _maybe_save_attention_maps(
        self,
        attn_weights: torch.Tensor,
        q_len: int,
        kv_len: int,
        batch_size: int,
        save_dir: str = "output/attention_map",
    ):
        """
        将每个 kv token 的 attention map 可视化并保存。

        attn_weights: [B, num_heads, Nq, Nk]
        q_len: Nq
        kv_len: Nk
        batch_size: B
        条件：当 kv_len == 3 且 q_len 是平方数时才执行。
        """
        if kv_len != 3:
            return

        import math
        import os
        import matplotlib.pyplot as plt

        # 检查 q_len 是不是平方数
        H = int(math.sqrt(q_len))
        if H * H != q_len:
            return
        W = H

        # [B, num_heads, Nq, Nk] -> [B, Nq, Nk] (平均头)
        attn = attn_weights.mean(dim=1)  # [B, Nq, Nk]

        os.makedirs(save_dir, exist_ok=True)
        # 当前已有文件数
        existing_files = [
            f for f in os.listdir(save_dir) if os.path.isfile(os.path.join(save_dir, f))
        ]
        file_counter = len(existing_files)

        attn_cpu = attn.float().detach().cpu().numpy()  # [B, Nq, Nk]

        for b in range(batch_size):
            for k_idx in range(kv_len):
                # [Nq]
                map_flat = attn_cpu[b, :, k_idx]
                # reshape 为 [H, W]
                attn_map = map_flat.reshape(H, W)

                # 归一化到 0-1 增强对比
                min_val = attn_map.min()
                max_val = attn_map.max()
                if max_val > min_val:
                    attn_norm = (attn_map - min_val) / (max_val - min_val)
                else:
                    attn_norm = attn_map  # 常数图就随它去

                # 文件名根据目录中已有文件数递增
                fname = f"{file_counter:06d}_b{b}_k{k_idx}.png"
                file_counter += 1
                fpath = os.path.join(save_dir, fname)

                plt.figure()
                plt.imshow(attn_norm, cmap="viridis")
                plt.axis("off")
                plt.tight_layout(pad=0)
                plt.savefig(fpath, bbox_inches="tight", pad_inches=0)
                plt.close()

    def forward(
        self, x: torch.Tensor, kv: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x: Query tensor, shape [B, Nq, Cq] where Cq == self.input_dims (or compatible)
            kv: Optional Key/Value tensor, shape [B, Nk, Ckv].
                If None, use x for k and v (self-attention).
        Returns:
            Tensor of shape [B, Nq, embed_dims]
        """
        Bq, Nq, _ = x.shape

        # project queries
        q = self.q_proj(x)  # [B, Nq, embed_dims]
        q = q.reshape(Bq, Nq, self.num_heads, self.head_dims).permute(0, 2, 1, 3)
        # q: [B, num_heads, Nq, head_dims]

        # select kv input (self-attn if kv is None)
        kv_input = x if kv is None else kv
        Bk, Nk, _ = kv_input.shape
        assert Bk == Bq, "Batch size of query and kv must match"

        # project kv -> get k and v
        kv_proj = self.kv_proj(kv_input)  # [B, Nk, 2*embed_dims]
        kv_proj = kv_proj.reshape(Bk, Nk, 2, self.num_heads, self.head_dims).permute(
            2, 0, 3, 1, 4
        )
        # kv_proj: [2, B, num_heads, Nk, head_dims]
        k, v = kv_proj[0], kv_proj[1]  # each: [B, num_heads, Nk, head_dims]

        # attention + 拿到权重
        attn_drop_p = self.attn_drop_p if self.training else 0.0
        out, attn_weights = self._scaled_dot_product_attention_with_weights(
            q, k, v, attn_mask=None, dropout_p=attn_drop_p
        )
        # out: [B, num_heads, Nq, head_dims]

        # 这里做可视化（注意：不需要梯度，所以不影响训练）
        # kv_input.shape: [B, Nk, Ckv]；q.shape 当前 Nq 已知
        # 条件：kv_len == 3 且 Nq == H*W
        # 为了不影响训练速度，这部分放在 no_grad 里
        with torch.no_grad():
            self._maybe_save_attention_maps(
                attn_weights=attn_weights, q_len=Nq, kv_len=Nk, batch_size=Bq
            )

        # merge heads
        out = out.transpose(1, 2).reshape(
            Bq, Nq, self.embed_dims
        )  # [B, Nq, embed_dims]

        # output projection
        out = self.proj(out)
        out = self.out_drop(self.gamma1(self.proj_drop(out)))

        # optional v_shortcut: add a shortcut from (projected) value to output.
        if self.v_shortcut:
            # construct a value shortcut in the original per-token ordering
            # v currently has shape [B, num_heads, Nk, head_dims]
            # we want a [B, Nq, embed_dims] to add to out
            # simplest and reasonable choice: use kv_input projected to embed_dims
            if self.proj_v_shortcut is not None:
                # project kv_input -> [B, Nk, embed_dims]
                v_short = self.proj_v_shortcut(kv_input)
            else:
                # kv_input already has embed_dims, use it directly
                v_short = kv_input
            # If Nq != Nk (query length differs), try broadcasting if Nk==1 or match lengths
            if v_short.shape[1] == Nq:
                out = out + v_short
            elif v_short.shape[1] == 1:
                out = out + v_short.repeat(1, Nq, 1)
            else:
                # lengths differ and cannot broadcast: fallback to no shortcut (or mean-pool)
                pooled = v_short.mean(dim=1, keepdim=True).repeat(1, Nq, 1)
                out = out + pooled

        return out
