import inspect
import math
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn


from diffusers.models.attention_processor import Attention
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import deprecate, is_torch_xla_available, logging
from diffusers.utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None

class CogVideoXXFormersAttnProcessor2_0:
    r"""
    Processor for implementing memory efficient attention using xFormers for the CogVideoX model.
    It applies rotary embedding on query and key vectors and uses xformers for efficient attention computation.
    """

    def __init__(self, attention_op: Optional[Callable] = None):
        if not is_xformers_available():
            raise ImportError("xformers is required for CogVideoXXFormersAttnProcessor2_0")
        self.attention_op = attention_op

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        text_seq_length = encoder_hidden_states.size(1)
        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
        batch_size, sequence_length, _ = hidden_states.shape

        if attention_mask is not None:
            # Prepare attention mask in the format expected by xformers
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            _, query_tokens, _ = hidden_states.shape
            attention_mask = attention_mask.expand(-1, query_tokens, -1)

        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # Apply RoPE if needed
        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
            if not attn.is_cross_attention:
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

        # Reshape for xformers
        query = query.reshape(batch_size * attn.heads, -1, head_dim).contiguous()
        key = key.reshape(batch_size * attn.heads, -1, head_dim).contiguous()
        value = value.reshape(batch_size * attn.heads, -1, head_dim).contiguous()

        # TODO: Find out why this is needed
        assert value.dtype == torch.bfloat16, f"value.dtype: {value.dtype}"
        query = query.to(value.dtype)
        key = key.to(value.dtype)

        # Apply xformers attention
        hidden_states = xformers.ops.memory_efficient_attention(
            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
        )

        hidden_states = hidden_states.reshape(batch_size, attn.heads, -1, head_dim)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        encoder_hidden_states, hidden_states = hidden_states.split(
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
        )
        return hidden_states, encoder_hidden_states


if __name__ == "__main__":
    import torch
    from diffusers import CogVideoXPipeline
    from diffusers.utils import export_to_video
    prompt = [
        "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea.", 
        "A young woman, with a serene expression, sits at the water's edge, a steaming cup of tea by her side. She is engrossed in her artwork, brush in hand, as she renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through her silver hair, gently billowing her loose-fitting white shirt, while the salty air adds an intangible element to her masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea.",
        "A young man, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
    ]

    pipe = CogVideoXPipeline.from_pretrained(
        "ckpt/CogVideoX-2b",
        torch_dtype=torch.bfloat16
    )
    from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
    pipe.transformer.set_attn_processor(CogVideoXXFormersAttnProcessor2_0())

    pipe.enable_model_cpu_offload()
    pipe.vae.enable_tiling()

    with torch.autocast("cuda", dtype=torch.bfloat16):
        video = pipe(
            prompt=prompt,
            num_videos_per_prompt=1,
            num_inference_steps=10,
            num_frames=16,
            guidance_scale=6,
            generator=torch.Generator(device="cuda").manual_seed(42),
        ).frames

    for i in range(len(video)):
        export_to_video(video[i], f"output_xformers_{i}.mp4", fps=8)

    """
    import torch
    from diffusers.models.attention import Attention

    # Test configuration
    batch_size = 3
    hidden_dim = 512
    num_heads = 8
    head_dim = hidden_dim // num_heads
    
    # Fix: Adjust sequence lengths to match the expected concatenation
    text_seq_length = 4  # Length of encoder_hidden_states
    image_seq_length = 16  # Length of hidden_states
    total_seq_length = text_seq_length + image_seq_length  # Total length after concatenation

    # Create test inputs and move to CUDA
    hidden_states = torch.randn(batch_size, image_seq_length, hidden_dim).cuda()
    encoder_hidden_states = torch.randn(batch_size, text_seq_length, hidden_dim).cuda()
    # Create attention mask
    attention_mask = torch.ones(batch_size, total_seq_length, total_seq_length).cuda()
    attention_mask = torch.triu(attention_mask, diagonal=1)
    image_rotary_emb = None

    # Create attention module and move to CUDA
    attention = Attention(
        query_dim=hidden_dim,
        heads=num_heads,
        dim_head=head_dim,
        cross_attention_dim=hidden_dim,
    ).cuda()

    # Test both processors
    processor1 = CogVideoXAttnProcessor2_0()
    processor2 = CogVideoXXFormersAttnProcessor2_0()

    # Run both processors
    out1 = processor1(attention, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb)
    out2 = processor2(attention, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb)

    # Compare outputs
    print("Output shapes match:", out1[0].shape == out2[0].shape)
    print("Hidden states max difference:", torch.max(torch.abs(out1[0] - out2[0])).item())
    print("Encoder states max difference:", torch.max(torch.abs(out1[1] - out2[1])).item())
    """