import math
import types
from typing import Optional, Tuple, List
from collections import deque
import torch
import torch.nn as nn
try:
    from minigpt4.models.modeling_llama import apply_rotary_pos_emb as minigpt4_apply_rotary_pos_emb
except ImportError:
    from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as minigpt4_apply_rotary_pos_emb

from Qwen_VL.modeling_qwen import apply_rotary_pos_emb as qwen_apply_rotary_pos_emb

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as llama_apply_rotary_pos_emb

def llama_new_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()
    query_states = (
        self.q_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    key_states = (
        self.k_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    value_states = (
        self.v_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = llama_apply_rotary_pos_emb(
        query_states, key_states, cos, sin, position_ids
    )

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(
            key_states, value_states, self.layer_idx, cache_kwargs
        )
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
        self.head_dim
    )

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(
            attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
        )

        # =========================================================================

    # =========================================================================
    if hasattr(self, "aggregation") and q_len == 1:
        img_start = self.img_start_idx
        img_end = self.img_end_idx

        visual_logits = attn_weights[:, :, -1, img_start:img_end]

        text_logits_pre = attn_weights[:, :, -1, :img_start]
        text_logits_post = attn_weights[:, :, -1, img_end:]

        max_vis_val, _ = visual_logits.max(dim=-1, keepdim=True)

        max_text_val = torch.tensor(-1e4, device=attn_weights.device)
        if text_logits_pre.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_pre.max(dim=-1, keepdim=True)[0])
        if text_logits_post.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_post.max(dim=-1, keepdim=True)[0])

        # -----------------------------------------------------



        diff = max_text_val - max_vis_val



        threshold_cutoff = 5.5



        base_alpha = self.alpha 





        should_intervene = (diff < threshold_cutoff).float()

        # dynamic_alpha = base_alpha * should_intervene
        dynamic_alpha = base_alpha* should_intervene

        # -----------------------------------------------------
        enhancement_signal = torch.zeros_like(visual_logits)

         # self.previous_attn_weights=attn_weights
        epsilon = 1e-6
        img_start = self.img_start_idx
        img_end   = self.img_end_idx

        attn_img = attn_weights[:, :, :, img_start:img_end]          # (B, H, Q, K_img)
        k_heads = int(self.num_heads * 0.5)
        attn_img, _ = torch.topk(visual_logits.abs(), k=k_heads, dim=1)
        # debug_tensor("attn_img (raw logits)", attn_img)

        attn_img_stat = attn_img.to(torch.float32)    
        attn_img_stat[~torch.isfinite(attn_img_stat)] = 0.0
        # debug_tensor("attn_img_stat (cleaned)", attn_img_stat)

        abs_img   = attn_img_stat.abs()                              # (B, H, Q, K_img)
        # debug_tensor("abs_img", abs_img)
        mean_attn = abs_img.mean(dim=1)                              # (B, Q, K_img)
        # debug_tensor("mean_attn", mean_attn)
        var_attn  = abs_img.var(dim=1, unbiased=False)               # (B, Q, K_img)
        # debug_tensor("var_attn", var_attn)

        var_max = var_attn.max(dim=-1, keepdim=True).values          # (B, Q, 1)
        # debug_tensor("var_max (before clamp)", var_max)
        var_max = torch.clamp(var_max, min=epsilon)
        # debug_tensor("var_max (after clamp)", var_max)
        var_norm = var_attn / var_max                                # (B, Q, K_img)
        # debug_tensor("var_norm", var_norm)
        beta = 0.5

        # consistency_score = mean_attn - (var_norm + beta)            # (B, Q, K_img)
        consistency_score = mean_attn+0.4*(1- var_norm)  
        # print(consistency_score)
        # debug_tensor("consistency_score", consistency_score, max_print=True)

        consistency_sum = consistency_score.sum(dim=-1, keepdim=True)
        # debug_tensor("consistency_sum (before clamp)", consistency_sum)
        consistency_sum = torch.clamp(consistency_sum, min=epsilon)
        # debug_tensor("consistency_sum (after clamp)", consistency_sum)
        weight_img = consistency_score / (consistency_sum + epsilon) # (B, Q, K_img)
        enhancement_signal += consistency_score .unsqueeze(1)    


        if hasattr(self, "window_history") and len(self.window_history) > 0:
            window_size = len(self.window_history)
            curr_total_len = attn_weights.shape[-1]
            history_start_idx = curr_total_len - 1 - window_size
            if history_start_idx >= 0:
                text_window_logits = attn_weights[:, :, -1, history_start_idx : curr_total_len-1]
                best_history_idx = text_window_logits.argmax(dim=-1)
                history_stack = torch.stack(list(self.window_history), dim=0)
                gather_idx = best_history_idx.view(1, bsz, self.num_heads, 1).expand(-1, -1, -1, img_end - img_start)
                retrieved_history = history_stack.gather(0, gather_idx).squeeze(0)

                enhancement_signal += (0.25 * retrieved_history)

        # -----------------------------------------------------
        # Final = Original + (Dynamic_Alpha * Signal)


        img_start_idx = self.img_start_idx
        img_end_idx = self.img_end_idx
        attn_weights[:, :, -1, img_start:img_end] = (
            visual_logits + (dynamic_alpha * enhancement_signal)
        )
    ### vision attention modification
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
        query_states.dtype
    )
    attn_output = torch.matmul(attn_weights, value_states)
    # print(attn_output.size())
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2)
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

