# utils/timm_vit_attn_patch.py
import math, types, torch
import torch.nn as nn
import torch.nn.functional as F

def _iter_timm_blocks(trunk: nn.Module):
    """
    Yield transformer blocks from timm ViT trunks.
    Supports both nn.Sequential and nn.ModuleList.
    """
    blocks = getattr(trunk, "blocks", None)
    if blocks is None:
        raise RuntimeError("timm trunk has no `.blocks`")
    # iterate over children regardless of container type
    for b in blocks.children():
        if hasattr(b, "attn") and hasattr(b, "norm1"):
            yield b

def _timm_attn_forward_with_capture(self, x: torch.Tensor, attn_mask=None, return_attn: bool=False):
    """
    Replacement for timm Attention.forward that stores attention weights
    as self._last_attn with shape [B, H, N, N].
    """
    B, N, C = x.shape
    H = self.num_heads
    Dh = C // H

    qkv = self.qkv(x).reshape(B, N, 3, H, Dh).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]                   # [B,H,N,Dh]

    scale = getattr(self, "scale", Dh ** -0.5)
    attn = (q * scale) @ k.transpose(-2, -1)           # [B,H,N,N]
    if attn_mask is not None:
        attn = attn.masked_fill(attn_mask, float("-inf"))
    attn = F.softmax(attn, dim=-1)
    if hasattr(self, "attn_drop"):
        attn = self.attn_drop(attn)
    self._last_attn = attn

    y = attn @ v                                       # [B,H,N,Dh]
    y = y.transpose(1, 2).reshape(B, N, C)
    y = self.proj(y)
    if hasattr(self, "proj_drop"):
        y = self.proj_drop(y)
    return (y, attn) if return_attn else y

def enable_timm_vit_attention_capture(trunk: nn.Module):
    """
    Patch all attention modules inside a timm ViT trunk so each block
    stores its attention in `.attn._last_attn` after a forward pass.
    """
    n = 0
    for blk in _iter_timm_blocks(trunk):
        blk.attn.forward = types.MethodType(_timm_attn_forward_with_capture, blk.attn)
        n += 1
    if n == 0:
        raise RuntimeError("No attention modules were patched in timm trunk.")
    return trunk

@torch.no_grad()
def get_all_block_attentions_timm(trunk: nn.Module):
    """
    Collect attentions from all blocks after a forward.
    Returns: Tensor [L, B, H, N, N]
    """
    attns = []
    for i, blk in enumerate(_iter_timm_blocks(trunk)):
        A = getattr(blk.attn, "_last_attn", None)
        if A is None:
            raise RuntimeError(f"Block {i} has no stored attention. Run a forward after patching.")
        attns.append(A)
    return torch.stack(attns, dim=0)

@torch.no_grad()
def get_last_block_attention_timm(trunk: nn.Module):
    return get_all_block_attentions_timm(trunk)[-1]
