import torch
import argparse
import torch.nn as nn
import numpy as np
from fvcore.nn import FlopCountAnalysis, flop_count_table

# 确保导入了两个版本的Decoder
from nn_module.decoder_module import PointWiseDecoder2D, PointWiseDecoder2D_Adaptive

def get_args():
    parser = argparse.ArgumentParser("Decoder FLOPs Counter for NS2D Task")
    
    # --- 输入数据形状参数 ---
    parser.add_argument('--batch_size', type=int, default=1, help="批处理大小")
    parser.add_argument('--resolution', type=int, default=64, help="空间分辨率 (例如 64x64)")
    parser.add_argument('--out_seq_len', type=int, default=20, help="Rollout 的总步数")
    parser.add_argument('--out_step', type=int, default=1, help="每次 rollout 预测的块大小")
    
    # --- 模型架构参数 ---
    parser.add_argument('--out_seq_emb_dim', type=int, default=192, help="来自Encoder的隐状态z的维度")
    parser.add_argument('--decoder_emb_dim', type=int, default=384, help="Decoder的工作维度")
    parser.add_argument('--out_channels', type=int, default=1, help="输出通道数")
    parser.add_argument('--propagator_depth', type=int, default=10, help="Propagator的深度")
    
    # --- 自适应专属参数 ---
    parser.add_argument('--capacity_ratios', nargs='+', type=float, required=True, help="自适应Decoder的容量比例列表")
    
    return parser.parse_args()

class GetEmbeddingWrapper(nn.Module):
    def __init__(self, decoder):
        super().__init__()
        # 将原始decoder的所有相关子模块“借”过来
        self.get_embedding = decoder.get_embedding

    def forward(self, z, pos1, pos2):
        return self.get_embedding(z, pos1, pos2)

# class PropagateWrapper(nn.Module):
#     def __init__(self, decoder):
#         super().__init__()
#         self.propagate = decoder.propagate

#     def forward(self, h, pos):
#         return self.propagate(h, pos)

# class DecodeWrapper(nn.Module):
#     def __init__(self, decoder):
#         super().__init__()
#         self.decode = decoder.decode

#     def forward(self, h):
#         return self.decode(h)
class PropagateWrapper(nn.Module):
    """一个包装器，用于独立分析 decoder.propagate 方法"""
    def __init__(self, decoder):
        super().__init__()
        self.propagator = decoder.propagator
    
    def forward(self, z, pos):
        # 复现原始的 propagate 逻辑
        for layer in self.propagator:
            norm_fn, ffn = layer
            z = ffn(torch.cat((norm_fn(z), pos), dim=-1)) + z
        return z
class DecodeWrapper(nn.Module):
    """
    一个包装器，用于独立分析 decoder.decode 方法
    [已修复] 现在它正确地只依赖于 decoder.to_out 模块
    """
    def __init__(self, decoder):
        super().__init__()
        # decode 方法的核心就是 to_out 模块
        self.to_out = decoder.to_out
    
    def forward(self, z):
        # 直接调用 to_out，完美模拟原始的 decode 方法
        return self.to_out(z)


