from typing import Callable, Optional, Tuple, Any, Dict, Callable
import torch
import torch.library
import torch.utils.checkpoint
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.processing_utils import Unpack
import transformers.models.llama.modeling_llama as modeling_llama
import transformers.models.qwen2.modeling_qwen2 as modeling_qwen2
from flash_attn import flash_attn_with_kvcache
import os

# Monkey Patch to capture Attention Queries for TargetKVSDDraftModel
# Model: Llama3 and Qwen2

"""
CaptureAttentionContext (full version)

- Supports Llama and Qwen2 family models from HuggingFace Transformers
- Works with multi-GPU device_map="auto" (Accelerate dispatch) by patching
  *instance-level* forward methods for attention modules, not class-level.
- Captures RoPE-applied Q (q_rope) for selected layers/heads defined on model:
    - model.important_layers : list[int] or 1D tensor[int]
    - model.important_heads  : shape [num_important_layers, k] LongTensor
- Stores results on model.latest_captured_rope_queries
- Optional: capture_all_queries=True keeps full sequence queries; otherwise last token only
- Moves captured tensors to cuda:0 by default for safe stacking

Usage:
    ctx = CaptureAttentionContext(model, capture_all_queries=False, move_to="cuda:0")
    with ctx as m:
        # ensure these exist on m
        # m.important_layers = ...
        # m.important_heads = ...
        out = m(...)
    q = model.latest_captured_rope_queries
"""

