"""
CogVideoX Inference with Adaptive Mask Generation

初始化和配置自适应稀疏attention（使用attn_mask2.py的函数式接口）
"""

import torch
from diffusers import CogVideoXPipeline
from diffusers.models.attention_processor import Attention
from SVPG.models.cogvideox.attention_processor import CogVideoXAttnAdaptiveProcessor2_0
from SVPG.models.cogvideox.sparse_pipeline import CogVideoXAdaptiveSparsePipeline
from SVPG.svpg.attn_mask import create_adaptive_mask_state
from typing import Dict, List, Tuple, Optional

def reset_peak_gpu_stats(device: Optional[torch.device] = None):
    if torch.cuda.is_available():
        if device is None:
            torch.cuda.reset_peak_memory_stats()
        else:
            torch.cuda.reset_peak_memory_stats(device.index)

def print_peak_gpu_stats(prefix: str = "", device: Optional[torch.device] = None):
    if not torch.cuda.is_available():
        print(prefix + " GPU not available")
        return
    # 确保所有 kernel 完成
    torch.cuda.synchronize()
    if device is None:
        alloc = torch.cuda.max_memory_allocated()
        reserved = torch.cuda.max_memory_reserved()
    else:
        alloc = torch.cuda.max_memory_allocated(device.index)
        reserved = torch.cuda.max_memory_reserved(device.index)
    def to_gib(x): return x / (1024**3)
    print(f"{prefix} peak_alloc={to_gib(alloc):.3f} GiB, peak_reserved={to_gib(reserved):.3f} GiB")

