"""Attention extraction utilities (Q/K dump per layer).

This module patches all attention layers (LLaMA, Qwen3) to capture Q/K after
RoPE during the prefill step (q_len > 1). Each layer saves one pair of tensors:
    q: [B, Hq, S, D], k: [B, Hk, S, D], cast to float32 on CPU for analysis.

Files are saved as Torch .pt with keys {'q', 'k'} under
    <save_dir>/q_layer{idx:02d}.pt and k_layer{idx:02d}.pt

Use `patch_model_for_extraction` to install the hooks prior to running a single
generate call (e.g., max_new_tokens=1) to trigger prefill across all layers.
"""

from __future__ import annotations

import os
from typing import Callable, Tuple

import torch
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def _apply_rotary_pos_emb(
    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, *, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (_rotate_half(q) * sin)
    k_embed = (k * cos) + (_rotate_half(k) * sin)
    return q_embed, k_embed


def _save_qk_pair(q: torch.Tensor, k: torch.Tensor, save_dir: str, layer_idx: int) -> None:
    os.makedirs(save_dir, exist_ok=True)
    q_cpu = q.detach().to(torch.float32).to("cpu")
    k_cpu = k.detach().to(torch.float32).to("cpu")
    torch.save({"q": q_cpu}, os.path.join(save_dir, f"q_layer{layer_idx:02d}.pt"))
    torch.save({"k": k_cpu}, os.path.join(save_dir, f"k_layer{layer_idx:02d}.pt"))


def _get_llama_forward_capturer(attn_module, save_dir: str, layer_idx: int) -> Callable:
    original_forward = attn_module.forward

    def custom_forward(
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()
        head_dim = attn_module.head_dim
        num_heads = getattr(attn_module, "num_heads", getattr(attn_module.config, "num_attention_heads"))
        num_kv_heads = getattr(attn_module, "num_key_value_heads", getattr(attn_module.config, "num_key_value_heads"))

        q = attn_module.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
        k = attn_module.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)

        if position_embeddings is not None:
            cos, sin = position_embeddings
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        elif hasattr(attn_module, "rotary_emb"):
            cos, sin = attn_module.rotary_emb(k, position_ids)
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        else:
            cos = sin = None

        if cos is not None and sin is not None:
            q, k = _apply_rotary_pos_emb(q, k, cos, sin)

        # Save only once during prefill for this layer per sample
        if q_len > 1 and not getattr(attn_module, "_extracted_once", False):
            # Allow dynamic save_dir override per sample
            target_dir = getattr(attn_module, "_extract_save_dir", save_dir)
            _save_qk_pair(q, k, target_dir, layer_idx)
            attn_module._extracted_once = True

        return original_forward(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )

    return custom_forward


def _get_qwen_forward_capturer(attn_module, save_dir: str, layer_idx: int) -> Callable:
    original_forward = attn_module.forward

    def custom_forward(
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, attn_module.head_dim)
        bsz, q_len, _ = hidden_states.size()

        q = attn_module.q_norm(attn_module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        k = attn_module.k_norm(attn_module.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)

        if position_embeddings is not None:
            cos, sin = position_embeddings
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        elif hasattr(attn_module, "rotary_emb"):
            cos, sin = attn_module.rotary_emb(k, position_ids)
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        else:
            cos = sin = None

        if cos is not None and sin is not None:
            q, k = _apply_rotary_pos_emb(q, k, cos, sin)

        if q_len > 1 and not getattr(attn_module, "_extracted_once", False):
            target_dir = getattr(attn_module, "_extract_save_dir", save_dir)
            _save_qk_pair(q, k, target_dir, layer_idx)
            attn_module._extracted_once = True

        return original_forward(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )

    return custom_forward


def patch_model_for_extraction(model, save_dir: str) -> None:
    """Patch all supported attention layers to dump Q/K during prefill.

    Args:
        model: HF CausalLM model instance with `.model.layers` list.
        save_dir: Directory to write q/k files.
    """
    for i, layer in enumerate(model.model.layers):
        attn = layer.self_attn
        if isinstance(attn, LlamaAttention):
            attn._orig_forward = attn.forward
            attn.forward = _get_llama_forward_capturer(attn, save_dir, i)
            attn._extract_save_dir = save_dir
            attn._extracted_once = False
            print(f"Patched LLaMA attention layer {i} for extraction")
        elif isinstance(attn, Qwen3Attention):
            attn._orig_forward = attn.forward
            attn.forward = _get_qwen_forward_capturer(attn, save_dir, i)
            attn._extract_save_dir = save_dir
            attn._extracted_once = False
            print(f"Patched Qwen3 attention layer {i} for extraction")


def set_extraction_dir_and_reset(model, save_dir: str) -> None:
    """Update target directory for Q/K dumps and reset per-layer flags.

    Call this before each new sample when extracting multiple samples.
    """
    os.makedirs(save_dir, exist_ok=True)
    for layer in model.model.layers:
        attn = layer.self_attn
        setattr(attn, "_extract_save_dir", save_dir)
        setattr(attn, "_extracted_once", False)