class CaptureAttentionContext:
    """
    Context manager that monkey-patches attention *instances* inside a Transformers model
    to capture RoPE-applied queries (q_rope) for important layers/heads.

    This design is robust under device_map="auto" because Accelerate often wraps/dispatches
    module.forward at the instance level; class-level patches can be bypassed.
    """

    def __init__(
        self,
        model,
        capture_all_queries: bool = False,
        move_to: Optional[str] = "cuda:0",
        verbose: bool = False,
        measure_latency: bool = False,
        capture_queries: bool = True,
    ):
        self.model = model
        self.capture_all_queries = bool(capture_all_queries)
        self.move_to = move_to  # e.g., "cuda:0" or None to keep on original device
        self.verbose = bool(verbose)
        self.measure_latency = bool(measure_latency)
        self.capture_queries = bool(capture_queries)

        self.model_type: str = ""
        self.layer_to_idx: Dict[int, int] = {}
        self._orig_forward_by_module: Dict[torch.nn.Module, Callable] = {}

        # Detect model type
        config = getattr(self.model, "config", None)
        if config is None:
            raise ValueError("model must have a .config (Transformers model expected).")

        model_type = str(getattr(config, "model_type", "")).lower()
        architectures = getattr(config, "architectures", []) or []
        is_llama = (model_type == "llama") or any("Llama" in str(a) for a in architectures)
        is_qwen2 = (model_type == "qwen2") or any("Qwen2" in str(a) for a in architectures)

        if is_llama:
            if modeling_llama is None:
                raise ImportError("Detected Llama model but could not import transformers.models.llama.modeling_llama")
            self.model_type = "llama"
            self._attn_classnames = {
                # Transformers variants
                "LlamaAttention",
                "LlamaSdpaAttention",
                "LlamaFlashAttention2",
            }
        elif is_qwen2:
            if modeling_qwen2 is None:
                raise ImportError("Detected Qwen2 model but could not import transformers.models.qwen2.modeling_qwen2")
            self.model_type = "qwen2"
            self._attn_classnames = {
                # Transformers variants
                "Qwen2Attention",
                "Qwen2SdpaAttention",
                "Qwen2FlashAttention2",
            }
        else:
            raise ValueError(f"Unsupported model type: {model_type} (architectures={architectures}).")

    def _log(self, msg: str):
        if self.verbose:
            print(msg)

    def _get_important_layers(self):
        dm = self.model
        if not hasattr(dm, "important_layers"):
            raise AttributeError("model must have attribute: important_layers (list[int] or 1D tensor).")
        imp = dm.important_layers
        if torch.is_tensor(imp):
            imp_list = imp.detach().cpu().tolist()
        else:
            imp_list = list(imp)
        return [int(x) for x in imp_list]

    def _get_heads_for_layer(self, layer_idx: int, q_device: torch.device) -> torch.Tensor:
        dm = self.model
        if not hasattr(dm, "important_heads"):
            raise AttributeError("model must have attribute: important_heads (tensor of head indices).")

        # Map absolute layer index -> row in important_heads
        row = self.layer_to_idx[layer_idx]
        head_indices = dm.important_heads[row]

        # Ensure 1D LongTensor indices for index_select
        if not torch.is_tensor(head_indices):
            head_indices = torch.tensor(head_indices, dtype=torch.long, device=q_device)
        else:
            head_indices = head_indices.to(device=q_device)
            if head_indices.dtype != torch.long:
                head_indices = head_indices.long()

        if head_indices.dim() != 1:
            head_indices = head_indices.reshape(-1)

        return head_indices

    def _should_patch_module(self, mod: torch.nn.Module) -> bool:
        # Patch by class name to avoid import / type identity issues across versions
        return mod.__class__.__name__ in self._attn_classnames

    def _apply_rope(self, q: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor]):
        cos, sin = position_embeddings
        if self.model_type == "llama":
            q_rope, _ = modeling_llama.apply_rotary_pos_emb(q, q, cos, sin)
        else:
            q_rope, _ = modeling_qwen2.apply_rotary_pos_emb(q, q, cos, sin)
        return q_rope

    def __enter__(self):
        dm = self.model
        self._log(f"\n[CaptureAttentionContext] Enter: {getattr(dm.config, '_name_or_path', '(unknown)')}")
        self._log(f"[CaptureAttentionContext] model_type={self.model_type} capture_all_queries={self.capture_all_queries} move_to={self.move_to}")

        # Enable capture and init buffers
        # dm._capture_enabled = True
        dm.latest_captured_rope_queries = []
        # Check if analysis mode is enabled
        if self.measure_latency:
            dm.latest_attention_latencies = []
        else:
            dm.latest_attention_latencies = None

        # Build layer mapping (only if capturing queries)
        if self.capture_queries:
            imp_layers_list = self._get_important_layers()
            self._log(f"[CaptureAttentionContext] important_layers={imp_layers_list}")
            self.layer_to_idx = {layer_idx: i for i, layer_idx in enumerate(imp_layers_list)}
        else:
            self.layer_to_idx = {}

        # Patch attention instances
        patched_cnt = 0

        def make_patched_forward(mod: torch.nn.Module, orig_fwd: Callable):
            """
            Patched instance forward: capture q_rope then call orig forward.
            Uses *args/**kwargs to be signature-robust across Transformers versions.
            """
            def patched_forward(*args, **kwargs):
                # Try kwargs first (Transformers often calls with kwargs)
                hidden_states = kwargs.get("hidden_states", None)
                position_embeddings = kwargs.get("position_embeddings", None)

                # Fallback to positional if needed
                # Expected order: hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, ...
                if hidden_states is None and len(args) >= 1:
                    hidden_states = args[0]
                if position_embeddings is None and len(args) >= 2:
                    position_embeddings = args[1]

                # Capture only if we have the needed inputs and capture_queries is enabled
                if (
                    self.capture_queries
                    and getattr(dm, "_capture_enabled", False)
                    and (hidden_states is not None)
                    and (position_embeddings is not None)
                ):
                    layer_idx = getattr(mod, "layer_idx", None)
                    if (layer_idx is not None) and (int(layer_idx) in self.layer_to_idx):
                        # Compute q
                        input_shape = hidden_states.shape[:-1]  # (B, S)
                        head_dim = getattr(mod, "head_dim", None)
                        if head_dim is None:
                            # Some variants store as .head_dim, others infer from q_proj
                            # We require head_dim to reshape correctly; if missing, skip capture.
                            return orig_fwd(*args, **kwargs)

                        hidden_shape = (*input_shape, -1, head_dim)
                        q = mod.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)  # (B, H, S, D)

                        # RoPE
                        q_rope = self._apply_rope(q, position_embeddings)

                        # Select heads
                        head_indices = self._get_heads_for_layer(int(layer_idx), q_rope.device)

                        # Select tokens
                        if self.capture_all_queries:
                            current_q_rope = q_rope                      # (B, H, S, D)
                        else:
                            current_q_rope = q_rope[:, :, -1:, :]        # (B, H, 1, D)

                        filtered = current_q_rope.index_select(1, head_indices)  # (B, Hsel, S_or_1, D)

                        # Move to a single device for safe stacking later (optional)
                        out_tensor = filtered.detach()
                        if self.move_to is not None:
                            out_tensor = out_tensor.to(self.move_to)

                        dm.latest_captured_rope_queries.append(out_tensor)

                measure = self.measure_latency and hidden_states is not None and hidden_states.is_cuda
                latency_start = latency_end = None
                if measure:
                    device = hidden_states.device
                    latency_start = torch.cuda.Event(enable_timing=True)
                    latency_end = torch.cuda.Event(enable_timing=True)
                    with torch.cuda.device(device):
                        torch.cuda.synchronize(device)  # Ensure all previous operations are complete
                        latency_start.record()

                # Always continue the original forward
                out = orig_fwd(*args, **kwargs)

                if measure:
                    with torch.cuda.device(device):
                        latency_end.record()
                        latency_end.synchronize()
                    dm.latest_attention_latencies.append(latency_start.elapsed_time(latency_end))

                return out

            return patched_forward

        for name, mod in dm.named_modules():
            if self._should_patch_module(mod):
                # Save current (possibly wrapped) forward, then override
                self._orig_forward_by_module[mod] = mod.forward
                mod.forward = make_patched_forward(mod, mod.forward)
                patched_cnt += 1

        self._log(f"[CaptureAttentionContext] patched attention instances: {patched_cnt}")
        if patched_cnt == 0:
            # Provide helpful diagnostics: show likely attention-ish module classnames
            self._log("[CaptureAttentionContext] WARNING: patched_cnt=0. "
                      "Your Transformers version may use different attention class names. "
                      "Consider scanning dm.named_modules() for classnames containing 'Attention'/'Attn'.")
        return dm

    def __exit__(self, exc_type, exc_val, exc_tb):
        dm = self.model

        # Restore all patched forwards
        for mod, orig_fwd in self._orig_forward_by_module.items():
            mod.forward = orig_fwd
        self._orig_forward_by_module.clear()

        dm._capture_enabled = False

        # Reshape captured list -> tensor
        num_layers = len(self._get_important_layers()) if hasattr(dm, "important_layers") else 0
        if num_layers == 0:
            self._log("[CaptureAttentionContext] Exit: important_layers empty; nothing to reshape.")
            return

        def stack_and_reshape(captured_list):
            if not captured_list:
                return None

            # try:
            #     stacked = torch.stack(captured_list)  # (Total, B, Hsel, QLen, D) or (Total, B, Hsel, S, D)
            # except RuntimeError as e:
            #     self._log(f"[CaptureAttentionContext] Stack error: {e}")
            #     return None
            stacked = torch.stack(captured_list)  # (Total, B, Hsel, QLen, D) or (Total, B, Hsel, S, D)

            total_items = stacked.shape[0]
            if total_items % num_layers != 0:
                self._log(f"[CaptureAttentionContext] WARNING: captured {total_items} items not divisible by num_layers={num_layers}. "
                          "Returning unshaped stacked tensor.")
                return stacked

            num_steps = total_items // num_layers
            reshaped = stacked.view(num_steps, num_layers, *stacked.shape[1:])  # (Step, Layer, B, Hsel, QLen, D)

            # If we only captured last token, squeeze QLen==1 dimension
            # In capture_all_queries mode, keep full QLen dimension
            if (not self.capture_all_queries) and reshaped.shape[4] == 1:
                reshaped = reshaped.squeeze(4)  # (Step, Layer, B, Hsel, D)

            return reshaped

        dm.latest_captured_rope_queries = stack_and_reshape(dm.latest_captured_rope_queries)
        self._log(f"[CaptureAttentionContext] Exit: latest_captured_rope_queries is None? {dm.latest_captured_rope_queries is None}")
        if dm.latest_captured_rope_queries is not None:
            self._log(f"[CaptureAttentionContext] Exit: latest_captured_rope_queries shape={tuple(dm.latest_captured_rope_queries.shape)}")