def replace_cogvideox_attention(
    pipe,
    height,
    width,
    num_frames,
    max_seq_length,
    model_type,
    dense_layers=0,
    dense_timesteps=0,
    warmup_steps=12,
    top_k=10,
    predict_T=10,
    thereshold=1.5e-4,
    use_cuda=True,
    backend="sparse_sageattn",  # "flashinfer" or "sparse_sageattn" or "none"
    block_size=128,
    sparse_type="adaptive",
    predict_all=False,
):
    """
    将CogVideoX pipeline的attention替换为自适应稀疏attention
    
    参数:
        pipe: CogVideoX pipeline实例
        height: 视频高度
        width: 视频宽度
        num_frames: 视频帧数
        dense_layers: 使用dense attention的层数（默认：0，即所有层都用sparse）
        dense_timesteps: 使用dense attention的timestep数（默认：0）
        warmup_steps: warm up阶段的步数（默认：12）
        top_k: 选择最亮的K条线（默认：10）
        use_cuda: 是否使用CUDA加速（默认：True）
        backend: 加速backend，"flashinfer" 或 "sparse_sageattn" 或 "none"（默认："flashinfer"）
        block_size: block大小，用于backend（默认：128）
        sparse_type: 稀疏类型 ("adaptive" 或 "dense")
    
    返回:
        warmup_state: warmup状态字典
    """
    
    # 计算实际的帧数和每帧的token数
    num_frames = 1 + (num_frames - 1) // (pipe.vae_scale_factor_temporal)
    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size
    frame_size = int(height // mod_value) * int(width // mod_value)
    
    # 计算总的video token数量和sequence length
    video_token_num = frame_size * num_frames
    
    # 获取attention head数量
    num_heads = pipe.transformer.config.num_attention_heads if hasattr(pipe.transformer.config, 'num_attention_heads') else 30
    
    # 计算总的head数量（所有层的head数量总和）
    num_layers = len(pipe.transformer.transformer_blocks)
    total_heads = num_heads * num_layers
    
    # 创建自适应mask生成器
    print(f"\n{'='*60}")
    print("Initializing Adaptive Mask Generator for CogVideoX")
    print(f"{'='*60}")
    print(f"Video dimensions: {height}x{width}, {num_frames} frames")
    print(f"Frame size (tokens per frame): {frame_size}")
    print(f"Total video tokens: {video_token_num}")
    print(f"Max text sequence length: {max_seq_length}")
    print(f"Number of heads per layer: {num_heads}")
    print(f"Number of layers: {num_layers}")
    print(f"Total heads: {total_heads}")
    print(f"Warmup steps: {warmup_steps}")
    print(f"Top-K lines: {top_k}")
    print(f"Dense layers: {dense_layers}")
    print(f"Dense timesteps: {dense_timesteps}")
    print(f"Sparse type: {sparse_type}")

    
    print(f"Backend: {backend}")
    print(f"Block size: {block_size}")
    print(f"{'='*60}\n")
    
    
    # 配置Attention处理器类
    AttnModule = CogVideoXAttnAdaptiveProcessor2_0
    AttnModule.dense_block = dense_layers
    AttnModule.dense_timestep = dense_timesteps
    AttnModule.warmup_steps = warmup_steps
    AttnModule.top_k = top_k
    AttnModule.use_cuda = use_cuda
    AttnModule.backend = backend
    AttnModule.block_size = block_size
    AttnModule.sparse_type = sparse_type
    
    # 替换所有attention层的processor
    print("Replacing attention processors...")
    # 首先设置每个transformer block的layer_idx
    for layer_idx, m in enumerate(pipe.transformer.transformer_blocks):
        if hasattr(m, 'attn1') and hasattr(m.attn1, 'processor'):
            m.attn1.processor.layer_idx = layer_idx
    # 然后替换所有Attention模块的processor
    replaced_count = 0
    for _, m in pipe.transformer.named_modules():
        if isinstance(m, Attention) and hasattr(m.processor, "layer_idx"):
            layer_idx = m.processor.layer_idx
            m.set_processor(AttnModule(layer_idx))
            # 初始化warmup_state
            m.processor.warmup_state = []
            m.processor.warmup_state.append(create_adaptive_mask_state(
                    model_type=model_type,
                    num_heads=num_heads,
                    video_token_num=video_token_num,
                    text_token_num=max_seq_length,
                    tokens_per_frame=frame_size,
                    warmup_steps=warmup_steps,
                    top_k=top_k,
                    predict_T=predict_T,
                    thereshold=thereshold,
                    block_size=block_size,
                    predict_all=predict_all,
                    use_cuda=use_cuda
                )
            )
            m.processor.warmup_state.append(create_adaptive_mask_state(
                    model_type=model_type,
                    num_heads=num_heads,
                    video_token_num=video_token_num,
                    text_token_num=max_seq_length,
                    tokens_per_frame=frame_size,
                    warmup_steps=warmup_steps,
                    top_k=top_k,
                    predict_T=predict_T,
                    thereshold=thereshold,
                    block_size=block_size,
                    predict_all=predict_all,
                    use_cuda=use_cuda
                )
            )
            replaced_count += 1
            reset_peak_gpu_stats()
            print_peak_gpu_stats(prefix="After setting layer_idx:", device=torch.device("cuda"))
    print(f"Replaced {replaced_count} attention processors\n")
    CogVideoXPipeline.__call__ = CogVideoXAdaptiveSparsePipeline.__call__
    

def get_adaptive_mask_statistics(warmup_state: dict):
    """
    获取自适应mask生成器的统计信息
    
    参数:
        warmup_state: warmup状态字典
    
    返回:
        统计信息字典
    """
    stats = {
        'num_heads': warmup_state['num_heads'],
        'map_size': warmup_state['seq_len'],
        'warmup_steps': warmup_state['warmup_steps'],
        'top_k': warmup_state['top_k'],
        'num_blocks': len(warmup_state['block_configs']) if warmup_state['block_configs'] else 0,
        'heads_completed_warmup': sum(
            1 for head_idx in range(warmup_state['num_heads'])
            if head_idx in warmup_state['current_steps'] and 
               warmup_state['current_steps'][head_idx] >= warmup_state['warmup_steps']
        )
    }
    
    return stats


def print_adaptive_mask_statistics(warmup_state: dict):
    """
    打印自适应mask生成器的统计信息
    
    参数:
        warmup_state: warmup状态字典
    """
    stats = get_adaptive_mask_statistics(warmup_state)
    
    print(f"\n{'='*60}")
    print("Adaptive Mask Generator Statistics")
    print(f"{'='*60}")
    print(f"Number of heads: {stats['num_heads']}")
    print(f"Map size: {stats['map_size']}x{stats['map_size']}")
    print(f"Warmup steps: {stats['warmup_steps']}")
    print(f"Top-K lines: {stats['top_k']}")
    print(f"Number of blocks: {stats['num_blocks']}")
    print(f"Heads completed warmup: {stats['heads_completed_warmup']}/{stats['num_heads']}")
    print(f"{'='*60}\n")


# 示例用法
if __name__ == "__main__":
    print("""
    示例用法:
    
    from diffusers import CogVideoXPipeline
    from radial_attn.models.cogvideox1.inference import replace_cogvideox_adaptive_attention
    
    # 加载pipeline
    pipe = CogVideoXPipeline.from_pretrained(
        "THUDM/CogVideoX-2b",
        torch_dtype=torch.float16
    )
    pipe = pipe.to("cuda")
    
    # 替换为自适应稀疏attention
    warmup_state = replace_cogvideox_adaptive_attention(
        pipe=pipe,
        height=480,
        width=720,
        num_frames=49,
        dense_layers=0,
        dense_timesteps=0,
        warmup_steps=12,
        top_k=10,
        use_cuda=True,
        backend="flashinfer",  # or "sparse_sageattn"
        block_size=128,
        sparse_type="adaptive"
    )
    
    # 生成视频
    video = pipe(
        prompt="A beautiful sunset over the ocean",
        num_inference_steps=50,
        guidance_scale=6.0
    ).frames
    
    # 查看统计信息
    print_adaptive_mask_statistics(warmup_state)
    """)