def llama_head_guide(model, guided_layer_range, aggregation, alpha, img_start_idx, img_end_idx):
    layer_list = guided_layer_range if len(guided_layer_range) == 1 else list(range(guided_layer_range[0], guided_layer_range[1]))
    for i in layer_list:
        model.model.layers[i].self_attn.img_start_idx = img_start_idx
        model.model.layers[i].self_attn.img_end_idx = img_end_idx
        model.model.layers[i].self_attn.aggregation = aggregation
        model.model.layers[i].self_attn.alpha = alpha
        model.model.layers[i].self_attn.forward = types.MethodType(llama_new_forward, model.model.layers[i].self_attn)

def minigpt4_new_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    layer_idx: int = -1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    
    # ----------------------------------------------------------------------

    # ----------------------------------------------------------------------
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]

    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
        
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = minigpt4_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)


    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None


    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)


    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

    # ----------------------------------------------------------------------


    # ----------------------------------------------------------------------
        # =========================================================================

    # =========================================================================
    if hasattr(self, "aggregation") and q_len == 1:
        img_start = self.img_start_idx
        img_end = self.img_end_idx

        visual_logits = attn_weights[:, :, -1, img_start:img_end]

        text_logits_pre = attn_weights[:, :, -1, :img_start]
        text_logits_post = attn_weights[:, :, -1, img_end:]

        max_vis_val, _ = visual_logits.max(dim=-1, keepdim=True)

        max_text_val = torch.tensor(-1e4, device=attn_weights.device)
        if text_logits_pre.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_pre.max(dim=-1, keepdim=True)[0])
        if text_logits_post.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_post.max(dim=-1, keepdim=True)[0])

        # -----------------------------------------------------



        diff = max_text_val - max_vis_val



        threshold_cutoff = 6.0



        base_alpha = self.alpha 





        should_intervene = (diff < threshold_cutoff).float()

        # dynamic_alpha = base_alpha * should_intervene
        dynamic_alpha = base_alpha

        # -----------------------------------------------------
        enhancement_signal = torch.zeros_like(visual_logits)

         # self.previous_attn_weights=attn_weights
        epsilon = 1e-6
        img_start = self.img_start_idx
        img_end   = self.img_end_idx

        attn_img = attn_weights[:, :, :, img_start:img_end]          # (B, H, Q, K_img)
        k_heads = int(self.num_heads * 0.4)
        attn_img, _ = torch.topk(visual_logits.abs(), k=k_heads, dim=1)
        # debug_tensor("attn_img (raw logits)", attn_img)

        attn_img_stat = attn_img.to(torch.float32)    
        attn_img_stat[~torch.isfinite(attn_img_stat)] = 0.0
        # debug_tensor("attn_img_stat (cleaned)", attn_img_stat)

        abs_img   = attn_img_stat.abs()                              # (B, H, Q, K_img)
        # debug_tensor("abs_img", abs_img)
        mean_attn = abs_img.mean(dim=1)                              # (B, Q, K_img)
        # debug_tensor("mean_attn", mean_attn)
        var_attn  = abs_img.var(dim=1, unbiased=False)               # (B, Q, K_img)
        # debug_tensor("var_attn", var_attn)

        var_max = var_attn.max(dim=-1, keepdim=True).values          # (B, Q, 1)
        # debug_tensor("var_max (before clamp)", var_max)
        var_max = torch.clamp(var_max, min=epsilon)
        # debug_tensor("var_max (after clamp)", var_max)
        var_norm = var_attn / var_max                                # (B, Q, K_img)
        # debug_tensor("var_norm", var_norm)
        beta = 0.5

        # consistency_score = mean_attn - (var_norm + beta)            # (B, Q, K_img)
        consistency_score = mean_attn+0.4*(1- var_norm)  
        # print(consistency_score)
        # debug_tensor("consistency_score", consistency_score, max_print=True)

        consistency_sum = consistency_score.sum(dim=-1, keepdim=True)
        # debug_tensor("consistency_sum (before clamp)", consistency_sum)
        consistency_sum = torch.clamp(consistency_sum, min=epsilon)
        # debug_tensor("consistency_sum (after clamp)", consistency_sum)
        weight_img = consistency_score / (consistency_sum + epsilon) # (B, Q, K_img)
        enhancement_signal += consistency_score .unsqueeze(1)    


        if hasattr(self, "window_history") and len(self.window_history) > 0:
            window_size = len(self.window_history)
            curr_total_len = attn_weights.shape[-1]
            history_start_idx = curr_total_len - 1 - window_size
            if history_start_idx >= 0:
                text_window_logits = attn_weights[:, :, -1, history_start_idpx : curr_total_len-1]
                best_history_idx = text_window_logits.argmax(dim=-1)
                history_stack = torch.stack(list(self.window_history), dim=0)
                gather_idx = best_history_idx.view(1, bsz, self.num_heads, 1).expand(-1, -1, -1, img_end - img_start)
                retrieved_history = history_stack.gather(0, gather_idx).squeeze(0)

                enhancement_signal += (0.25 * retrieved_history)

        # -----------------------------------------------------
        # Final = Original + (Dynamic_Alpha * Signal)


        img_start_idx = self.img_start_idx
        img_end_idx = self.img_end_idx
        attn_weights[:, :, -1, img_start:img_end] = (
            visual_logits + (dynamic_alpha * enhancement_signal)
        )

    # ----------------------------------------------------------------------

    # ----------------------------------------------------------------------
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )
    attn_output = attn_output.transpose(1, 2)
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
    attn_output = self.o_proj(attn_output)
    if not output_attentions:
        attn_weights = None
    return attn_output, attn_weights, past_key_value

