# benchmark.py (Advanced Version - Supports Asymmetric Architectures)

import torch
import torch.utils.data
import time
import argparse
import numpy as np
import os
import inspect
import scipy.io as scio
from tqdm import tqdm

# 导入模型字典和与训练脚本相同的损失函数
from model_dict_adaptive import get_model
from utils.testloss import TestLoss

def get_benchmark_args():
    """
    定义一个支持非对称架构对比的、功能全面的基准测试命令行参数解析器。
    """
    parser = argparse.ArgumentParser(description='Advanced Transolver Benchmark (Supports Asymmetric Arch)')
    
    # --- 核心路径参数 ---
    parser.add_argument('--model_path_base', type=str, required=True, 
                        help='基准模型的检查点路径')
    parser.add_argument('--model_path_optimized', type=str, required=True, 
                        help='优化模型的检查点路径')
    
    # --- 模型类型指定 ---
    parser.add_argument('--model_name_base', type=str, default='Transolver_Structured_Mesh_2D', 
                        help='基准模型的类型')
    parser.add_argument('--model_name_optimized', type=str, default='StructuredAdaptiveTransolver',
                        help='优化模型的类型')
                        
    # --- 数据与任务参数 ---
    parser.add_argument('--data_path', type=str, required=True, 
                        help='NavierStokes .mat 文件的完整路径')
    parser.add_argument('--time_steps', type=int, default=10, 
                        help='自回归预测的步长 (T_pred)')
    parser.add_argument('--batch_size', type=int, default=4, 
                        help='推理时的批处理大小')
    
    # --- 硬件与性能控制 ---
    parser.add_argument('--gpu', type=str, default='0', 
                        help='要使用的GPU索引')
    parser.add_argument('--warmup_runs', type=int, default=5, 
                        help='在正式计时前，用多少个自回归步来进行预热')
    
    # --- 模型架构参数 (分为 _base 和 _optimized) ---
    # a. 基准模型 (Base Model) 的参数
    parser.add_argument('--n-hidden-base', type=int,default=128, help='基准模型的隐藏层维度')
    parser.add_argument('--n-layers-base', type=int,default=8,  help='基准模型的层数')
    parser.add_argument('--n-head-base', type=int, default=8, help="基准模型的注意力头数")
    parser.add_argument('--slice_num-base', type=int,default=32,  help='基准模型的Slice数')
    parser.add_argument('--mlp_ratio-base', type=int, default=1, help='基准模型的MLP膨胀比例')

    # b. 优化模型 (Optimized Model) 的参数
    parser.add_argument('--n-hidden-optimized', type=int, default=128, help='优化模型的隐藏层维度')
    parser.add_argument('--n-layers-optimized', type=int, default=8, help='优化模型的层数')
    parser.add_argument('--n-head-optimized', type=int, default=8, help="优化模型的注意力头数")
    parser.add_argument('--slice_num-optimized', type=int, default=32, help='优化模型的Slice数')
    parser.add_argument('--mlp_ratio-optimized', type=int, default=1, help='优化模型的MLP膨胀比例')
    
    # c. 优化模型专用的自适应参数
    parser.add_argument('--capacity_ratios', type=float, nargs='+',
                        help='[优化模型专用] 每个递归层的保留比例列表')
    
    # d. 其他共享参数
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--unified_pos', type=int, default=1)
    parser.add_argument('--dropout', type=float, default=0.0)
                        
    return parser.parse_args()

def load_model_from_checkpoint(model_path, model_name, cli_args, device, model_type):
    """
    从检查点加载模型。
    model_type: 一个字符串, 'base' 或 'optimized'
    """
    print(f"--- 正在加载 {model_type.upper()} 模型: {os.path.basename(model_path)} ({model_name}) ---")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"模型文件未找到: {model_path}")

    model_params = {}
    # 提取通用参数
    for p in ['ref', 'unified_pos', 'dropout']:
        if hasattr(cli_args, p):
            model_params[p] = getattr(cli_args, p)

    # 提取特定于该模型类型的参数
    for p in ['n_hidden', 'n_layers', 'n_head', 'slice_num', 'mlp_ratio']:
        arg_name_cli = f"{p.replace('_', '-')}-{model_type}"
        arg_key_namespace = p.replace('-', '_') + f"_{model_type}"
        if not hasattr(cli_args, arg_key_namespace):
            raise AttributeError(f"命令行参数缺失: 未找到 --{arg_name_cli}。请为 {model_type} 模型提供此参数。")
        model_params[p.replace('-', '_')] = getattr(cli_args, arg_key_namespace)

    # 如果是优化模型，还需要 capacity_ratios
    if model_type == 'optimized':
        if cli_args.capacity_ratios is None:
             raise ValueError(f"参数缺失: --capacity_ratios 是 {model_name} 必需的参数。")
        model_params['capacity_ratios'] = cli_args.capacity_ratios
        
    # 定义任务固定参数
    h = 64
    T_in = 10
    step = 1
    
    model_params.update({
        'space_dim': 2, 'fun_dim': T_in, 'out_dim': step,
        'H': h, 'W': h
    })
    
    model_class = get_model(argparse.Namespace(model=model_name)).Model
    constructor_params = inspect.signature(model_class.__init__).parameters
    
    # 智能名称适配 n_heads -> n_head
    if 'n_head' in constructor_params and 'n_heads' in model_params:
         model_params['n_head'] = model_params.pop('n_heads', model_params.get('n_head'))
    
    final_model_params = {key: val for key, val in model_params.items() if key in constructor_params}

    print(f"  > 正在使用以下参数创建模型: {final_model_params}")
    model = model_class(**final_model_params).to(device)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"模型 '{model_name}' 加载成功。")
    return model

