# utils/dinov_attn_capture.py
import math
import types
import torch
import torch.nn.functional as F
from torch import Tensor
from typing import List, Tuple

def _forward_attn_with_capture(self, x, is_causal: bool = False):
    # Matches dinov2.layers.attention.Attention.forward signature
    B, N, C = x.shape
    H = self.num_heads
    Dh = C // H

    qkv = self.qkv(x).reshape(B, N, 3, H, Dh)
    q, k, v = torch.unbind(qkv, dim=2)            # [B, N, H, Dh] each
    q = q.transpose(1, 2)                          # [B, H, N, Dh]
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    scale = getattr(self, "scale", Dh ** -0.5)

    # Explicit attention (store weights)
    attn = (q * scale) @ k.transpose(-2, -1)       # [B, H, N, N]
    if is_causal:
        # simple causal mask if ever used in your codepath
        mask = torch.triu(torch.ones(N, N, device=attn.device, dtype=torch.bool), diagonal=1)
        attn = attn.masked_fill(mask, float("-inf"))
    attn = F.softmax(attn, dim=-1)
    self._last_attn = attn

    y = attn @ v                                   # [B, H, N, Dh]
    y = y.transpose(1, 2).contiguous().view(B, N, C)
    y = self.proj(y)
    y = self.proj_drop(y)
    return y

# RoPE-related functions:
def rope_rotate_half(x: Tensor) -> Tensor:
    # x:   [ x0  x1  x2  x3  x4  x5]
    # out: [-x3 -x4 -x5  x0  x1  x2]
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)

def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
    # x:   [..., D], eg [x0,     x1,   x2,   x3,   x4,   x5]
    # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
    # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
    return (x * cos) + (rope_rotate_half(x) * sin)

def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
    # All operations will use the dtype of rope, the output is cast back to the dtype of q and k
    q_dtype = q.dtype
    k_dtype = k.dtype
    sin, cos = rope
    rope_dtype = sin.dtype
    q = q.to(dtype=rope_dtype)
    k = k.to(dtype=rope_dtype)
    N = q.shape[-2]
    prefix = N - sin.shape[-2]
    assert prefix >= 0
    q_prefix = q[:, :, :prefix, :]
    q = rope_apply(q[:, :, prefix:, :], sin, cos)  # [B, head, hw, D//head]
    q = torch.cat((q_prefix, q), dim=-2)  # [B, head, N, D//head]
    k_prefix = k[:, :, :prefix, :]
    k = rope_apply(k[:, :, prefix:, :], sin, cos)  # [B, head, hw, D//head]
    k = torch.cat((k_prefix, k), dim=-2)  # [B, head, N, D//head]
    q = q.to(dtype=q_dtype)
    k = k.to(dtype=k_dtype)
    return q, k

def _forward_mem_eff_with_capture(self, x, attn_bias=None, rope=None):
    # Matches dinov2.layers.attention.MemEffAttention.forward signature
    # We bypass xFormers to compute/store attn explicitly.
    B, N, C = x.shape
    H = self.num_heads
    Dh = C // H

    qkv = self.qkv(x).reshape(B, N, 3, H, Dh)
    q, k, v = torch.unbind(qkv, dim=2)            # [B, N, H, Dh] each
    q = q.transpose(1, 2)                          # [B, H, N, Dh]
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    if rope is not None:
        q, k = self.apply_rope(q, k, rope)

    # Optional attn_bias (BlockDiagonalMask) not supported in explicit math here.
    # If you need it, convert it to a dense [B*H, N, N] mask and add before softmax.
    scale = (Dh ** -0.5)
    attn = (q * scale) @ k.transpose(-2, -1)       # [B, H, N, N]
    if attn_bias is not None:
        # Minimal bias support: broadcast to [B,H,N,N] if possible.
        # xFormers' BlockDiagonalMask is complex; for most single-sequence cases
        # attn_bias can be None. Extend if you truly use nested tensors.
        try:
            attn = attn + attn_bias.to(attn.dtype)
        except Exception:
            pass
    attn = F.softmax(attn, dim=-1)
    self._last_attn = attn

    y = attn @ v                                   # [B, H, N, Dh]
    y = y.transpose(1, 2).reshape(B, N, C)
    y = self.proj(y)
    y = self.proj_drop(y)
    return y

