import math
import torch # type: ignore
import torch.nn.functional as F # type: ignore
from torch import nn # type: ignore
from typing import Optional, Tuple, Callable, Dict # type: ignore
from transformers.models.llama.configuration_llama import LlamaConfig # type: ignore
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig 
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
    Qwen2VLAttention,
    repeat_kv,
    apply_multimodal_rotary_pos_emb,
)# type: ignore
from transformers.models.llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb # type: ignore
from transformers.generation import GenerationMixin # type: ignore

class ClearSightGenerationMixin(GenerationMixin):
    """
    Mixin to capture 'key_position' and keep it alive
    through each generation step.
    """
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        **kwargs,
    ):
        # call the original to get standard args
        base_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            **kwargs,
        )
        # re‑inject key_position if present
        if "key_position" in kwargs:
            base_inputs["key_position"] = kwargs["key_position"]
        return base_inputs

def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    q_len: int,
    enh_para: float,
    sup_para: float,
    key_position,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    # enhance visual information ----------
    # SYS_LEN = 35
    # IMG_LEN = 576
    img_start = key_position["image_start"]
    img_end = key_position["image_end"]

    # print("img_start:", img_start, "img_end:", img_end)

    IMG_LEN = img_end - img_start + 1
    SYS_LEN = img_start
    if q_len > SYS_LEN+IMG_LEN:
        attn_weights[:, :, SYS_LEN+IMG_LEN:, SYS_LEN:SYS_LEN+IMG_LEN] = enh_para * attn_weights[:, :, SYS_LEN+IMG_LEN:, SYS_LEN:SYS_LEN+IMG_LEN]
        attn_weights[:, :, SYS_LEN+IMG_LEN:, :SYS_LEN] = sup_para * attn_weights[:, :, SYS_LEN+IMG_LEN:, :SYS_LEN]
    else:
        attn_weights[:, :, :, SYS_LEN:SYS_LEN+IMG_LEN] = enh_para * attn_weights[:, :, :, SYS_LEN:SYS_LEN+IMG_LEN]
        attn_weights[:, :, :, :SYS_LEN] = sup_para * attn_weights[:, :, :, :SYS_LEN]

    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights

class AttnAdapter(LlamaAttention):

    def __init__(self, config: LlamaConfig, layer_idx: int, enh_para: float, sup_para: float):
        super().__init__(config,layer_idx =layer_idx)
        self.enh_para = enh_para
        self.sup_para = sup_para

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: None,
        cache_position: Optional[torch.LongTensor] = None,
        key_position: Optional[Dict[str, int]] = None,
        **kwargs
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        bsz, q_len, _ = hidden_states.size()
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        # print(hidden_states.dtype)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # key_position = kwargs.get("key_position", None)

        attn_output, attn_weights = eager_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            q_len=q_len,
            enh_para=self.enh_para,
            sup_para=self.sup_para,
            key_position = key_position,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

class AttnAdapterQwen(Qwen2VLAttention):
    def __init__(
        self,
        config: Qwen2VLConfig,
        layer_idx: Optional[int] = None,
        enh_para: float = 1.15,
        sup_para: float  = 0.95,
    ):
        super().__init__(config, layer_idx=layer_idx)
        self.enh_para = enh_para
        self.sup_para = sup_para

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor]               = None,
        position_ids: Optional[torch.LongTensor]              = None,
        past_key_value: Optional[torch.Tensor]                = None,
        output_attentions: bool                               = False,
        use_cache: bool                                       = False,
        cache_position: Optional[torch.LongTensor]            = None,
        position_embeddings: Optional[Tuple[torch.Tensor, ...]] = None,
        key_position: Optional[Dict[str, int]] = None,
        **kwargs
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        bsz, q_len, _ = hidden_states.size()

        # 1) project + reshape
        query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states   = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        # 2) rotary (multimodal)
        cos, sin = position_embeddings
        query_states, key_states = apply_multimodal_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin,
            self.rope_scaling["mrope_section"],
        )

        # 3) caching
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

        # key_position = kwargs.get("key_position", None)
        # 4) run your custom attention
        scaling = 1.0 / math.sqrt(self.head_dim)
        attn_output, attn_weights = eager_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            scaling=scaling,
            q_len=q_len,
            enh_para=self.enh_para,
            sup_para=self.sup_para,
            key_position=key_position,            # now required
            dropout=self.attention_dropout,
        )

        # 5) final o_proj
        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
