"""High-level model patcher for sparse attention.

Provides a single function `patch_model_for_sparse_attention` which walks the
transformer layers and installs a custom attention forward per layer using
per-layer `SparseAttentionManager` instances.
"""

from __future__ import annotations

from typing import Optional

from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention

from ..sparse_attention.manager import SparseAttentionManager
from ..sparse_attention.mask_topk import MaskTopKAttentionManager, MaskTopKRecallAttentionManager
from .llama_patch import get_custom_llama_forward, get_custom_llama_forward_mask_topk
from .qwen_patch import get_custom_qwen_forward


def patch_model_for_sparse_attention(model, *, top_token_ratio: float = 0.05) -> None:
    """Patch supported attention modules with PQ-based sparse decoding (legacy)."""
    for i, layer in enumerate(model.model.layers):
        attn = layer.self_attn
        if isinstance(attn, LlamaAttention):
            try:
                layer_device = attn.o_proj.weight.device
                layer_dtype = attn.o_proj.weight.dtype
            except Exception:
                p = next(attn.parameters())
                layer_device, layer_dtype = p.device, p.dtype
            manager = SparseAttentionManager(model.config, layer_dtype, layer_device, top_token_ratio=top_token_ratio)
            custom_forward = get_custom_llama_forward(attn, manager)
            attn.forward = custom_forward
            print(f"Patched Llama attention layer {i} (pq_sparse)")
        elif isinstance(attn, Qwen3Attention):
            try:
                layer_device = attn.o_proj.weight.device
                layer_dtype = attn.o_proj.weight.dtype
            except Exception:
                p = next(attn.parameters())
                layer_device, layer_dtype = p.device, p.dtype
            manager = SparseAttentionManager(model.config, layer_dtype, layer_device, top_token_ratio=top_token_ratio)
            custom_forward = get_custom_qwen_forward(attn, manager)
            attn.forward = custom_forward
            print(f"Patched Qwen3 attention layer {i} (pq_sparse)")


def patch_model_for_mask_topk(model, *, sparsity_ratio: float) -> None:
    """Patch supported attention modules with mask-topk decoding (no PQ)."""
    for i, layer in enumerate(model.model.layers):
        attn = layer.self_attn
        if isinstance(attn, LlamaAttention):
            try:
                layer_device = attn.o_proj.weight.device
                layer_dtype = attn.o_proj.weight.dtype
            except Exception:
                p = next(attn.parameters())
                layer_device, layer_dtype = p.device, p.dtype
            manager = MaskTopKAttentionManager(model.config, layer_dtype, layer_device, sparsity_ratio=sparsity_ratio)
            custom_forward = get_custom_llama_forward_mask_topk(attn, manager)
            attn.forward = custom_forward
            # keep a handle to manager for external control if needed
            layer.mask_topk_manager = manager  # type: ignore[attr-defined]
            print(f"Patched Llama attention layer {i} (mask_topk)")


def patch_model_for_mask_topk_recall(model, *, base_sparsity_ratio: float, recall_ratio: float) -> None:
    """Patch attention with recall-aware mask-topk selection.

    base_sparsity_ratio: fraction of tokens constituting the base top-K set (default 0.05)
    recall_ratio: fraction of base top-K retained; the remainder is substituted by the next-best tokens.
    """
    for i, layer in enumerate(model.model.layers):
        attn = layer.self_attn
        if isinstance(attn, LlamaAttention):
            try:
                layer_device = attn.o_proj.weight.device
                layer_dtype = attn.o_proj.weight.dtype
            except Exception:
                p = next(attn.parameters())
                layer_device, layer_dtype = p.device, p.dtype
            manager = MaskTopKRecallAttentionManager(
                model.config, layer_dtype, layer_device,
                base_sparsity_ratio=base_sparsity_ratio, recall_ratio=recall_ratio
            )
            # Reuse the same custom forward; manager exposes select_topk_indices
            custom_forward = get_custom_llama_forward_mask_topk(attn, manager)
            attn.forward = custom_forward
            layer.mask_topk_manager = manager  # type: ignore[attr-defined]
            print(f"Patched Llama attention layer {i} (mask_topk_recall)")