def minigpt4_head_guide(llm_model, guided_layer_range, aggregation, alpha, img_start_idx, img_end_idx):
    layer_list = guided_layer_range if len(guided_layer_range) == 1 else list(range(guided_layer_range[0], guided_layer_range[1]))
    

    for i in layer_list:
        if hasattr(llm_model, 'model') and hasattr(llm_model.model, 'layers'):
            target_layer = llm_model.model.layers[i].self_attn
        elif hasattr(llm_model, 'layers'): 
            target_layer = llm_model.layers[i].self_attn
        else:
            continue


        target_layer.img_start_idx = img_start_idx
        target_layer.img_end_idx = img_end_idx
        target_layer.aggregation = aggregation
        target_layer.alpha = alpha
        

        target_layer.forward = types.MethodType(minigpt4_new_forward, target_layer)

from Qwen_VL.modeling_qwen import apply_rotary_pos_emb

def qwen_new_forward(
    self,
    hidden_states: Optional[Tuple[torch.FloatTensor]],
    rotary_pos_emb: Optional[List[torch.Tensor]] = None,
    registered_causal_mask: Optional[torch.Tensor] = None,
    layer_past: Optional[Tuple[torch.Tensor]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
):

    mixed_x_layer = self.c_attn(hidden_states)
    query, key, value = mixed_x_layer.split(self.split_size, dim=2)

    query = self._split_heads(query, self.num_heads, self.head_dim)
    key = self._split_heads(key, self.num_heads, self.head_dim)
    value = self._split_heads(value, self.num_heads, self.head_dim)

    # --- Rotary Positional Embeddings ---
    if rotary_pos_emb is not None:
        cur_len = query.shape[1]

        rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
        rotary_pos_emb = (rotary_pos_emb,) * 2
        q_pos_emb, k_pos_emb = rotary_pos_emb
        query = qwen_apply_rotary_pos_emb(query, q_pos_emb)
        key = qwen_apply_rotary_pos_emb(key, k_pos_emb)

    # --- KV Cache Update ---
    if layer_past is not None:
        past_key, past_value = layer_past[0], layer_past[1]
        key = torch.cat((past_key, key), dim=1)
        value = torch.cat((past_value, value), dim=1)

    if use_cache:
        present = (key, value)
    else:
        present = None
    # --- LogN Attention Scaling ---
    if self.use_logn_attn and not self.training:
        if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
            self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
        seq_start = key.size(1) - query.size(1)
        seq_end = key.size(1)
        logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
        query = query * logn_tensor.expand_as(query)

    # --- Transpose for Attention Calculation ---
    query = query.permute(0, 2, 1, 3)
    key = key.permute(0, 2, 1, 3)
    value = value.permute(0, 2, 1, 3)


    
    # Matmul: [BSZ, Heads, Q_Len, Head_Dim] @ [BSZ, Heads, Head_Dim, K_Len]
    attn_weights = torch.matmul(query, key.transpose(-1, -2))

    if self.scale_attn_weights:
        attn_weights = attn_weights / torch.full(
            [],
            value.size(-1) ** 0.5,
            dtype=attn_weights.dtype,
            device=attn_weights.device,
        )

    # Add Mask
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    # =========================================================================
    bsz, q_len, _ = hidden_states.size()
    if hasattr(self, "aggregation") and q_len == 1:
        img_start = self.img_start_idx
        img_end = self.img_end_idx

        visual_logits = attn_weights[:, :, -1, img_start:img_end]

        text_logits_pre = attn_weights[:, :, -1, :img_start]
        text_logits_post = attn_weights[:, :, -1, img_end:]

        max_vis_val, _ = visual_logits.max(dim=-1, keepdim=True)

        max_text_val = torch.tensor(-1e4, device=attn_weights.device)
        if text_logits_pre.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_pre.max(dim=-1, keepdim=True)[0])
        if text_logits_post.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_post.max(dim=-1, keepdim=True)[0])

        # -----------------------------------------------------



        diff = max_text_val - max_vis_val



        threshold_cutoff = 6.0



        base_alpha = self.alpha 





        should_intervene = (diff < threshold_cutoff).float()

        # dynamic_alpha = base_alpha * should_intervene
        dynamic_alpha = base_alpha* should_intervene

        # -----------------------------------------------------
        enhancement_signal = torch.zeros_like(visual_logits)

         # self.previous_attn_weights=attn_weights
        epsilon = 1e-6
        img_start = self.img_start_idx
        img_end   = self.img_end_idx

        attn_img = attn_weights[:, :, :, img_start:img_end]          # (B, H, Q, K_img)
        k_heads = int(self.num_heads * 0.65)
        attn_img, _ = torch.topk(visual_logits.abs(), k=k_heads, dim=1)
        # debug_tensor("attn_img (raw logits)", attn_img)

        attn_img_stat = attn_img.to(torch.float32)    
        attn_img_stat[~torch.isfinite(attn_img_stat)] = 0.0
        # debug_tensor("attn_img_stat (cleaned)", attn_img_stat)

        abs_img   = attn_img_stat.abs()                              # (B, H, Q, K_img)
        # debug_tensor("abs_img", abs_img)
        mean_attn = abs_img.mean(dim=1)                              # (B, Q, K_img)
        # debug_tensor("mean_attn", mean_attn)
        var_attn  = abs_img.var(dim=1, unbiased=False)               # (B, Q, K_img)
        # debug_tensor("var_attn", var_attn)

        var_max = var_attn.max(dim=-1, keepdim=True).values          # (B, Q, 1)
        # debug_tensor("var_max (before clamp)", var_max)
        var_max = torch.clamp(var_max, min=epsilon)
        # debug_tensor("var_max (after clamp)", var_max)
        var_norm = var_attn / var_max                                # (B, Q, K_img)
        # debug_tensor("var_norm", var_norm)
        beta = 0.5

        # consistency_score = mean_attn+0.5*(1- var_norm)             # (B, Q, K_img)
        consistency_score = mean_attn+0.5*(1- var_norm)  
        # print(consistency_score)
        # debug_tensor("consistency_score", consistency_score, max_print=True)

        consistency_sum = consistency_score.sum(dim=-1, keepdim=True)
        # debug_tensor("consistency_sum (before clamp)", consistency_sum)
        consistency_sum = torch.clamp(consistency_sum, min=epsilon)
        # debug_tensor("consistency_sum (after clamp)", consistency_sum)
        weight_img = consistency_score / (consistency_sum + epsilon) # (B, Q, K_img)
        enhancement_signal += consistency_score .unsqueeze(1)    


        if hasattr(self, "window_history") and len(self.window_history) > 0:
            window_size = len(self.window_history)
            curr_total_len = attn_weights.shape[-1]
            history_start_idx = curr_total_len - 1 - window_size
            if history_start_idx >= 0:
                text_window_logits = attn_weights[:, :, -1, history_start_idx : curr_total_len-1]
                best_history_idx = text_window_logits.argmax(dim=-1)
                history_stack = torch.stack(list(self.window_history), dim=0)
                gather_idx = best_history_idx.view(1, bsz, self.num_heads, 1).expand(-1, -1, -1, img_end - img_start)
                retrieved_history = history_stack.gather(0, gather_idx).squeeze(0)

                enhancement_signal += (0.01 * retrieved_history)

        # -----------------------------------------------------
        # Final = Original + (Dynamic_Alpha * Signal)


        img_start_idx = self.img_start_idx
        img_end_idx = self.img_end_idx
        attn_weights[:, :, -1, img_start:img_end] = (
            visual_logits + (dynamic_alpha * enhancement_signal)
        )
    # ====================================================================
    # Softmax
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    # Dropout & Masking
    attn_weights = attn_weights.type(value.dtype)
    attn_weights = self.attn_dropout(attn_weights)

    if head_mask is not None:
        attn_weights = attn_weights * head_mask
    # # Update History
    # if hasattr(self, "aggregation") and q_len == 1:
    #     if attn_weights.shape[-1] > self.img_end_idx:
    #         curr_vis_map = attn_weights[:, :, -1, self.img_start_idx : self.img_end_idx].detach()
    #         if not hasattr(self, "window_history"):
    #             self.window_history = deque(maxlen=8)
    #         self.window_history.append(curr_vis_map)

    # Output Projection
    attn_output = torch.matmul(attn_weights, value)
    attn_output = attn_output.transpose(1, 2)

    context_layer = self._merge_heads(
        attn_output, self.num_heads, self.head_dim
    )
    attn_output = self.c_proj(context_layer)
    outputs = (attn_output, present)
    if output_attentions:
        outputs += (attn_weights,)
    return outputs

