import types, torch
import torch.nn as nn

def _iter_resblocks(vit: nn.Module):
    tr = getattr(vit, "transformer", None)
    blocks = getattr(tr, "resblocks", None)
    assert tr is not None and isinstance(blocks, nn.ModuleList), "No transformer.resblocks"
    for b in blocks:
        yield b

def _rab_attention_with_capture(self, q_x, k_x=None, v_x=None, attn_mask=None):
    # match original signature; call nn.MultiheadAttention with weights on
    k_x = k_x if k_x is not None else q_x
    v_x = v_x if v_x is not None else q_x
    out, w = self.attn(
        q_x, k_x, v_x,
        need_weights=True,
        average_attn_weights=False,   # <-- get per-head weights
        attn_mask=attn_mask,
    )
    # Normalize shape to [B, H, L, S] (older PyTorch may return [B*H, L, S])
    if w.dim() == 3:
        B = q_x.shape[0]; H = self.attn.num_heads
        L, S = w.shape[-2:]
        w = w.view(B, H, L, S)
    self.attn._last_attn = w  # stash on the attention module
    return out

def _rab_forward_with_capture(self, q_x, k_x=None, v_x=None, attn_mask=None):
    # exact same computation as original, but calls our attention() above
    k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
    v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
    x = q_x + self.ls_1(_rab_attention_with_capture(self, q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
    x = x + self.ls_2(self.mlp(self.ln_2(x)))
    return x

def enable_clip_attention_capture(vit: nn.Module):
    """Patch every ResidualAttentionBlock in CLIP/OpenCLIP VisionTransformer."""
    n = 0
    for blk in _iter_resblocks(vit):
        if not isinstance(getattr(blk, "attn", None), nn.MultiheadAttention):
            continue
        blk.attention = types.MethodType(_rab_attention_with_capture, blk)
        blk.forward   = types.MethodType(_rab_forward_with_capture, blk)
        n += 1
    if n == 0:
        raise RuntimeError("No MultiheadAttention blocks found to patch.")
    return vit

@torch.no_grad()
def get_all_block_attentions_clip(vit: nn.Module):
    """Return [L, B, H, N, N] after a forward pass."""
    attns = []
    for i, blk in enumerate(_iter_resblocks(vit)):
        w = getattr(blk.attn, "_last_attn", None)
        if w is None:
            raise RuntimeError(f"Block {i} has no stored attention. Run a forward after patching.")
        attns.append(w)  # [B,H,N,N]
    return torch.stack(attns, dim=0)

@torch.no_grad()
def get_last_block_attention_clip(vit: nn.Module):
    return get_all_block_attentions_clip(vit)[-1]
