from diffusers.models.attention_processor import Attention
from diffusers.models.attention import AttentionModuleMixin
from diffusers import WanPipeline
from diffusers.models.transformers.transformer_wan import WanTransformerBlock
from SVPG.models.wan.attention_processor import WanAttnAdaptiveProcessor
from SVPG.models.wan.sparse_pipeline import WanAdaptiveSparsePipeline
from SVPG.svpg.attn_mask import create_adaptive_mask_state


def replace_wan_attention(
    pipe,
    height,
    width,
    num_frames,
    max_seq_length,
    model_type,
    dense_layers=0,
    dense_timesteps=0,
    warmup_steps=12,
    top_k=10,
    predict_T=10,
    thereshold=1.5e-4,
    use_cuda=True,
    backend="sparse_sageattn",  # "flashinfer" or "sparse_sageattn" or "none"
    block_size=128,
    sparse_type="adaptive",
    predict_all=False,
):
    """
    将CogVideoX pipeline的attention替换为自适应稀疏attention
    
    参数:
        pipe: CogVideoX pipeline实例
        height: 视频高度
        width: 视频宽度
        num_frames: 视频帧数
        dense_layers: 使用dense attention的层数（默认：0，即所有层都用sparse）
        dense_timesteps: 使用dense attention的timestep数（默认：0）
        warmup_steps: warm up阶段的步数（默认：12）
        top_k: 选择最亮的K条线（默认：10）
        use_cuda: 是否使用CUDA加速（默认：True）
        backend: 加速backend，"flashinfer" 或 "sparse_sageattn" 或 "none"（默认："flashinfer"）
        block_size: block大小，用于backend（默认：128）
        sparse_type: 稀疏类型 ("adaptive" 或 "dense")
    
    返回:
        warmup_state: warmup状态字典
    """
    
    # 计算实际的帧数和每帧的token数
    num_frames = 1 + num_frames // (pipe.vae_scale_factor_temporal * pipe.transformer.config.patch_size[0])
    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
    print(f"Mod value (spatial scaling * patch size): {mod_value}")
    print(f"Number of frames after scaling: {num_frames}")
    frame_size = int(height // mod_value) * int(width // mod_value)
    
    # 计算总的video token数量和sequence length
    video_token_num = frame_size * num_frames
    
    # 获取attention head数量
    num_heads = pipe.transformer.config.num_attention_heads if hasattr(pipe.transformer.config, 'num_attention_heads') else 30
    
    # 计算总的head数量（所有层的head数量总和）
    num_layers = len(pipe.transformer.blocks)
    total_heads = num_heads * num_layers
    
    # 创建自适应mask生成器
    print(f"\n{'='*60}")
    print("Initializing Adaptive Mask Generator for HunyuanVideo")
    print(f"{'='*60}")
    print(f"Video dimensions: {height}x{width}, {num_frames} frames")
    print(f"Frame size (tokens per frame): {frame_size}")
    print(f"Total video tokens: {video_token_num}")
    print(f"Max text sequence length: {max_seq_length}")
    print(f"Number of heads per layer: {num_heads}")
    print(f"Number of layers: {num_layers}")
    print(f"Total heads: {total_heads}")
    print(f"Warmup steps: {warmup_steps}")
    print(f"Top-K lines: {top_k}")
    print(f"Dense layers: {dense_layers}")
    print(f"Dense timesteps: {dense_timesteps}")
    print(f"Sparse type: {sparse_type}")

    
    print(f"Backend: {backend}")
    print(f"Block size: {block_size}")
    print(f"{'='*60}\n")
    
    
    # 配置Attention处理器类
    AttnModule = WanAttnAdaptiveProcessor
    AttnModule.dense_block = dense_layers
    AttnModule.dense_timestep = dense_timesteps
    AttnModule.current_step = 0
    AttnModule.warmup_steps = warmup_steps
    AttnModule.top_k = top_k
    AttnModule.use_cuda = use_cuda
    AttnModule.backend = backend
    AttnModule.block_size = block_size
    AttnModule.sparse_type = sparse_type
    
    # 替换所有attention层的processor
    print("Replacing attention processors...")
    for layer_idx, m in enumerate(pipe.transformer.blocks):
        m.attn1.processor.layer_idx = layer_idx
        
    # 然后替换所有Attention模块的processor
    replaced_count = 0
    for _, m in pipe.transformer.named_modules():
        if isinstance(m, WanTransformerBlock):
            layer_idx = m.attn1.processor.layer_idx
            m.attn1.set_processor(AttnModule(layer_idx))
            m.attn1.processor.warmup_state = []
            m.attn1.processor.warmup_state.append(create_adaptive_mask_state(
                    model_type=model_type,
                    num_heads=num_heads,
                    video_token_num=video_token_num,
                    text_token_num=max_seq_length,
                    tokens_per_frame=frame_size,
                    warmup_steps=warmup_steps,
                    top_k=top_k,
                    predict_T=predict_T,
                    thereshold=thereshold,
                    block_size=block_size,
                    predict_all=predict_all,
                    use_block_attention=True,
                    use_cuda=use_cuda
            ))
            m.attn1.processor.warmup_state.append(create_adaptive_mask_state(
                    model_type=model_type,
                    num_heads=num_heads,
                    video_token_num=video_token_num,
                    text_token_num=max_seq_length,
                    tokens_per_frame=frame_size,
                    warmup_steps=warmup_steps,
                    top_k=top_k,
                    predict_T=predict_T,
                    thereshold=thereshold,
                    block_size=block_size,
                    predict_all=predict_all,
                    use_block_attention=True,
                    use_cuda=use_cuda
            ))
            replaced_count += 1
    
    print(f"Replaced {replaced_count} attention processors\n")
    WanPipeline.__call__ = WanAdaptiveSparsePipeline.__call__
    