import math
import torch
import torch.nn.functional as F
from torch import nn
from typing import Optional, Tuple
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb

import torch
import math

import torch
import math

import torch
import math

def boost_text2image_attention(
    attn_weights: torch.Tensor,
    SYS_LEN: int = 35,
    IMG_LEN: int = 576,
    scale_up: float = 4,
    scale_down: float = 0.25,
    scale_text: float = 2,
    scale_sys: float = 1,
) -> torch.Tensor:
    """
    1) 计算每个 patch 的 avg_attn
    2) 取排名前 25% 的 patch，计算它们的质心；同时标记排名后 25% 的 patch
    3) 围绕质心取固定 12×12 矩形区域（144 个 patch ≈ 25%），该区域 → scale_up
    4) 后 25% 的 patch → scale_down
    5) 整体：
         • system token (0:SYS_LEN)           → scale_sys
         • image tokens (SYS_LEN:SYS_LEN+IMG_LEN) → 按上面步骤局部缩放
         • text tokens (SYS_LEN+IMG_LEN:kv_len) → scale_text
    6) 对受影响的 query 行在所有 key 上重新归一化
    """
    bsz, n_head, q_len, kv_len = attn_weights.shape
    if q_len <= SYS_LEN:
        return attn_weights

    img_start = SYS_LEN
    img_end   = SYS_LEN + IMG_LEN
    txt_start = img_end

    # grid & region sizes
    grid_size  = int(math.sqrt(IMG_LEN))                # 24
    assert grid_size * grid_size == IMG_LEN, "IMG_LEN must be a perfect square"
    quart      = IMG_LEN // 4                           # 144
    region_dim = int(math.sqrt(quart))                  # 12
    half       = region_dim // 2                        # 6

    # 选几行 text queries，可改为 range(txt_start, q_len)
    # query_rows = [txt_start + 6, txt_start + 7, -1]
    query_rows = [-1]

    for b in range(bsz):
        # 1) avg_attn → (IMG_LEN,)
        avg_attn   = attn_weights[b, :, query_rows, img_start:img_end].mean(dim=(0,1))
        sorted_idx = torch.argsort(avg_attn, descending=True)

        # 2) 前 & 后 25% idx
        top_idxs    = sorted_idx[:quart]
        bottom_idxs = sorted_idx[-quart:]

        # 质心
        rows = top_idxs // grid_size
        cols = top_idxs % grid_size
        centroid_r = float(rows.float().mean())
        centroid_c = float(cols.float().mean())
        cr = int(round(centroid_r))
        cc = int(round(centroid_c))

        # 12×12 区域起点，边界裁剪
        r0 = max(0, min(cr - half, grid_size - region_dim))
        c0 = max(0, min(cc - half, grid_size - region_dim))

        # 3&4) 构造 image 区域 scale_vec (长度 IMG_LEN)
        scale_img = torch.ones((IMG_LEN,), device=attn_weights.device)
        # 增强集中区域
        for dr in range(region_dim):
            for dc in range(region_dim):
                idx = (r0 + dr) * grid_size + (c0 + dc)
                scale_img[idx] = scale_up
        # 衰减后 25%
        scale_img[bottom_idxs] = scale_down

        # 5) 合并成对整行 attention 的缩放向量 (长度 kv_len)
        full_scale = torch.ones((kv_len,), device=attn_weights.device)
        full_scale[:SYS_LEN] = scale_sys                          # system tokens
        full_scale[img_start:img_end] = scale_img                # image tokens
        full_scale[img_end:kv_len] = scale_text                   # text tokens

        # 广播并应用
        attn_weights[b, :, query_rows, :kv_len] *= full_scale.view(1,1,kv_len)

        # 6) 重新归一化：所有 key 上（system+image+text）
        block = attn_weights[b, :, query_rows, :kv_len]
        norm  = block.sum(dim=-1, keepdim=True).clamp_min(1e-9)
        attn_weights[b, :, query_rows, :kv_len] = block / norm

    return attn_weights



# def boost_text2image_attention(
#     attn_weights: torch.Tensor,
#     SYS_LEN: int = 35,
#     IMG_LEN: int = 576,
#     scale_up: float = 5,
#     scale_down: float = 0.5,      # < 1.0 for suppression
# ) -> torch.Tensor:
#     """
#     Globally rescale text-to-image attention:
#       • top   25 % patches  →  scale_up
#       • mid   25 % patches  →  1
#       • lower 50 % patches  →  scale_down
#     In-place operation; returns the same tensor.
#     """
#     bsz, n_head, q_len, kv_len = attn_weights.shape

#     if q_len <= 1:
#         return attn_weights
    
#     img_start = SYS_LEN
#     img_end   = SYS_LEN + IMG_LEN
#     txt_start = img_end

#     assert kv_len >= img_end, "kv_len too short"
#     # query_rows = range(q_len) if q_len <= txt_start else range(txt_start, q_len)
#     query_rows = [txt_start + 6, txt_start + 7, -1]
#     # query_rows = [-1]