def warmup_model(model, warmup_batch, warmup_runs, device):
    model.eval()
    x_batch, fx_in, _ = warmup_batch
    x_batch, fx_in = x_batch.to(device), fx_in.to(device)
    fx = fx_in

    with torch.no_grad():
        for _ in range(warmup_runs):
            model_output = model(x_batch, fx=fx)
            im = model_output[0] if isinstance(model_output, tuple) else model_output
            fx = torch.cat((fx[..., 1:], im), dim=-1)
    
    torch.cuda.synchronize(device=device)

def run_inference(model, data_batch, time_steps, device, metric_func):
    model.eval()
    x_batch, fx_in, true_trajectory = data_batch
    x_batch, fx_in, true_trajectory = x_batch.to(device), fx_in.to(device), true_trajectory.to(device)
    bsz = x_batch.shape[0]
    fx = fx_in

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    with torch.no_grad():
        torch.cuda.synchronize(device=device)
        start_event.record()
        
        pred_list = []
        for t in range(time_steps):
            model_output = model(x_batch, fx=fx)
            im = model_output[0] if isinstance(model_output, tuple) else model_output
            pred_list.append(im)
            fx = torch.cat((fx[..., 1:], im), dim=-1)
        
        end_event.record()
        torch.cuda.synchronize(device=device)
        
        pred = torch.cat(pred_list, -1)
        error = metric_func(pred.reshape(bsz, -1), true_trajectory.reshape(bsz, -1)).item()

    elapsed_time_s = start_event.elapsed_time(end_event) / 1000.0
    return elapsed_time_s, error

if __name__ == "__main__":
    args = get_benchmark_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # --- 1. 数据加载 (与训练脚本保持一致) ---
    print("--- 1. Loading Data for Benchmark ---")
    data_file_path = args.data_path
    ntest = 200
    T_in = 10
    T_pred = args.time_steps
    h = 64
    
    data = scio.loadmat(data_file_path)
    test_a = torch.from_numpy(data['u'][-ntest:, :, :, :T_in].reshape(ntest, -1, T_in)).float()
    test_u = torch.from_numpy(data['u'][-ntest:, :, :, T_in:T_in+T_pred].reshape(ntest, -1, T_pred)).float()
    
    pos = torch.tensor(np.c_[np.meshgrid(np.linspace(0,1,h), np.linspace(0,1,h), indexing='ij')[0].ravel(), np.meshgrid(np.linspace(0,1,h), np.linspace(0,1,h), indexing='ij')[1].ravel()], dtype=torch.float).unsqueeze(0)
    pos_test = pos.repeat(ntest, 1, 1)
    
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, test_a, test_u),
                                              batch_size=args.batch_size, shuffle=False)
    print(f"将使用 {ntest} 个测试样本进行评估，批大小为 {args.batch_size}")
    
    # --- 2. 加载模型 ---
    model_base = load_model_from_checkpoint(
        args.model_path_base, args.model_name_base, args, device, model_type='base'
    )
    model_optimized = load_model_from_checkpoint(
        args.model_path_optimized, args.model_name_optimized, args, device, model_type='optimized'
    )
    print("-" * 50)

    metric_func = TestLoss(size_average=False)
    
    # --- 3. 预热与测试 ---
    print(f"开始基准测试 (每次测试前独立预热 {args.warmup_runs} 步)...")
    
    try:
        warmup_batch = next(iter(test_loader))
        print("\n--- 正在预热模型 ... ---")
        warmup_model(model_base, warmup_batch, args.warmup_runs, device)
        warmup_model(model_optimized, warmup_batch, args.warmup_runs, device)
        print("模型预热完成。")
    except StopIteration:
        print("警告：测试数据加载器为空，跳过预热。")

    total_time_base, total_error_base = 0, 0
    total_time_optimized, total_error_optimized = 0, 0

    pbar = tqdm(test_loader, desc="Benchmarking Batches")
    for data_batch in pbar:
        base_time, base_error = run_inference(model_base, data_batch, args.time_steps, device, metric_func)
        total_time_base += base_time
        total_error_base += base_error

        opt_time, opt_error = run_inference(model_optimized, data_batch, args.time_steps, device, metric_func)
        total_time_optimized += opt_time
        total_error_optimized += opt_error
            
    avg_error_base = total_error_base / ntest
    avg_error_optimized = total_error_optimized / ntest

    # --- 4. 报告结果 ---
    print("\n" + "=" * 60)
    print("                 模型性能与精度对比基准测试结果 (纯自回归评估)")
    print("=" * 60)
    
    print(f"基准模型 ({args.model_name_base}):")
    print(f"  - 总耗时 (在 {ntest} 个样本上): {total_time_base:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_base:.6f}")

    print(f"\n优化模型 ({args.model_name_optimized}):")
    print(f"  - 总耗时 (在 {ntest} 个样本上): {total_time_optimized:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_optimized:.6f}")
    
    print("-" * 60)
    
    if total_time_optimized > 0 and total_time_base > 0:
        speedup = total_time_base / total_time_optimized
        print(f"🚀 加速比 (基准 / 优化): {speedup:.2f}x")
    
    if avg_error_base > 0:
        accuracy_change = ((avg_error_optimized - avg_error_base) / avg_error_base) * 100
        print(f"📉 精度变化 (误差增加/减少): {accuracy_change:+.2f}%")
    
    print("=" * 60)