"""
Memory-Augmented Attention for Mem-VLM

This module implements the memory-augmented attention mechanism that:
1. Stores evicted KV states to L2 memory bank
2. Retrieves relevant memories based on query similarity
3. Concatenates retrieved memories with current KV states for attention

Key Innovation: Training-free retrieval-augmented attention
"""

import math
from typing import Optional, Tuple, List, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F

from lmms_eval.caching.memory_bank import LayerMemoryBank


class MemoryAugmentedAttentionWrapper:
    """
    Wrapper class that adds memory-augmented attention capability to existing attention layers.
    
    This wrapper intercepts the forward pass of attention layers to:
    1. Store old KV states to memory when cache exceeds threshold
    2. Retrieve relevant memories before attention computation
    3. Concatenate retrieved memories with current KV for attention
    
    This is a non-invasive approach that doesn't require modifying the original model code.
    """
    
    def __init__(
        self,
        memory_bank: LayerMemoryBank,
        max_cache_length: int = 2048,
        memory_top_k: int = 5,
        memory_chunk_size: int = 256,
        memory_pooling: str = 'mean',
        sink_tokens: int = 4,  # 保留的 sink tokens 数量
        local_tokens: int = 256,  # 保留的局部 tokens 数量
    ):
        """
        初始化记忆增强注意力包装器
        
        Args:
            memory_bank: L2 显式长期记忆库
            max_cache_length: KV Cache 最大长度阈值
            memory_top_k: 检索的 Top-K 记忆槽数量
            memory_chunk_size: 每个记忆槽对应的 token 数量
            memory_pooling: 池化方式
            sink_tokens: 保留的 sink tokens 数量 (attention sink)
            local_tokens: 保留的局部 tokens 数量
        """
        self.memory_bank = memory_bank
        self.max_cache_length = max_cache_length
        self.memory_top_k = memory_top_k
        self.memory_chunk_size = memory_chunk_size
        self.memory_pooling = memory_pooling
        self.sink_tokens = sink_tokens
        self.local_tokens = local_tokens
        
        # 统计信息
        self.total_evictions = 0
        self.total_retrievals = 0
        
    def process_kv_cache(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        query_states: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """
        处理 KV Cache：执行记忆存储、检索和拼接
        
        Args:
            layer_idx: 当前层索引
            key_states: Key 张量 [batch_size, num_kv_heads, seq_len, head_dim]
            value_states: Value 张量 [batch_size, num_kv_heads, seq_len, head_dim]
            query_states: Query 张量 [batch_size, num_heads, query_len, head_dim]
            
        Returns:
            new_key_states: 更新后的 Key 张量 (包含检索到的记忆)
            new_value_states: 更新后的 Value 张量
            memory_mask: 记忆部分的 attention mask (如果需要)
        """
        bsz, num_kv_heads, seq_len, head_dim = key_states.shape
        
        # 步骤 1: 检查是否需要将旧的 KV 存入记忆
        if seq_len > self.max_cache_length:
            # 计算需要移出的 tokens 数量
            evict_len = seq_len - self.max_cache_length + self.memory_chunk_size
            
            # 保留 sink tokens，移出中间部分
            if evict_len > self.sink_tokens:
                evict_start = self.sink_tokens
                evict_end = evict_start + evict_len - self.sink_tokens
                
                # 提取要存入记忆的 KV states
                evict_keys = key_states[:, :, evict_start:evict_end, :]
                evict_values = value_states[:, :, evict_start:evict_end, :]
                
                # 存入记忆库 (按 chunk 分块存储)
                self._store_to_memory(layer_idx, evict_keys, evict_values)
                
                # 从当前 KV Cache 中移除已存储的部分
                # 保留: [0:sink_tokens] + [evict_end:]
                key_states = torch.cat([
                    key_states[:, :, :self.sink_tokens, :],
                    key_states[:, :, evict_end:, :]
                ], dim=2)
                value_states = torch.cat([
                    value_states[:, :, :self.sink_tokens, :],
                    value_states[:, :, evict_end:, :]
                ], dim=2)
                
                self.total_evictions += 1
        
        # 步骤 2: 从记忆库检索相关记忆
        retrieved_keys, retrieved_values = self._retrieve_from_memory(
            layer_idx, query_states
        )
        
        # 步骤 3: 将检索到的记忆拼接到 KV states
        if retrieved_keys is not None and retrieved_keys.shape[2] > 0:
            # retrieved_keys: [num_kv_heads, top_k, head_dim]
            # 需要扩展到 batch 维度
            retrieved_keys = retrieved_keys.unsqueeze(0).expand(bsz, -1, -1, -1)
            retrieved_values = retrieved_values.unsqueeze(0).expand(bsz, -1, -1, -1)
            
            # 转换数据类型
            retrieved_keys = retrieved_keys.to(key_states.device, dtype=key_states.dtype)
            retrieved_values = retrieved_values.to(value_states.device, dtype=value_states.dtype)
            
            # 拼接: [retrieved_memories, sink_tokens, rest_tokens]
            # 检索到的记忆放在最前面
            key_states = torch.cat([retrieved_keys, key_states], dim=2)
            value_states = torch.cat([retrieved_values, value_states], dim=2)
            
            self.total_retrievals += 1
            
            # 创建记忆部分的 mask (全部可见)
            memory_len = retrieved_keys.shape[2]
            memory_mask = torch.zeros(
                bsz, 1, query_states.shape[2], memory_len,
                device=key_states.device, dtype=key_states.dtype
            )
        else:
            memory_mask = None
            
        return key_states, value_states, memory_mask
    
    def _store_to_memory(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
    ):
        """将 KV states 分块存入记忆库"""
        # key_states: [batch_size, num_kv_heads, seq_len, head_dim]
        # 转换为 [num_kv_heads, seq_len, head_dim]
        keys = key_states[0]  # 取第一个 batch
        values = value_states[0]
        
        seq_len = keys.shape[1]
        
        # 按 chunk 分块存储
        for start in range(0, seq_len, self.memory_chunk_size):
            end = min(start + self.memory_chunk_size, seq_len)
            key_chunk = keys[:, start:end, :]
            value_chunk = values[:, start:end, :]
            
            self.memory_bank.add_memory(
                layer_idx=layer_idx,
                key_states=key_chunk,
                value_states=value_chunk,
                pooling=self.memory_pooling,
            )
    
    def _retrieve_from_memory(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
        """从记忆库检索相关记忆"""
        if self.memory_bank[layer_idx].is_empty:
            return None, None
            
        # query_states: [batch_size, num_heads, query_len, head_dim]
        # 使用最后一个 token 的 query 进行检索
        query = query_states[0, :, -1:, :]  # [num_heads, 1, head_dim]
        
        # 如果 num_heads != num_kv_heads，需要适配
        num_kv_heads = self.memory_bank[layer_idx].num_heads
        if query.shape[0] != num_kv_heads:
            # 取对应的 kv_heads 的 query (假设是 GQA)
            group_size = query.shape[0] // num_kv_heads
            query = query[::group_size, :, :]  # [num_kv_heads, 1, head_dim]
        
        retrieved_keys, retrieved_values = self.memory_bank.retrieve(
            layer_idx=layer_idx,
            query_state=query,
            top_k=self.memory_top_k,
        )
        
        return retrieved_keys, retrieved_values
    
    def get_stats(self) -> dict:
        """获取统计信息"""
        return {
            "total_evictions": self.total_evictions,
            "total_retrievals": self.total_retrievals,
            "total_memories": self.memory_bank.total_memories,
        }
    
    def reset_stats(self):
        """重置统计信息"""
        self.total_evictions = 0
        self.total_retrievals = 0


def create_memory_augmented_forward(
    original_forward: Callable,
    memory_wrapper: MemoryAugmentedAttentionWrapper,
    layer_idx: int,
):
    """
    创建一个带记忆增强的 forward 函数
    
    这个函数会包装原始的 attention forward，在计算前注入检索到的记忆
    
    Args:
        original_forward: 原始的 forward 函数
        memory_wrapper: 记忆增强包装器
        layer_idx: 层索引
    """
    def memory_augmented_forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs,
    ):
        # 调用原始 forward (这里我们不直接修改，而是通过 hook 方式处理)
        # 实际的记忆增强逻辑在下面的 hook 函数中实现
        return original_forward(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
    
    return memory_augmented_forward


class MemoryAugmentedQwen2VLAttention(nn.Module):
    """
    带记忆增强的 Qwen2VL Attention 模块
    
    这个类继承并扩展原始的 Qwen2VLAttention，添加记忆检索和注入功能
    """
    
    def __init__(
        self,
        original_attention: nn.Module,
        memory_bank: LayerMemoryBank,
        layer_idx: int,
        max_cache_length: int = 2048,
        memory_top_k: int = 5,
        memory_chunk_size: int = 256,
        memory_pooling: str = 'mean',
        sink_tokens: int = 4,
        local_tokens: int = 256,
        enable_memory: bool = True,
    ):
        super().__init__()
        
        # 保存原始 attention 模块
        self.original_attention = original_attention
        self.layer_idx = layer_idx
        
        # 记忆相关配置
        self.memory_bank = memory_bank
        self.max_cache_length = max_cache_length
        self.memory_top_k = memory_top_k
        self.memory_chunk_size = memory_chunk_size
        self.memory_pooling = memory_pooling
        self.sink_tokens = sink_tokens
        self.local_tokens = local_tokens
        self.enable_memory = enable_memory
        
        # 从原始 attention 获取配置
        self.num_heads = original_attention.num_heads
        self.num_kv_heads = original_attention.num_key_value_heads
        self.head_dim = original_attention.head_dim
        self.hidden_size = original_attention.hidden_size
        self.num_key_value_groups = original_attention.num_key_value_groups
        
        # 共享原始模块的投影层
        self.q_proj = original_attention.q_proj
        self.k_proj = original_attention.k_proj
        self.v_proj = original_attention.v_proj
        self.o_proj = original_attention.o_proj
        self.rotary_emb = original_attention.rotary_emb
        
        # 继承其他属性
        self.config = original_attention.config
        self.attention_dropout = original_attention.attention_dropout
        self.rope_scaling = original_attention.rope_scaling
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        custom_kv = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """
        带记忆增强的 forward 方法
        """
        bsz, q_len, _ = hidden_states.size()
        
        # 计算 Q, K, V
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # 应用 RoPE
        cos, sin = position_embeddings
        from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
        query_states, key_states = apply_multimodal_rotary_pos_emb(
            query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
        )
        
        # 更新 KV Cache
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )
        
        # ========== 记忆增强逻辑 ==========
        memory_mask = None
        if self.enable_memory:
            key_states, value_states, memory_mask = self._process_memory(
                key_states, value_states, query_states
            )
        # ================================
        
        # 处理 custom_kv (如果有)
        if custom_kv is not None:
            bs = query_states.shape[0]
            custom_k = custom_kv.key_cache[self.layer_idx].expand(bs, -1, -1, -1)
            custom_v = custom_kv.value_cache[self.layer_idx].expand(bs, -1, -1, -1)
            key_states = torch.cat([custom_k, key_states], dim=-2).contiguous()
            value_states = torch.cat([custom_v, value_states], dim=-2).contiguous()
            prefix_attn = torch.zeros(
                (bs, 1, attention_mask.shape[-2], custom_k.shape[-2])
            ).to(attention_mask.device).type_as(attention_mask)
            attention_mask = torch.cat([prefix_attn, attention_mask], dim=-1)
        
        # 处理记忆 mask
        if memory_mask is not None:
            attention_mask = torch.cat([memory_mask, attention_mask], dim=-1)
        
        # Repeat KV for GQA
        from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        
        # 计算 Attention
        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
        
        # 确保张量连续
        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()
        
        is_causal = True if causal_mask is None and q_len > 1 else False
        
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        
        return attn_output, None, past_key_value
    
    def _process_memory(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        query_states: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """
        处理记忆：存储、检索和拼接
        
        Args:
            key_states: [batch_size, num_kv_heads, seq_len, head_dim]
            value_states: [batch_size, num_kv_heads, seq_len, head_dim]
            query_states: [batch_size, num_heads, query_len, head_dim]
        """
        bsz, num_kv_heads, seq_len, head_dim = key_states.shape
        
        # 步骤 1: 检查是否需要将旧 KV 存入记忆
        if seq_len > self.max_cache_length:
            evict_len = seq_len - self.max_cache_length + self.memory_chunk_size
            
            if evict_len > self.sink_tokens:
                evict_start = self.sink_tokens
                evict_end = min(evict_start + evict_len, seq_len - self.local_tokens)
                
                if evict_end > evict_start:
                    # 存入记忆
                    self._store_to_memory(
                        key_states[:, :, evict_start:evict_end, :],
                        value_states[:, :, evict_start:evict_end, :]
                    )
                    
                    # 移除已存储的部分
                    key_states = torch.cat([
                        key_states[:, :, :self.sink_tokens, :],
                        key_states[:, :, evict_end:, :]
                    ], dim=2)
                    value_states = torch.cat([
                        value_states[:, :, :self.sink_tokens, :],
                        value_states[:, :, evict_end:, :]
                    ], dim=2)
        
        # 步骤 2: 检索相关记忆
        retrieved_keys, retrieved_values = self._retrieve_from_memory(query_states)
        
        # 步骤 3: 拼接检索到的记忆
        memory_mask = None
        if retrieved_keys is not None and retrieved_keys.shape[1] > 0:
            # retrieved: [num_kv_heads, top_k, head_dim] -> [bsz, num_kv_heads, top_k, head_dim]
            retrieved_keys = retrieved_keys.unsqueeze(0).expand(bsz, -1, -1, -1)
            retrieved_values = retrieved_values.unsqueeze(0).expand(bsz, -1, -1, -1)
            
            retrieved_keys = retrieved_keys.to(key_states.device, dtype=key_states.dtype)
            retrieved_values = retrieved_values.to(value_states.device, dtype=value_states.dtype)
            
            # 拼接到 KV 前面
            key_states = torch.cat([retrieved_keys, key_states], dim=2)
            value_states = torch.cat([retrieved_values, value_states], dim=2)
            
            # 创建记忆 mask (全部可见，值为 0)
            memory_len = retrieved_keys.shape[2]
            query_len = query_states.shape[2]
            memory_mask = torch.zeros(
                bsz, 1, query_len, memory_len,
                device=key_states.device, dtype=key_states.dtype
            )
        
        return key_states, value_states, memory_mask
    
    def _store_to_memory(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
    ):
        """存入记忆库"""
        keys = key_states[0]  # [num_kv_heads, seq_len, head_dim]
        values = value_states[0]
        
        seq_len = keys.shape[1]
        for start in range(0, seq_len, self.memory_chunk_size):
            end = min(start + self.memory_chunk_size, seq_len)
            self.memory_bank.add_memory(
                layer_idx=self.layer_idx,
                key_states=keys[:, start:end, :],
                value_states=values[:, start:end, :],
                pooling=self.memory_pooling,
            )
    
    def _retrieve_from_memory(
        self,
        query_states: torch.Tensor,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
        """从记忆库检索"""
        if self.memory_bank[self.layer_idx].is_empty:
            return None, None
        
        # 使用最后一个 token 的 query
        # query_states: [bsz, num_heads, query_len, head_dim]
        query = query_states[0, :, -1:, :]  # [num_heads, 1, head_dim]
        
        # GQA 适配
        num_kv_heads = self.memory_bank[self.layer_idx].num_heads
        if query.shape[0] != num_kv_heads:
            group_size = query.shape[0] // num_kv_heads
            query = query[::group_size, :, :]
        
        return self.memory_bank.retrieve(
            layer_idx=self.layer_idx,
            query_state=query,
            top_k=self.memory_top_k,
        )


def patch_model_with_memory(
    model,
    memory_bank: LayerMemoryBank,
    max_cache_length: int = 2048,
    memory_top_k: int = 5,
    memory_chunk_size: int = 256,
    memory_pooling: str = 'mean',
    sink_tokens: int = 4,
    local_tokens: int = 256,
):
    """
    将现有模型的 Attention 层替换为带记忆增强的版本
    
    Args:
        model: Qwen2VLForConditionalGeneration 模型
        memory_bank: 记忆库
        其他参数: 记忆配置
        
    Returns:
        修改后的模型
    """
    # 遍历所有 decoder layers
    for layer_idx, layer in enumerate(model.model.layers):
        original_attention = layer.self_attn
        
        # 创建带记忆的 attention
        memory_attention = MemoryAugmentedQwen2VLAttention(
            original_attention=original_attention,
            memory_bank=memory_bank,
            layer_idx=layer_idx,
            max_cache_length=max_cache_length,
            memory_top_k=memory_top_k,
            memory_chunk_size=memory_chunk_size,
            memory_pooling=memory_pooling,
            sink_tokens=sink_tokens,
            local_tokens=local_tokens,
        )
        
        # 替换 attention 层
        layer.self_attn = memory_attention
    
    return model
