import torch
from argparse import ArgumentParser
import pandas as pd

# --- 核心依赖 ---
# 请确保已经安装 fvcore: pip install fvcore
from fvcore.nn import FlopCountAnalysis, flop_count_str

# --- 导入需要分析的模型模块 ---
# 假设此脚本与 models 文件夹在同一目录下
from models.ipot.ipot_processor import IPOTProcessor
from models.ipot.ipot_processor_adapt import IPOTProcessorAdapt


def build_processor(args, processor_type: str):
    """
    一个辅助函数，根据传入的参数和 processor_type，
    只构建并返回相应的处理器模块。
    """
    if processor_type == 'standard':
        print("正在构建 Standard IPOT Processor...")
        return IPOTProcessor(
            self_per_cross_attn=args.self_per_cross_attn,
            latent_channel=args.latent_channel,
            self_heads_num=args.self_heads_num,
            self_heads_channel=args.self_heads_channel,
            ff_mult=args.ff_mult,
        )
    elif processor_type == 'adaptive':
        print("正在构建 Adaptive IPOT Processor...")
        assert len(args.reduction_schedule) == args.self_per_cross_attn, \
            f"Reduction schedule length ({len(args.reduction_schedule)}) must match the number of processor layers ({args.self_per_cross_attn})."
        return IPOTProcessorAdapt(
            self_per_cross_attn=args.self_per_cross_attn,
            latent_channel=args.latent_channel,
            self_heads_num=args.self_heads_num,
            self_heads_channel=args.self_heads_channel,
            ff_mult=args.ff_mult,
            reduction_schedule=args.reduction_schedule
        )
    else:
        raise ValueError(f"未知的处理器类型: {processor_type}")


def profile_processor_flops(processor_model, input_tensor):
    """
    执行计算量统计的核心逻辑。
    返回总 GFLOPs 和详细的分析报告字符串。
    """
    # FlopCountAnalysis 支持 kwargs 形式的输入
    flops_counter = FlopCountAnalysis(processor_model, inputs=(input_tensor,))
    
    # 启用 fvcore 的详细模式来追踪未被计算的操作
    flops_counter.unsupported_ops_warnings(False).uncalled_modules_warnings(False)

    total_flops = flops_counter.total()
    
    # 将 FLOPs 转换为 GFLOPs
    gflops = total_flops / 1e9
    
    # 生成详细的模块计算量报告
    detailed_report = flop_count_str(flops_counter)
    
    return gflops, detailed_report


def main(args):
    """
    主执行逻辑，串联所有模块，完成实验并展示结果。
    """
    print("\n" + "="*50)
    print(" " * 10 + "IPOT Processor FLOPs 分析")
    print("="*50)
    print("实验配置:")
    print(f"  - Latent Tokens (num_latents): {args.num_latents}")
    print(f"  - Latent Channel (latent_channel): {args.latent_channel}")
    print(f"  - Processor Layers (self_per_cross_attn): {args.self_per_cross_attn}")
    print(f"  - Reduction Schedule: {args.reduction_schedule}")
    print("="*50 + "\n")

    # 创建一个共享的伪造输入张量
    # 形状: (batch_size, num_latents, latent_channel)
    # batch_size=1 用于分析单个样本的计算量
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"分析将在设备: {device} 上进行 (注意: FLOPs统计与设备无关)")
    dummy_input = torch.randn(1, args.num_latents, args.latent_channel).to(device)

    # --- 1. 分析基准模型 ---
    standard_processor = build_processor(args, 'standard').to(device)
    standard_processor.eval()
    print("正在分析 Standard Processor...")
    standard_gflops, standard_report = profile_processor_flops(standard_processor, dummy_input)
    print("分析完成。\n")

    # --- 2. 分析自适应模型 ---
    adaptive_processor = build_processor(args, 'adaptive').to(device)
    adaptive_processor.eval()
    print("正在分析 Adaptive Processor...")
    adaptive_gflops, adaptive_report = profile_processor_flops(adaptive_processor, dummy_input)
    print("分析完成。\n")

    # --- 3. 结果展示与对比 ---
    reduction_gflops = standard_gflops - adaptive_gflops
    reduction_percent = (reduction_gflops / standard_gflops) * 100 if standard_gflops > 0 else 0

    summary_data = {
        'Processor Type': ['Standard (Baseline)', 'Adaptive'],
        'Total GFLOPs': [f"{standard_gflops:.4f}", f"{adaptive_gflops:.4f}"]
    }
    summary_df = pd.DataFrame(summary_data)

    print("="*50)
    print(" " * 18 + "分析总结")
    print("="*50)
    print(summary_df.to_string(index=False))
    print("-" * 50)
    print(f"计算量减少 (GFLOPs): {reduction_gflops:.4f}")
    print(f"计算量减少比例: {reduction_percent:.2f}%")
    print("="*50)

    if args.verbose:
        print("\n\n" + "="*50)
        print(" " * 10 + "详细报告: Standard Processor")
        print("="*50)
        print(standard_report)
        
        print("\n\n" + "="*50)
        print(" " * 10 + "详细报告: Adaptive Processor")
        print("="*50)
        print(adaptive_report)


if __name__ == '__main__':
    parser = ArgumentParser(description='精确分析 IPOT Processor 模块的理论计算量 (FLOPs)')
    
    # --- 提供与 main.py 和 benchmark.py 一致的模型架构参数 ---
    parser.add_argument('--num_latents', type=int, default=256, help='隐变量(tokens)的数量')
    parser.add_argument('--latent_channel', type=int, default=64, help='隐变量的特征维度')
    parser.add_argument('--self_per_cross_attn', type=int, default=8, help='处理器中的层数')
    parser.add_argument('--self_heads_num', type=int, default=4, help='自注意力头的数量')
    parser.add_argument('--self_heads_channel', type=int, default=None, help='每个自注意力头的维度')
    parser.add_argument('--ff_mult', type=int, default=2, help='前馈网络中间层的扩展倍数')
    
    parser.add_argument('--reduction_schedule', type=float, nargs='+', 
                        default=[1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.1, 0.1],
                        help='自适应处理器的 Token 保留比例策略。长度必须等于 --self_per_cross_attn')

    parser.add_argument('--verbose', action='store_true', help='如果设置，将打印每个子模块的详细FLOPs报告')

    args = parser.parse_args()
    
    main(args)