def enable_dino_attention_capture(backbone):
    """
    Patch dinov2 Attention + MemEffAttention so they store `._last_attn` during forward.
    Works for:
      - DinoVisionTransformer with Block(attn_class=MemEffAttention)  (your default)
      - Any submodule that matches those classes.
    """
    # Detect classes from the loaded backbone so we don't rely on imports/paths.
    AttentionType = None
    MemEffAttentionType = None
    # Find a Block with `.attn`
    candidate_attn = None
    for m in backbone.modules():
        if hasattr(m, "attn") and hasattr(m, "norm1"):
            candidate_attn = m.attn
            break
    if candidate_attn is None:
        raise RuntimeError("Could not find a Block with an `.attn` submodule.")

    # Grab the actual attention types from the model
    AttentionType = type(candidate_attn)
    # If MemEffAttention is used, its type is also the same as candidate_attn (in your code)
    # But some builds might differ; scan to see if an alternative exists.
    for m in backbone.modules():
        if hasattr(m, "qkv") and hasattr(m, "num_heads") and hasattr(m, "forward"):
            if type(m) is not AttentionType:
                MemEffAttentionType = type(m)
                break

    n_patched = 0
    for m in backbone.modules():
        if hasattr(m, "qkv") and hasattr(m, "num_heads") and hasattr(m, "forward"):
            # Heuristic: decide which stub to use based on forward signature name
            f = m.forward
            code = getattr(f, "__code__", None)
            if code and "is_causal" in code.co_varnames:
                # Looks like Attention.forward(self, x, is_causal=False)
                m.forward = types.MethodType(_forward_attn_with_capture, m)
            else:
                # Looks like MemEffAttention.forward(self, x, attn_bias=None)
                m.forward = types.MethodType(_forward_mem_eff_with_capture, m)
            n_patched += 1

    if n_patched == 0:
        raise RuntimeError("No attention modules were patched. Check the backbone structure.")
    return backbone

def get_last_block_attention(backbone):
    """
    Return last block's stored attention: Tensor [B, H, N, N].
    Supports chunked and non-chunked layouts.
    """
    # Try non-chunked: ModuleList of Blocks
    blocks = getattr(backbone, "blocks", None)
    if blocks is None:
        raise RuntimeError("Backbone has no `.blocks`")

    # If chunked, `blocks` is a ModuleList of BlockChunk (each is an nn.ModuleList)
    def _iter_blocks(mod):
        if isinstance(mod, torch.nn.ModuleList):
            for sub in mod:
                yield from _iter_blocks(sub)
        else:
            # sub may be a Block, or nn.Identity inside a chunk; we only care for those with `.attn`
            if hasattr(mod, "attn") and hasattr(mod, "norm1"):
                yield mod

    all_blocks = list(_iter_blocks(blocks))
    if not all_blocks:
        raise RuntimeError("Could not collect any Blocks with `.attn`.")
    last_attn_mod = all_blocks[-1].attn
    attn = getattr(last_attn_mod, "_last_attn", None)
    if attn is None:
        raise RuntimeError("No attention captured yet. Run a forward pass after patching.")
    return attn

def get_all_block_attentions(backbone):
    """
    Return attention maps from *all* blocks: Tensor [L, B, H, N, N]
      L = number of blocks
      B = batch size
      H = number of heads
      N = sequence length
    """
    blocks = getattr(backbone, "blocks", None)
    if blocks is None:
        raise RuntimeError("Backbone has no `.blocks`")

    def _iter_blocks(mod):
        if isinstance(mod, torch.nn.ModuleList):
            for sub in mod:
                yield from _iter_blocks(sub)
        else:
            if hasattr(mod, "attn") and hasattr(mod, "norm1"):
                yield mod

    all_blocks = list(_iter_blocks(blocks))
    if not all_blocks:
        raise RuntimeError("Could not collect any Blocks with `.attn`.")

    attns = []
    for i, blk in enumerate(all_blocks):
        attn = getattr(blk.attn, "_last_attn", None)
        if attn is None:
            raise RuntimeError(
                f"No attention captured yet in block {i}. "
                "Run a forward pass after patching."
            )
        attns.append(attn)  # each [B, H, N, N]

    return torch.stack(attns, dim=0)  # [L, B, H, N, N]