def analyze_decoder_flops(decoder, dummy_inputs, rollout_steps, description, device):
    """一个辅助函数，用于分析一个Decoder的FLOPs"""
    
    print(f"\n--- 正在分析: {description} ---")
    
    dummy_z, dummy_pos, dummy_h = dummy_inputs
    
    total_flops = 0.0
    
    if isinstance(decoder, PointWiseDecoder2D):
        # --- [核心修复 2] 使用了正确的 DecodeWrapper ---
        
        # a. 手动分解并分析 get_embedding (保持不变)
        flops_coord_proj = FlopCountAnalysis(decoder.coordinate_projection, (dummy_pos,)).total()
        dummy_x_after_coord_proj = torch.randn(
            dummy_pos.shape[0], dummy_pos.shape[1], decoder.latent_channels//2
        ).to(device).requires_grad_(False)
        
        flops_cross_attn = FlopCountAnalysis(
            decoder.decoding_transformer, 
            (dummy_x_after_coord_proj, dummy_z, dummy_pos, dummy_pos)
        ).total()

        # --- [核心修复 2] ---
        dummy_after_cross_attn = torch.randn(
            dummy_z.shape[0], dummy_z.shape[1], decoder.latent_channels//2
        ).to(device).requires_grad_(False)
        flops_expand_feat = FlopCountAnalysis(decoder.expand_feat, (dummy_after_cross_attn,)).total()
        flops_get_emb = flops_coord_proj + flops_cross_attn + flops_expand_feat
        print(f"  > (合计) get_embedding FLOPs: {flops_get_emb / 1e9:.4f} GFLOPs")

        # b. 使用包装器来分析 propagate 和 decode
        prop_wrapper = PropagateWrapper(decoder)
        decode_wrapper = DecodeWrapper(decoder)
        
        flops_prop = FlopCountAnalysis(prop_wrapper, (dummy_h, dummy_pos)).total()
        flops_decode = FlopCountAnalysis(decode_wrapper, (dummy_h,)).total()
        
        print(f"  > 单次 propagate FLOPs:         {flops_prop / 1e9:.4f} GFLOPs")
        print(f"  > 单次 decode FLOPs:           {flops_decode / 1e9:.4f} GFLOPs")
        
        total_flops = flops_get_emb + (flops_prop + flops_decode) * rollout_steps
    
    elif isinstance(decoder, PointWiseDecoder2D_Adaptive):
        # --- [核心修复] ---
        # 我们需要修改 Wrapper，使其只返回 JIT 可追踪的输出
        
        class RolloutWrapper(nn.Module):
            def __init__(self, decoder):
                super().__init__()
                self.decoder = decoder
            
            def forward(self, z, pos, steps, pos_in):
                # 调用原始的 rollout 方法
                output = self.decoder.rollout(z, pos, steps, pos_in)
                
                # ！！！关键：只返回第一个元素（预测张量）！！！
                # 我们在这里“拦截”并丢弃了 JIT 无法处理的 diagnostics
                return output[0] 
        
        rollout_wrapper = RolloutWrapper(decoder)
        
        # 现在，传递给 FlopCountAnalysis 的模块只返回一个张量，JIT可以成功追踪
        flops_analyzer = FlopCountAnalysis(rollout_wrapper, (dummy_z, dummy_pos, args.out_seq_len, dummy_pos))
        
        print("详细FLOPs分布:")
        print(flop_count_table(flops_analyzer))
        total_flops = flops_analyzer.total()
    
    else:
        raise TypeError("不支持的Decoder类型")

    print(f"  > 完整 Rollout ({rollout_steps} 步) 的总 FLOPs: {total_flops / 1e9:.4f} GFLOPs")
    return total_flops
    

if __name__ == '__main__':
    args = get_args()
    device = torch.device(f'cuda:{3}' if torch.cuda.is_available() else 'cpu')
    print(f"将在设备 {device} 上进行分析。")
    if torch.cuda.is_available():
        torch.cuda.set_device(device)
    if len(args.capacity_ratios) != args.propagator_depth:
        raise ValueError("capacity_ratios 列表长度与 propagator_depth 不匹配!")

    # --- 1. 构建两个模型 ---
    print("--- 正在构建模型 ---")
    decoder_base = PointWiseDecoder2D(
        latent_channels=args.decoder_emb_dim, out_channels=args.out_channels,
        out_steps=args.out_step, propagator_depth=args.propagator_depth,
    )
    
    decoder_adapt = PointWiseDecoder2D_Adaptive(
        latent_channels=args.decoder_emb_dim, out_channels=args.out_channels,
        out_steps=args.out_step, propagator_depth=args.propagator_depth,
        capacity_ratios=args.capacity_ratios, out_seq_emb_dim=args.out_seq_emb_dim
    )
    decoder_base.to(device)
    decoder_adapt.to(device)
    print("模型构建完毕。")

    # --- 2. 准备伪输入数据 ---
    B = args.batch_size
    N = args.resolution * args.resolution
    
    dummy_z = torch.randn(B, N, args.out_seq_emb_dim)
    dummy_pos = torch.randn(B, N, 2)
    dummy_h = torch.randn(B, N, args.decoder_emb_dim)
    
    # --- [核心修复 3] 在创建伪输入时就明确指定 requires_grad=False ---
    dummy_z = dummy_z.to(device).requires_grad_(False)
    dummy_pos = dummy_pos.to(device).requires_grad_(False)
    dummy_h = dummy_h.to(device).requires_grad_(False)
    
    dummy_inputs = (dummy_z, dummy_pos, dummy_h)
    rollout_steps = args.out_seq_len // args.out_step

    # --- 3. 执行分析 ---
    total_flops_base = analyze_decoder_flops(decoder_base, dummy_inputs, rollout_steps, "基准解码器 (PointWiseDecoder2D)", device)
    total_flops_adapt = analyze_decoder_flops(decoder_adapt, dummy_inputs, rollout_steps, "自适应解码器 (PointWiseDecoder2D_Adaptive)", device)
    
    # --- 4. 报告总结 ---
    print("\n" + "="*60)
    print("                 Decoder FLOPs 对比总结")
    print("="*60)
    print(f"基准模型总 FLOPs:     {total_flops_base / 1e9:.4f} GFLOPs")
    print(f"自适应模型总 FLOPs: {total_flops_adapt / 1e9:.4f} GFLOPs")
    print("-" * 60)
    
    if total_flops_base > 0:
        savings_ratio = (total_flops_base - total_flops_adapt) / total_flops_base
        print(f"✅ 计算量节省比例: {savings_ratio:.2%}")
    print("="*60)