#!/usr/bin/env python
#-*- coding:utf-8 _*-
# Filename: benchmark.py
# (最终完整版 - 已移除AMP，专注于FP32基准测试)

import sys
import os
import time
import json
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset

# 确保项目路径在 sys.path 中
# 假设 benchmark.py 存放在与 train_p.py, models/, data_utils_fix.py 等相同的目录下或项目根目录
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# 导入所有必要的自定义模块
from args import get_args
from data_utils_fix import get_segment_dataset, get_model, get_loss_func, collate_segment_batch
from utils import get_seed, get_num_params

def warmup(models, loader, device):
    """GPU预热 (无AMP版本)"""
    try:
        g_batch, coords_batch, u_p_batch, inputs_f_batch, features_batch, _ = next(iter(loader))
        # 将数据移动到设备
        g_batch = g_batch.to(device)
        coords_batch = coords_batch.to(device)
        u_p_batch = u_p_batch.to(device)
        inputs_f_batch = inputs_f_batch.to(device)
        features_batch = features_batch.to(device)
        
        B, _, N, C_state = features_batch.shape
        initial_state = features_batch[:, 0, :, :]
        g_batch.ndata['x'] = initial_state.reshape(-1, C_state)

        print("Warming up GPU with a sample batch...")
        for _ in tqdm(range(10), desc="Warming up", leave=False):
            for model in models.values():
                with torch.no_grad():
                    _ = model(g_batch.clone(), coords_batch.clone(), u_p_batch.clone(), inputs_f_batch)
        torch.cuda.synchronize()
        print("GPU warmup complete.")
    except StopIteration:
        print("Warning: DataLoader is empty, skipping warmup.")
    except Exception as e:
        print(f"An error occurred during warmup: {e}")
        raise e

def evaluate_rollout_accuracy(model, loader, metric_func, device, rollout_length):
    """评估模型的长时序自回归误差 (无AMP版本)"""
    all_error_curves = []
    pbar = tqdm(loader, desc="Evaluating Accuracy", leave=False)
    
    for g_batch, coords_batch, u_p_batch, inputs_f_batch, features_batch, targets_batch in pbar:
        # 数据移动到设备
        g_batch = g_batch.to(device)
        coords_batch = coords_batch.to(device, non_blocking=True)
        u_p_batch = u_p_batch.to(device, non_blocking=True)
        features_batch = features_batch.to(device, non_blocking=True)
        targets_batch = targets_batch.to(device, non_blocking=True)
        inputs_f_batch = inputs_f_batch.to(device)
            
        B, _, N, C_state = features_batch.shape
        current_state = features_batch[:, 0, :, :]
        batch_errors = torch.zeros((B, rollout_length), device=device)

        with torch.no_grad():
            for k in range(rollout_length):
                g_batch.ndata['x'] = current_state.reshape(-1, C_state)
                pred_out, _ = model(g_batch, coords_batch, u_p_batch, inputs_f_batch)
                
                target_state = targets_batch[:, k, :, :]
                
                _, _, metric_k = metric_func(g_batch, pred_out, target_state.reshape(-1, C_state))
                batch_errors[:, k] = torch.from_numpy(np.atleast_1d(metric_k)).to(device)
                
                if k < rollout_length - 1:
                    current_state = pred_out.view(B, N, C_state)
            
        all_error_curves.append(batch_errors.cpu())

    if not all_error_curves:
        return np.array([])
    return torch.cat(all_error_curves, dim=0).numpy()

def evaluate_inference_speed(model, loader, device, rollout_length, num_runs):
    """评估模型的推理吞吐量 (无AMP版本)"""
    total_throughputs = []
    
    with torch.no_grad():
        for run in range(num_runs):
            pbar = tqdm(loader, desc=f"Evaluating Speed (Run {run+1}/{num_runs})", leave=False)
            for g_batch, coords_batch, u_p_batch, inputs_f_batch, features_batch, _ in pbar:
                # 数据移动到设备
                g_batch = g_batch.to(device)
                coords_batch = coords_batch.to(device, non_blocking=True)
                u_p_batch = u_p_batch.to(device, non_blocking=True)
                features_batch = features_batch.to(device, non_blocking=True)
                inputs_f_batch = inputs_f_batch.to(device)

                B, _, N, C_state = features_batch.shape
                current_state = features_batch[:, 0, :, :]

                torch.cuda.synchronize(device=device)
                start_time = time.perf_counter()

                for k in range(rollout_length):
                    g_batch.ndata['x'] = current_state.reshape(-1, C_state)
                    pred_out, _ = model(g_batch, coords_batch, u_p_batch, inputs_f_batch)
                    if k < rollout_length - 1:
                        current_state = pred_out.view(B, N, C_state)
                
                torch.cuda.synchronize(device=device)
                end_time = time.perf_counter()

                elapsed_time = end_time - start_time
                total_frames = B * rollout_length
                if elapsed_time > 1e-9:
                    total_throughputs.append(total_frames / elapsed_time)

    if not total_throughputs:
        return 0.0
    return np.mean(total_throughputs)