# Monkey Patch to capture Llama3 Attention Queries for TargetKVSDDraftModel

def apply_llama3_attention_monkey_patch(cache_implementation: str):
    print("Applying Llama Attention Optimization Patch...")
    modeling_llama.LlamaAttention.cache_implementation = cache_implementation
    modeling_llama.LlamaAttention.forward = llama_attention_monkey_patch_installed
    print("Patch applied: LlamaAttention.forward is now optimized.")
    print("Applying StaticCache compiler-friendly patch...")
    StaticCache.update = custom_op_static_cache_update
    print("Patch applied: StaticCache.update now uses custom ops.")

op_name = "specdecodes::flash_attn_wrapper"
@torch.library.custom_op(op_name, mutates_args=())
def secure_flash_attn_wrapper(q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, 
                              k: torch.Tensor | None, v: torch.Tensor | None, 
                              cache_seqlens: torch.Tensor | None, 
                              softmax_scale: float, causal: bool) -> torch.Tensor:
    
    from flash_attn.flash_attn_interface import flash_attn_with_kvcache
    
    out = flash_attn_with_kvcache(
        q, k_cache, v_cache, k=k, v=v, 
        cache_seqlens=cache_seqlens, 
        softmax_scale=softmax_scale, causal=causal
    )
    
    if isinstance(out, tuple):
        return out[0]
    return out

