import torch
import torch.nn as nn
import numpy as np
import argparse
from fvcore.nn import FlopCountAnalysis, flop_count_str, parameter_count
from functools import partial

# 导入您的模型
from nn_module.decoder_module import PointWiseDecoder2DSimple, AdaptivePointWiseDecoder2D_SteadyState

# ==============================================================================
# 1. 修复警告：为 fvcore 添加自定义操作的 FLOPs 计数器
# ==============================================================================
def add_handler(inputs, outputs):
    # inputs 和 outputs 是 list of torch._C.Value
    # 使用 .type().sizes() 来获取形状
    output_sizes = outputs[0].type().sizes()
    return np.prod(output_sizes)

def gelu_handler(inputs, outputs):
    # inputs 和 outputs 是 list of torch._C.Value
    # 使用 .type().sizes() 来获取形状
    input_sizes = inputs[0].type().sizes()
    return np.prod(input_sizes)

# 将自定义 handler 放入一个字典
custom_ops = {
    "aten::add": add_handler,
    "aten::gelu": gelu_handler,
}

# ==============================================================================
# 2. 包装器模块 (解决 JIT Tracer 错误)
# ==============================================================================
class DensePropagatorWrapper(nn.Module):
    """包装 nn.ModuleList 以便 fvcore 分析"""
    def __init__(self, propagator_blocks):
        super().__init__()
        self.propagator = propagator_blocks

    def forward(self, hidden_states, pos_embedding):
        h = hidden_states
        for block in self.propagator:
            h = block(h, pos_embedding)
        return h

class AdaptivePropagatorWrapper(nn.Module):
    """
    包装自适应传播器，只返回第一个张量输出，向 fvcore 隐藏 diagnostics 字典。
    """
    def __init__(self, adaptive_propagator_module):
        super().__init__()
        self.propagator = adaptive_propagator_module
        
    def forward(self, hidden_states, pos_embedding, z_context):
        # 调用原始的 forward 方法，但只返回第一个输出
        h_final, _ = self.propagator(hidden_states, pos_embedding, z_context)
        return h_final

# ==============================================================================
# 3. 主分析函数
# ==============================================================================
def main():
    # ... (参数解析部分与之前版本相同)
    parser = argparse.ArgumentParser(description="FLOPs analysis for decoder propagator modules.")
    parser.add_argument('--dataset', type=str, required=True, choices=['pipe', 'naca'])
    parser.add_argument('--propagator_depth', type=int, default=8)
    parser.add_argument('--latent_channels', type=int, default=256)
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None)
    parser.add_argument('--final_keep_ratio', type=float, default=0.25)
    opt = parser.parse_args()

    if opt.dataset == 'pipe':
        NUM_POINTS, DUMMY_RES = 289, 17
    elif opt.dataset == 'naca':
        NUM_POINTS, DUMMY_RES = 598, 23
    else:
        raise ValueError("Invalid dataset choice.")
    
    print(f"--- Analyzing FLOPs for [{opt.dataset.upper()}] dataset ---")
    print(f"Number of points (N): {NUM_POINTS}")
    print(f"Propagator depth (D): {opt.propagator_depth}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dummy_hidden_states = torch.randn(1, NUM_POINTS, opt.latent_channels).to(device)
    dummy_pos_embedding = torch.randn(1, NUM_POINTS, 2).to(device)
    dummy_z_context = dummy_hidden_states.detach()

    # --- 实例化并隔离目标模块 (与之前版本相同) ---
    deep_decoder = PointWiseDecoder2DSimple(
        latent_channels=opt.latent_channels, out_channels=1,
        refinement_depth=opt.propagator_depth, res=DUMMY_RES,
        scale=0.5, heads=4, dim_head=32)
    dense_propagator = deep_decoder.refinement_blocks.to(device)

    capacity_ratios = opt.capacity_ratios
    if capacity_ratios is None:
        capacity_ratios = np.linspace(1.0, opt.final_keep_ratio, opt.propagator_depth).tolist()
    adaptive_decoder = AdaptivePointWiseDecoder2D_SteadyState(
        latent_channels=opt.latent_channels, out_channels=1,
        propagator_depth=opt.propagator_depth, capacity_ratios=capacity_ratios,
        res=DUMMY_RES, scale=0.5, heads=4, dim_head=32)
    adaptive_propagator = adaptive_decoder.propagator.to(device)
    
    # --- 计算FLOPs (使用包装器) ---
    # 基准FLOPs
    dense_wrapper = DensePropagatorWrapper(dense_propagator)
    flops_dense = FlopCountAnalysis(dense_wrapper, (dummy_hidden_states, dummy_pos_embedding))
    flops_dense.set_op_handle(**custom_ops) # <-- 应用自定义计数器
    baseline_flops = flops_dense.total()

    # 自适应FLOPs
    adaptive_wrapper = AdaptivePropagatorWrapper(adaptive_propagator)
    flops_adaptive = FlopCountAnalysis(adaptive_wrapper, (dummy_hidden_states, dummy_pos_embedding, dummy_z_context))
    flops_adaptive.set_op_handle(**custom_ops) # <-- 应用自定义计数器
    adaptive_flops = flops_adaptive.total()

    # --- 打印和报告结果 (与之前版本相同) ---
    print("\n" + "="*50)
    print("           FLOPs Analysis Results")
    print("="*50)
    print(f"  Baseline (Dense) Propagator FLOPs: {baseline_flops / 1e9:.4f} GFLOPs")
    print(f"  Adaptive Propagator FLOPs:         {adaptive_flops / 1e9:.4f} GFLOPs")
    print("-" * 50)
    
    reduction_abs = baseline_flops - adaptive_flops
    reduction_ratio = (reduction_abs / baseline_flops) * 100 if baseline_flops > 0 else 0.0

    print(f"  Absolute FLOPs Reduction: {reduction_abs / 1e9:.4f} GFLOPs")
    print(f"  Relative FLOPs Reduction: {reduction_ratio:.2f}%")
    print("="*50)
    # print("\nDetailed FLOPs breakdown:")
    # print(flops_adaptive.by_module_and_operator())

if __name__ == '__main__':
    main()