from typing import Optional
import torch
import torch.nn.functional as F
from diffusers.models.attention import Attention
from einops import rearrange
from torch import nn
import math
from torch.nn.attention import SDPBackend, sdpa_kernel
import qdiff.s2quant.globalvar as globalvar


class EnhanceCogVideoXAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def scaled_dot_product_attention_to_list(self, query, key) -> torch.Tensor:
        # 使用bfloat16减少内存占用
        scale_factor = 1 / math.sqrt(query.size(-1))
        batch_size, num_heads, seq_length, head_dim = query.shape

        # 分块计算注意力权重以减少内存使用
        chunk_size = 2048  # 可以根据显存大小调整
        num_chunks = (seq_length + chunk_size - 1) // chunk_size
        
        attn_weight = torch.zeros(batch_size, seq_length, device=query.device)
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, seq_length)
            
            # 分块计算attention [batch, heads, chunk_size, seq_len]
            chunk_query = query[:, :, start_idx:end_idx]
            chunk_attn = chunk_query @ key.transpose(-2, -1)
            chunk_attn.mul_(scale_factor)
            chunk_attn = torch.softmax(chunk_attn, dim=-1)
            chunk_attn = chunk_attn.sum(dim=1)
            chunk_attn = chunk_attn.sum(dim=1)
            attn_weight += chunk_attn
            
            # 主动清理临时变量
            del chunk_query, chunk_attn
            torch.cuda.empty_cache()
        
        # 获取最小值和最大值（只取值，不要索引）
        tensor_min = torch.min(attn_weight, dim=-1, keepdim=True)[0]
        tensor_max = torch.max(attn_weight, dim=-1, keepdim=True)[0]
        attn_weight.sub_(tensor_min).div_(tensor_max - tensor_min)
        


        # 原地线性变换
        MIN_VALUE, MAX_VALUE = 0.5, 1
        attn_weight.mul_(MAX_VALUE - MIN_VALUE).add_(MIN_VALUE).clamp_(MIN_VALUE, MAX_VALUE)
        
        # 添加到列表
        globalvar.add_attn_map(attn_weight.unsqueeze(-1))
        # import ipdb; ipdb.set_trace()

        raise Exception("stop here")
        
        # 清理内存
        del query, key, tensor_min, tensor_max
        torch.cuda.empty_cache()
        return

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        text_seq_length = encoder_hidden_states.size(1)

        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # Apply RoPE if needed
        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
            if not attn.is_cross_attention:
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

        # ========== get attn map ==========
        self.scaled_dot_product_attention_to_list(query[:, :, text_seq_length:], key[:, :, text_seq_length:])
        # ========== get attn map ==========

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        encoder_hidden_states, hidden_states = hidden_states.split(
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
        )

        return hidden_states, encoder_hidden_states
    