@secure_flash_attn_wrapper.register_fake
def _(q, k_cache, v_cache, k, v, cache_seqlens, softmax_scale, causal):
    return torch.empty_like(q)

def llama_attention_monkey_patch_installed(self,
    hidden_states: torch.Tensor,
    position_embeddings: tuple[torch.Tensor, torch.Tensor],
    attention_mask: Optional[torch.Tensor],
    past_key_value: Optional[Cache] = None,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs: Unpack[FlashAttentionKwargs],)-> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # FA2 Optimized Attention with KV Cache
        if (
            self.config._attn_implementation == "flash_attention_2" and self.cache_implementation == "static_cache"
            and past_key_value is not None 
            and cache_position is not None 
            and flash_attn_with_kvcache is not None
        ):
            # 1. Prepare cache_seqlens tensor
            cache_seqlens = (cache_position[-1] + 1).repeat(hidden_states.shape[0]).to(dtype=torch.int32)
            
            # 2. Transpose Flash Attention 
            # StaticCache: (Batch, Heads, Max_Seq, Dim)
            # flash_attn_with_kvcache: (Batch, Max_Seq, Heads, Dim)
            k_cache = key_states.transpose(1, 2)
            v_cache = value_states.transpose(1, 2)
            
            # Query (Batch, Seq, Heads, Dim)
            q = query_states.transpose(1, 2)
            
            # 3. Call the optimized Flash Attention
            attn_output = secure_flash_attn_wrapper(
                q=q,
                k_cache=k_cache,
                v_cache=v_cache,
                k=None, 
                v=None, 
                cache_seqlens=cache_seqlens, 
                causal=True, 
                softmax_scale=self.scaling,
            )
            
            # 4. Reshape the output
            # attn_output: (Batch, Seq, Heads, Dim) -> needs to be reshaped back to (Batch, Seq, Hidden_Size)
            attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], -1)
            
            return self.o_proj(attn_output), None

        attention_interface: Callable = modeling_llama.eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

