#!/usr/bin/env python
#-*- coding:utf-8 _*-
# Filename: benchmark_core.py
# 一个专门用于微基准测试的脚本，精确测量和比较 GNOT 模型
# 核心多层网络结构的推理性能（吞吐量和加速比）。

import sys
import os
import time
import json
import numpy as np
import torch
from tqdm import tqdm

# 确保项目路径在 sys.path 中
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from args import get_args
from data_utils import get_dataset, get_model, MIODataLoader
from utils import get_seed, get_num_params

def warmup(model, loader, device):
    """GPU预热，确保 CUDA 内核已被编译和加载。"""
    print("Warming up GPU...")
    # 使用一个批次的数据进行几次前向传播
    try:
        g_batch, u_p_batch, inputs_f_batch = next(iter(loader))
        g_batch, u_p_batch = g_batch.to(device), u_p_batch.to(device)
        if hasattr(inputs_f_batch, 'to'):
            inputs_f_batch = inputs_f_batch.to(device)
        
        for _ in range(5):
            with torch.no_grad():
                _ = model(g_batch, u_p_batch, inputs_f_batch)
        torch.cuda.synchronize(device=device)
        print("Warmup complete.")
    except StopIteration:
        print("Warning: DataLoader is empty, skipping warmup.")
    except Exception as e:
        print(f"An error occurred during warmup: {e}")

def evaluate_core_module_speed(model, loader, device, num_runs):
    """
    通过在模型 forward 方法内部署的计时器，精确测量核心多层网络结构的执行时间。
    返回核心模块的吞吐量（样本/秒）。
    """
    all_core_throughputs = []

    with torch.no_grad():
        for run in range(num_runs):
            total_samples = 0
            
            # --- 创建 CUDA 事件对象，用于精确计时 ---
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            timer_events = (start_event, end_event)

            total_core_time_ms = 0.0

            pbar = tqdm(loader, desc=f"Core Speed Test (Run {run+1}/{num_runs})", leave=False)
            for g_batch, u_p_batch, inputs_f_batch in pbar:
                # 数据移动到设备
                g_batch, u_p_batch = g_batch.to(device), u_p_batch.to(device)
                if hasattr(inputs_f_batch, 'to'):
                    inputs_f_batch = inputs_f_batch.to(device)

                # --- 关键：将 timer_events 传递给模型 ---
                # 模型内部的 forward 方法会使用这两个 event 来包裹核心计算块
                _ = model(g_batch, u_p_batch, inputs_f_batch, timer_events=timer_events)
                
                # 等待 GPU 完成到 end_event 的所有计算，然后获取耗时
                torch.cuda.synchronize(device=device) 
                current_core_time_ms = start_event.elapsed_time(end_event)
                total_core_time_ms += current_core_time_ms

                total_samples += g_batch.batch_size

            # 计算本次运行的核心模块吞吐量
            total_core_time_sec = total_core_time_ms / 1000.0 # 转换为秒
            
            if total_core_time_sec > 1e-9:
                all_core_throughputs.append(total_samples / total_core_time_sec)

    # 对多次运行的结果求平均
    final_core_throughput = np.mean(all_core_throughputs) if all_core_throughputs else 0.0
    
    return final_core_throughput

def process_and_print_results(results, args):
    """汇总、计算并打印最终的微基准测试结果表格"""
    processed = {}
    if not results: return {}
    
    base_model_name = os.path.splitext(os.path.basename(args.model_paths[0]))[0]

    for model_name, data in results.items():
        processed[model_name] = {
            'core_throughput_sps': float(data['core_throughput_sps'])
        }

    base_throughput = processed.get(base_model_name, {}).get('core_throughput_sps', 1.0)
    if base_throughput < 1e-9: base_throughput = 1.0
    
    for model_name in processed:
        processed[model_name]['speedup'] = processed[model_name]['core_throughput_sps'] / base_throughput

    header = f"{'Model':<60} | {'Core Throughput (samples/s)':<30} | {'Core Speedup':<15}"
    print("\n" + "="*len(header))
    print(" " * (len(header)//2 - 17) + "CORE MODULE MICRO-BENCHMARK SUMMARY")
    print("="*len(header))
    print(header)
    print("-"*len(header))
    
    for model_name, data in processed.items():
        print(f"{model_name:<60} | {data['core_throughput_sps']:<30.2f} | {data['speedup']:<15.2f}x")
    
    print("="*len(header))
    return processed

def run_benchmark(args):
    """主函数，执行整个微基准测试流程"""
    get_seed(args.seed, printout=True)
    
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() and not getattr(args, 'no_cuda', False) else "cpu")
    print(f"Using device: {device}")

    if not args.model_paths:
        print("错误: 未提供 --model_paths 参数。")
        return

    # --- 1. 加载数据 ---
    print("\n--- 1. Loading Data ---")
    _ , test_dataset = get_dataset(args)
    test_loader = MIODataLoader(test_dataset, batch_size=args.benchmark_batch_size, shuffle=False)
    print(f"Data loaded. Number of test samples: {len(test_dataset)}")
    
    # --- 2. 加载模型 ---
    print("\n--- 2. Loading Models ---")
    models = {}
    for model_path in args.model_paths:
        model_name_from_path = os.path.splitext(os.path.basename(model_path))[0]
        print(f"Loading model '{model_name_from_path}' from {model_path}...")
        checkpoint = torch.load(model_path, map_location=device)
        model_args_from_ckpt = checkpoint['args']
        model = get_model(model_args_from_ckpt).to(device)
        model.load_state_dict(checkpoint['model'])
        model.eval()
        models[model_name_from_path] = model
        print(f" -> Model loaded. Params: {get_num_params(model)/1e6:.2f}M")
    
    if not models: return

    # --- 3. 预热阶段 ---
    warmup(list(models.values())[0], test_loader, device) # 只用第一个模型预热即可

    # --- 4. 核心微基准测试 ---
    print("\n--- 4. Running Core Module Micro-Benchmark ---")
    results = {}
    for model_name, model in models.items():
        print(f"\nBenchmarking core module of: {model_name}")
        
        core_throughput = evaluate_core_module_speed(
            model, test_loader, device, args.speed_test_runs
        )
        
        results[model_name] = {
            'core_throughput_sps': core_throughput
        }

    # --- 5. 结果汇总与报告 ---
    print("\n--- 5. Benchmark Results ---")
    processed_results = process_and_print_results(results, args)
    
    if args.save_path:
        save_dir = os.path.dirname(args.save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(args.save_path, 'w') as f:
            json.dump(processed_results, f, indent=4)
        print(f"\nDetailed results saved to {args.save_path}")
if __name__ == '__main__':
    args = get_args()
    
    defaults = {
        'model_paths': [], 'save_path': None, 
        'speed_test_runs': 5, 'benchmark_batch_size': 16
    }
    for key, value in defaults.items():
        if not hasattr(args, key):
            setattr(args, key, value)
    
    run_benchmark(args)