import math
import types
import torch
from torch import nn
import torch.nn.functional as F

class CapiWrapper(nn.Module):
    """
    A thin wrapper around the CAPI model from Facebook:
    - We load the model from torch.hub.
    - We define a `head` attribute for linear probing.
    - The forward returns global_repr -> we pass it to self.head.
    """
    def __init__(self, capi_model: nn.Module, num_classes: int, features: str, embed_dim: int = 1024):
        super().__init__()
        self.capi_model = capi_model  # the backbone from Torch Hub
        # By default, let's define a simple linear head (like an nn.Linear(embed_dim, num_classes)).
        # Or set it to nn.Identity if you plan to override externally.
        self.head = nn.Linear(embed_dim, num_classes)
        self.features = features

    def forward(self, x: torch.Tensor, return_backbone_features = False):
        # The CAPI model typically returns (global_repr, registers, feature_map).
        global_repr, registers, feature_map = self.capi_model(x)
        # Then pass global_repr to the linear head
        if self.features == 'cls':
            out = self.head(global_repr)
        else:
            feature_map = feature_map.view(feature_map.size(0), -1, feature_map.size(-1))
            out = self.head(feature_map)
        if return_backbone_features:
            if self.features == 'cls':
                return out, global_repr
            else:
                return out, feature_map
        return out

def enable_attention_capture(capi_model):
    """
    Monkey-patch all capi_model Attention modules to compute and store attention
    matrices during forward() as `module._last_attn` (shape [B,H,Nq,Nk]).
    Works for both encoder (self-attn) and decoder (cross-attn).
    """
    # locate the Attention class type from the model itself
    # (residual1.fn is the Attention inside each Block)
    try:
        sample_attn = capi_model.encoder.blocks[0].residual1.fn
        AttentionType = type(sample_attn)
    except Exception as e:
        raise RuntimeError("Could not locate Attention modules in CAPI model") from e

    def patched_forward(self, x, coords, context=None, context_coords=None):
        # Mirror original semantic: if context is None, use self-attn
        if context is None or context_coords is None:
            context = x
            context_coords = coords

        b, n_q, d = x.shape
        h = self.num_heads
        dh = d // h

        # QKV projs and reshape to heads
        q = self.q_proj(x).reshape(b, n_q, h, dh).transpose(1, 2)    # [B,H,Nq,Dh]
        k = self.k_proj(context).reshape(b, -1, h, dh).transpose(1, 2)  # [B,H,Nk,Dh]
        v = self.v_proj(context).reshape(b, -1, h, dh).transpose(1, 2)  # [B,H,Nk,Dh]

        # RoPE (same as original)
        q = self.rope(q, coords[:, None, :, :])
        k = self.rope(k, context_coords[:, None, :, :])

        # Explicit attention (so we can capture weights)
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dh)   # [B,H,Nq,Nk]
        attn = F.softmax(attn, dim=-1)

        # Save for later inspection
        self._last_attn = attn

        # Attn apply + projection (match original)
        x = torch.matmul(attn, v)                                     # [B,H,Nq,Dh]
        x = x.transpose(1, 2).reshape(b, n_q, d)                      # [B,Nq,D]
        x = self.proj(x)
        return x

    # Patch every Attention module
    n_patched = 0
    for m in capi_model.modules():
        if isinstance(m, AttentionType):
            m.forward = types.MethodType(patched_forward, m)
            n_patched += 1
    if n_patched == 0:
        raise RuntimeError("No Attention modules found to patch.")

    return capi_model