def qwen_head_guide(model, guided_layer_range, aggregation, alpha, img_start_idx, img_end_idx):
    """
    Apply Head Guidance to Qwen-VL model.
    Args:
        model: The QWenLMHeadModel instance (usually passed as model_manager.model)
    """
    layer_list = guided_layer_range if len(guided_layer_range) == 1 else list(range(guided_layer_range[0], guided_layer_range[1]))
    

    print(f"Injecting Qwen Head Guide: Layers {layer_list}, Alpha {alpha}, Img Range [{img_start_idx}, {img_end_idx}]")



    if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        layers = model.transformer.h
    elif hasattr(model, "h"):
        layers = model.h
    else:
        print("Error: Could not find layers in Qwen model structure.")
        return

    for i in layer_list:
        if i >= len(layers):
            print(f"Warning: Layer index {i} out of bounds (max {len(layers)-1})")
            continue
            
        # QWenBlock -> attn (QWenAttention)
        target_layer = layers[i].attn
        

        target_layer.img_start_idx = img_start_idx
        target_layer.img_end_idx = img_end_idx
        target_layer.aggregation = aggregation
        target_layer.alpha = alpha
        # Monkey Patch
        target_layer.forward = types.MethodType(qwen_new_forward, target_layer)

def shikra_new_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    layer_idx: Optional[int] = None, 
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """
    Shikra (Llama-based) specialized forward with Head Guidance.
    Robustly handles both Tuple-based caching (Old Transformers/Custom Llama) 
    and Object-based caching (New Transformers).
    """
    bsz, q_len, _ = hidden_states.size()

    # 1. Projections
    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

    # 2. KV Cache Length Calculation & Update
    kv_seq_len = key_states.shape[-2]
    

    current_layer_idx = layer_idx if layer_idx is not None else getattr(self, 'layer_idx', None)


    if past_key_value is not None:
        if hasattr(past_key_value, "get_usable_length") and current_layer_idx is not None: 
            # Transformers >= 4.36 (Cache Object)
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, current_layer_idx)
        elif isinstance(past_key_value, tuple): 
            # Older Transformers / modeling_llama.py (Tuple)
            kv_seq_len += past_key_value[0].shape[-2]

    # 3. Rotary Embeddings
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = llama_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    # 4. Update Cache Content
    if past_key_value is not None:
        if hasattr(past_key_value, "update") and current_layer_idx is not None:
            # New Transformers
            cache_kwargs = {"sin": sin, "cos": cos}
            key_states, value_states = past_key_value.update(key_states, value_states, current_layer_idx, cache_kwargs)
        elif isinstance(past_key_value, tuple):
            # Old Transformers / modeling_llama.py (Manual Cat)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
            past_key_value = (key_states, value_states) if use_cache else None

    # 5. Attention Logits
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    # Masking
    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):


             raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

    # ================= HEAD GUIDE INTERVENTION (Shikra) =================
        # =========================================================================

    # =========================================================================
    if hasattr(self, "aggregation") and q_len == 1:
        img_start = self.img_start_idx
        img_end = self.img_end_idx

        visual_logits = attn_weights[:, :, -1, img_start:img_end]

        text_logits_pre = attn_weights[:, :, -1, :img_start]
        text_logits_post = attn_weights[:, :, -1, img_end:]

        max_vis_val, _ = visual_logits.max(dim=-1, keepdim=True)

        max_text_val = torch.tensor(-1e4, device=attn_weights.device)
        if text_logits_pre.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_pre.max(dim=-1, keepdim=True)[0])
        if text_logits_post.shape[-1] > 0:
            max_text_val = torch.max(max_text_val, text_logits_post.max(dim=-1, keepdim=True)[0])

        # -----------------------------------------------------



        diff = max_text_val - max_vis_val



        threshold_cutoff = 5.0



        base_alpha = self.alpha 





        should_intervene = (diff < threshold_cutoff).float()

        # dynamic_alpha = base_alpha * should_intervene
        dynamic_alpha = base_alpha* should_intervene

        # -----------------------------------------------------
        enhancement_signal = torch.zeros_like(visual_logits)

         # self.previous_attn_weights=attn_weights
        epsilon = 1e-6
        img_start = self.img_start_idx
        img_end   = self.img_end_idx

        attn_img = attn_weights[:, :, :, img_start:img_end]          # (B, H, Q, K_img)
        k_heads = int(self.num_heads * 0.7)
        attn_img, _ = torch.topk(visual_logits.abs(), k=k_heads, dim=1)
        # debug_tensor("attn_img (raw logits)", attn_img)

        attn_img_stat = attn_img.to(torch.float32)    
        attn_img_stat[~torch.isfinite(attn_img_stat)] = 0.0
        # debug_tensor("attn_img_stat (cleaned)", attn_img_stat)

        abs_img   = attn_img_stat.abs()                              # (B, H, Q, K_img)
        # debug_tensor("abs_img", abs_img)
        mean_attn = abs_img.mean(dim=1)                              # (B, Q, K_img)
        # debug_tensor("mean_attn", mean_attn)
        var_attn  = abs_img.var(dim=1, unbiased=False)               # (B, Q, K_img)
        # debug_tensor("var_attn", var_attn)

        var_max = var_attn.max(dim=-1, keepdim=True).values          # (B, Q, 1)
        # debug_tensor("var_max (before clamp)", var_max)
        var_max = torch.clamp(var_max, min=epsilon)
        # debug_tensor("var_max (after clamp)", var_max)
        var_norm = var_attn / var_max                                # (B, Q, K_img)
        # debug_tensor("var_norm", var_norm)
        beta = 0.5

        # consistency_score = mean_attn - (var_norm + beta)            # (B, Q, K_img)
        consistency_score = mean_attn+0.4*(1- var_norm)  
        # print(consistency_score)
        # debug_tensor("consistency_score", consistency_score, max_print=True)

        consistency_sum = consistency_score.sum(dim=-1, keepdim=True)
        # debug_tensor("consistency_sum (before clamp)", consistency_sum)
        consistency_sum = torch.clamp(consistency_sum, min=epsilon)
        # debug_tensor("consistency_sum (after clamp)", consistency_sum)
        weight_img = consistency_score / (consistency_sum + epsilon) # (B, Q, K_img)
        enhancement_signal += consistency_score .unsqueeze(1)    


        if hasattr(self, "window_history") and len(self.window_history) > 0:
            window_size = len(self.window_history)
            curr_total_len = attn_weights.shape[-1]
            history_start_idx = curr_total_len - 1 - window_size
            if history_start_idx >= 0:
                text_window_logits = attn_weights[:, :, -1, history_start_idx : curr_total_len-1]
                best_history_idx = text_window_logits.argmax(dim=-1)
                history_stack = torch.stack(list(self.window_history), dim=0)
                gather_idx = best_history_idx.view(1, bsz, self.num_heads, 1).expand(-1, -1, -1, img_end - img_start)
                retrieved_history = history_stack.gather(0, gather_idx).squeeze(0)

                enhancement_signal += (0.25 * retrieved_history)

        # -----------------------------------------------------
        # Final = Original + (Dynamic_Alpha * Signal)


        img_start_idx = self.img_start_idx
        img_end_idx = self.img_end_idx
        attn_weights[:, :, -1, img_start:img_end] = (
            visual_logits + (dynamic_alpha * enhancement_signal)
        )
    # ====================================================================

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(f"`attn_output` size mismatch.")

    attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

def shikra_head_guide(model, guided_layer_range, aggregation, alpha, img_start_idx, img_end_idx):
    """
    Apply Head Guidance to Shikra model.
    Args:
        model: The ShikraLlamaForCausalLM object (model_manager.llm_model)
    """
    layer_list = guided_layer_range if len(guided_layer_range) == 1 else list(range(guided_layer_range[0], guided_layer_range[1]))
    

    # print(f"Injecting Shikra Head Guide: Layers {layer_list}, Alpha {alpha}, Img Range [{img_start_idx}, {img_end_idx}]")


    # ShikraLlamaForCausalLM -> model (ShikraLlamaModel/LlamaModel) -> layers (ModuleList)
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "layers"): # Fallback if model passed is just the LlamaModel
        layers = model.layers
    else:
        print("Error: Could not find layers in Shikra model structure.")
        return

    for i in layer_list:
        if i >= len(layers):
            continue
            
        target_layer = layers[i].self_attn
        

        target_layer.img_start_idx = img_start_idx
        target_layer.img_end_idx = img_end_idx
        target_layer.aggregation = aggregation
        target_layer.alpha = alpha
        
        # Monkey Patch
        target_layer.forward = types.MethodType(shikra_new_forward, target_layer)