# Monkey Patch to capture Qwen2 Attention Queries for TargetKVSDDraftModel

def qwen2_attention_monkey_patch_installed(self,
    hidden_states: torch.Tensor,
    position_embeddings: tuple[torch.Tensor, torch.Tensor],
    attention_mask: Optional[torch.Tensor],
    past_key_value: Optional[Cache] = None,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs: Unpack[FlashAttentionKwargs],)-> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = modeling_qwen2.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # FA2 Optimized Attention with KV Cache
        if (
            self.config._attn_implementation == "flash_attention_2" and getattr(self, "cache_implementation", None) == "static_cache"
            and past_key_value is not None 
            and cache_position is not None 
            and flash_attn_with_kvcache is not None
        ):
            # 1. Prepare cache_seqlens tensor
            cache_seqlens = (cache_position[-1] + 1).repeat(hidden_states.shape[0]).to(dtype=torch.int32)
            
            # 2. Transpose Flash Attention 
            # StaticCache: (Batch, Heads, Max_Seq, Dim)
            # flash_attn_with_kvcache: (Batch, Max_Seq, Heads, Dim)
            k_cache = key_states.transpose(1, 2)
            v_cache = value_states.transpose(1, 2)
            
            # Query (Batch, Seq, Heads, Dim)
            q = query_states.transpose(1, 2)
            
            # 3. Call the optimized Flash Attention
            attn_output = secure_flash_attn_wrapper(
                q=q,
                k_cache=k_cache,
                v_cache=v_cache,
                k=None, 
                v=None, 
                cache_seqlens=cache_seqlens, 
                causal=True, 
                softmax_scale=self.scaling,
            )
            
            # 4. Reshape the output
            # attn_output: (Batch, Seq, Heads, Dim) -> needs to be reshaped back to (Batch, Seq, Hidden_Size)
            attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], -1)
            
            return self.o_proj(attn_output), None

        attention_interface: Callable = modeling_qwen2.eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

def apply_qwen2_attention_monkey_patch(cache_implementation: str):
    if modeling_qwen2 is None:
        print("Qwen2 modeling not available, skipping Qwen2 patch.")
        return

    print("Applying Qwen2 Attention Optimization Patch...")
    modeling_qwen2.Qwen2Attention.cache_implementation = cache_implementation
    modeling_qwen2.Qwen2Attention.forward = qwen2_attention_monkey_patch_installed
    print("Patch applied: Qwen2Attention.forward is now optimized.")
    print("Applying StaticCache compiler-friendly patch (for Qwen2)...")
    StaticCache.update = custom_op_static_cache_update
    print("Patch applied: StaticCache.update now uses custom ops.")

# StaticCache update method (for torch compile optimization)

cache_op_name = "specdecodes::update_kv_cache"
@torch.library.custom_op(cache_op_name, mutates_args={"k_cache", "v_cache"})
def update_kv_cache_op(
    k_cache: torch.Tensor, 
    v_cache: torch.Tensor, 
    key_states: torch.Tensor, 
    value_states: torch.Tensor, 
    layer_idx: int, 
    cache_position: torch.Tensor
) -> None:
    
    k_cache.index_copy_(2, cache_position, key_states)
    v_cache.index_copy_(2, cache_position, value_states)
    return None

@update_kv_cache_op.register_fake
def _(k_cache, v_cache, key_states, value_states, layer_idx, cache_position):
    return None

def custom_op_static_cache_update(
    self,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    layer_idx: int,
    cache_kwargs: Optional[dict] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    
    cache_position = cache_kwargs.get("cache_position")
    
    if not isinstance(cache_position, torch.Tensor):
        cache_position = torch.tensor(cache_position, device=key_states.device)
        
    k_out = self.key_cache[layer_idx]
    v_out = self.value_cache[layer_idx]
    
    update_kv_cache_op(k_out, v_out, key_states, value_states, layer_idx, cache_position)
    
    return k_out, v_out