#     top_quarter  = IMG_LEN // 4
#     mid_quarter  = IMG_LEN // 2                          # next 25 %

#     for b in range(bsz):
#         # mean over heads & all text queries → (576,)
#         avg_attn = attn_weights[b, :, query_rows, img_start:img_end].mean(dim=(0, 1))
#         sorted_idx = torch.argsort(avg_attn, descending=True)

#         scale_vec = torch.ones(IMG_LEN, device=attn_weights.device)
#         scale_vec[sorted_idx[:top_quarter]] = scale_up
#         scale_vec[sorted_idx[top_quarter:top_quarter + mid_quarter]] = 1.0
#         scale_vec[sorted_idx[top_quarter + mid_quarter:]] = scale_down
#         scale_vec = scale_vec.view(1, 1, IMG_LEN)        # broadcast

#         attn_weights[b, :, query_rows, img_start:img_end] *= scale_vec

#         # renormalise affected rows
#         rows = attn_weights[b, :, query_rows]
#         row_sum = rows.sum(dim=-1, keepdim=True).clamp_min(1e-9)
#         attn_weights[b, :, query_rows] = rows / row_sum

#     return attn_weights

def compute_spatial_entropy(
    attn_weights: torch.Tensor,
    SYS_LEN: int = 35,
    IMG_LEN: int = 576,
    grid_size: int = 24,
    n_blocks_row: int = 3,
    n_blocks_col: int = 3,
    top_ratio: float = 0.25,
    eps: float = 1e-12
) -> torch.Tensor:
    """
    计算选中 top_ratio 的高注意力 image tokens 的空间熵。

    Args:
        attn_weights: (bsz, n_head, q_len, kv_len) attention tensor.
        SYS_LEN: 序列开头非图像 token 数量。
        IMG_LEN: 图像 token 数量（grid_size²）。
        grid_size: 图像的空间边长（grid_size × grid_size = IMG_LEN）。
        n_blocks_row: 沿行划分的大块数。
        n_blocks_col: 沿列划分的大块数。
        top_ratio: 挑选 attention 最高的比例（如 0.25 表示 25%）。
        eps: 防止 log(0) 的小常数。

    Returns:
        Tensor of shape (bsz,)，每个样本对应的空间熵。
    """
    bsz, n_head, q_len, kv_len = attn_weights.shape
    device = attn_weights.device

    img_slice = slice(SYS_LEN, SYS_LEN + IMG_LEN)
    # 这里示例用最后一个文本 query，若需所有 query，可用 query_slice = slice(SYS_LEN+IMG_LEN, q_len)
    last_query_idx = -1

    # 1. 计算 avg_attn → (bsz, IMG_LEN)
    #    从所有 head、最后一个 query 上平均
    # avg_attn = attn_weights[:, :, last_query_idx, img_slice].mean(dim=1)  # (bsz, IMG_LEN)
    avg_attn = attn_weights[:, 24, last_query_idx, img_slice] # (bsz, IMG_LEN)

    # 2. 对每个样本选 topk
    k = int(IMG_LEN * top_ratio)
    # mask 用于表示哪些位置被选中
    topk_mask = torch.zeros_like(avg_attn, dtype=torch.bool, device=device)  # (bsz, IMG_LEN)
    topk_idx = torch.topk(avg_attn, k=k, dim=1, largest=True).indices       # (bsz, k)
    # 把这些位置标为 True
    topk_mask.scatter_(1, topk_idx, True)

    # 3. 将 mask reshape 为 (bsz, grid_size, grid_size)
    grid_mask = topk_mask.view(bsz, grid_size, grid_size).long()  # long 方便 sum

    # 4. 按大块统计选中数目
    block_h = grid_size // n_blocks_row
    block_w = grid_size // n_blocks_col
    entropies = torch.empty(bsz, device=device)

    for b in range(bsz):
        # 累积每块内的 token 数
        counts = []
        for i in range(n_blocks_row):
            for j in range(n_blocks_col):
                r0, r1 = i * block_h, (i + 1) * block_h
                c0, c1 = j * block_w, (j + 1) * block_w
                cnt = grid_mask[b, r0:r1, c0:c1].sum().float()
                counts.append(cnt)
        counts = torch.stack(counts)  # (n_blocks_row*n_blocks_col,)
        # 5. 归一化为概率分布
        probs = counts / (counts.sum() + eps)
        # 6. 计算熵
        entropies[b] = - (probs * torch.log(probs + eps)).sum()

    return entropies


class AttnAdapter(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.spatial_entropy = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
            query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        # EHANCE VISYAL INPUT -----------------------
        if q_len > 1:
            spatial_entropy = compute_spatial_entropy(attn_weights)
            self.spatial_entropy = spatial_entropy
        attn_weights = boost_text2image_attention(attn_weights)
        # -------------------------------------------

        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value