def process_and_print_results(results, args):
    """汇总、计算并打印最终的基准测试结果表格"""
    processed = {}
    if not results: return {}
    
    base_model_name = os.path.splitext(os.path.basename(args.model_paths[0]))[0]

    for model_name, data in results.items():
        error_curves = np.array(data['error_curves'])
        
        if error_curves.size == 0:
            mean_accumulated_error, mean_final_error, mean_curve = np.nan, np.nan, []
        else:
            mean_curve = np.mean(error_curves, axis=0)
            mean_accumulated_error = np.mean(mean_curve)
            mean_final_error = mean_curve[-1]
        
        processed[model_name] = {
            'mean_accumulated_error': float(mean_accumulated_error),
            'mean_final_error': float(mean_final_error),
            'throughput_fps': float(data['throughput_fps']),
            'mean_error_curve': mean_curve.tolist()
        }

    base_throughput = processed.get(base_model_name, {}).get('throughput_fps', 1.0)
    if base_throughput < 1e-9: base_throughput = 1.0
    
    for model_name in processed:
        processed[model_name]['speedup'] = processed[model_name]['throughput_fps'] / base_throughput

    header = f"{'Model':<45} | {'Mean Accum. Error':<20} | {'Mean Final Error':<20} | {'Throughput (fps)':<20} | {'Speedup':<10}"
    print("\n" + "="*len(header))
    print(" " * (len(header)//2 - 10) + "BENCHMARK SUMMARY")
    print("="*len(header))
    print(header)
    print("-"*len(header))
    
    for model_name, data in processed.items():
        print(f"{model_name:<45} | {data['mean_accumulated_error']:<20.4f} | {data['mean_final_error']:<20.4f} | {data['throughput_fps']:<20.2f} | {data['speedup']:<10.2f}x")
    
    print("="*len(header))
    return processed

def run_benchmark(args):
    """主函数，执行整个基准测试流程"""
    get_seed(args.seed, printout=True)
    
    if torch.cuda.is_available() and not getattr(args, 'no_cuda', False):
        device = torch.device(f'cuda:{args.gpu}')
    else:
        device = torch.device("cpu")
    if not args.model_paths:
        print("错误: 未提供 --model_paths 参数。请指定要测试的模型文件。")
        return

    # --- 1. 加载数据 ---
    print("\n--- 1. Loading Benchmark Data ---")
    args.segment_length = args.rollout_length + 1
    _,test_dataset = get_segment_dataset(args)
    # print("\n--- OBJECT TYPE CHECK ---")
    # print(f"Type of object passed to DataLoader: {type(test_dataset)}")
    # print("-------------------------\n")
    # if args.test_sims_indices:
    #     print(f"Selecting specific simulation indices for test: {args.test_sims_indices}")
    #     test_dataset = Subset(test_dataset, args.test_sims_indices)

    test_loader = DataLoader(
        test_dataset,
        batch_size=args.benchmark_batch_size,
        shuffle=False,
        collate_fn=collate_segment_batch,
        num_workers=args.num_workers
    )
    print(f"Data loaded. Number of test segments: {len(test_dataset)}")
    
    args.normalizer = args.y_normalizer.to(device) if hasattr(args, 'y_normalizer') and args.y_normalizer is not None else None
    metric_func = get_loss_func(name='rel2', args=args, normalizer=args.normalizer)

    # --- 2. 加载模型 ---
    print("\n--- 2. Loading Models ---")
    models = {}
    for model_path in args.model_paths:
        try:
            model_name = os.path.splitext(os.path.basename(model_path))[0]
            print(f"Loading model '{model_name}' from {model_path}...")
            checkpoint = torch.load(model_path, map_location=device)
            model_args_from_ckpt = checkpoint['args']
            
            model = get_model(model_args_from_ckpt).to(device)
            model.load_state_dict(checkpoint['model'])
            model.eval()
            models[model_name] = model
            print(f" -> Model loaded successfully. Params: {get_num_params(model)/1e6:.2f}M")
        except Exception as e:
            print(f"Error loading model from {model_path}: {e}")
            raise e
    
    if not models:
        print("No models were loaded. Exiting.")
        return

    # --- 3. 预热阶段 ---
    print("\n--- 3. Warming Up GPU ---")
    warmup(models, test_loader, device)

    # --- 4. 核心基准测试 ---
    print("\n--- 4. Running Benchmark (in FP32 mode) ---")
    results = {}
    for model_name, model in models.items():
        print(f"\nBenchmarking model: {model_name}")
        
        error_curves = evaluate_rollout_accuracy(
            model, test_loader, metric_func, device, args.rollout_length
        )
        
        avg_throughput = evaluate_inference_speed(
            model, test_loader, device, args.rollout_length, args.speed_test_runs
        )
        
        results[model_name] = {
            'error_curves': error_curves.tolist(),
            'throughput_fps': avg_throughput
        }

    # --- 5. 结果汇总与报告 ---
    print("\n--- 5. Benchmark Results ---")
    processed_results = process_and_print_results(results, args)
    
    if args.save_path:
        save_dir = os.path.dirname(args.save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(args.save_path, 'w') as f:
            json.dump(processed_results, f, indent=4)
        print(f"\nDetailed results saved to {args.save_path}")

if __name__ == '__main__':
    # 从 args.py 加载所有可能的参数定义
    args = get_args()
    
    # 运行基准测试主函数
    run_benchmark(args)