import torch
import torch.utils.data
import time
import argparse
import numpy as np
import os
import inspect

# 导入新的模型字典和与 exp_airfoil.py 相同的工具
from model_dict_adaptive import get_model
from utils.testloss import TestLoss

def get_benchmark_args():
    parser = argparse.ArgumentParser(description='Transolver Benchmark for Airfoil Flow')
    # ... (参数定义与 benchmark_pipe.py 类似)
    parser.add_argument('--model_path_base', type=str, required=True)
    parser.add_argument('--model_path_optimized', type=str, required=True)
    parser.add_argument('--model_name_base', type=str, default='Transolver_Structured_Mesh_2D')
    parser.add_argument('--model_name_optimized', type=str, default='PipeAdaptiveTransolver')
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--warmup_runs', type=int, default=10)
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--n-layers', type=int, default=8)
    parser.add_argument('--n-head', type=int, default=8)
    parser.add_argument('--mlp_ratio', type=int, default=1)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--slice_num', type=int, default=64)
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--capacity_ratios', type=float, nargs='+')
    return parser.parse_args()

def load_model_from_checkpoint(model_path, model_name, cli_args, device, s1, s2, fun_dim=0):
    # (此函数与上面 Darcy 版本完全相同)
    print(f"--- 正在加载模型: {os.path.basename(model_path)} ({model_name}) ---")
    if not os.path.exists(model_path): raise FileNotFoundError(f"{model_path} not found")
    model_class = get_model(argparse.Namespace(model=model_name)).Model
    model_params = vars(cli_args).copy()
    constructor_params = inspect.signature(model_class.__init__).parameters
    if 'n_head' in constructor_params and 'n_heads' in model_params:
        model_params['n_head'] = model_params['n_heads']
    model_params.update({'space_dim': 2, 'fun_dim': fun_dim, 'out_dim': 1, 'H': s1, 'W': s2})
    final_model_params = {k: v for k, v in model_params.items() if k in constructor_params}
    model = model_class(**final_model_params).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"模型 '{model_name}' 加载成功。")
    return model

def warmup_model(model, warmup_loader, warmup_runs, device):
    # (此函数与上面 Darcy 版本完全相同)
    model.eval()
    with torch.no_grad():
        for i, (pos, fx, y) in enumerate(warmup_loader):
            if i >= warmup_runs: break
            x = pos.to(device)
            _ = model(x, None)
    torch.cuda.synchronize(device=device)

def run_inference(model, test_loader, device, metric_func):
    model.eval()
    total_time_s, total_loss, num_samples = 0.0, 0.0, 0
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        for pos, fx, y in test_loader:
            x, y = pos.to(device), y.to(device)
            num_samples += x.size(0)

            torch.cuda.synchronize(device=device)
            start_event.record()
            
            # --- [方案一] 智能解包逻辑 ---
            model_output = model(x, None)
            if isinstance(model_output, tuple):
                out = model_output[0]
            else:
                out = model_output
            # ---------------------------

            end_event.record()
            torch.cuda.synchronize(device=device)
            total_time_s += start_event.elapsed_time(end_event) / 1000.0
            
            # Airfoil 任务没有归一化，直接计算loss
            loss = metric_func(out.squeeze(-1), y)
            total_loss += loss.item() * x.size(0)

    avg_loss = total_loss / num_samples
    return total_time_s, avg_loss

# ==============================================================================
# 主流程
# ==============================================================================
if __name__ == "__main__":
    args = get_benchmark_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # --- 1. 数据加载与处理 (严格遵循 exp_airfoil.py 逻辑) ---
    print("--- 1. Loading and Processing Airfoil Dataset ---")
    INPUT_X = os.path.join(args.data_path, 'NACA_Cylinder_X.npy')
    INPUT_Y = os.path.join(args.data_path, 'NACA_Cylinder_Y.npy')
    OUTPUT_Sigma = os.path.join(args.data_path, 'NACA_Cylinder_Q.npy')

    ntrain = 1000
    ntest = 200
    s1, s2 = 221, 51

    inputX = torch.from_numpy(np.load(INPUT_X)).float()
    inputY = torch.from_numpy(np.load(INPUT_Y)).float()
    input_coords = torch.stack([inputX, inputY], dim=-1)

    output_q = torch.from_numpy(np.load(OUTPUT_Sigma)[:, 4]).float()

    x_train = input_coords[:ntrain].reshape(ntrain, -1, 2)
    y_train = output_q[:ntrain].reshape(ntrain, -1)
    x_test = input_coords[ntrain:ntrain + ntest].reshape(ntest, -1, 2)
    y_test = output_q[ntrain:ntrain + ntest].reshape(ntest, -1)
    
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, x_test, y_test),
                                              batch_size=args.batch_size, shuffle=False)
    print(f"将使用 {ntest} 个测试样本进行评估。")

    # --- 2. 加载模型 ---
    model_base = load_model_from_checkpoint(args.model_path_base, args.model_name_base, args, device, s1, s2, fun_dim=0)
    model_optimized = load_model_from_checkpoint(args.model_path_optimized, args.model_name_optimized, args, device, s1, s2, fun_dim=0)
    print("-" * 50)
    
    metric_func = TestLoss(size_average=True)
    
    # --- 3. 预热与测试 ---
    print("开始基准测试...")
    warmup_model(model_base, test_loader, args.warmup_runs, device)
    warmup_model(model_optimized, test_loader, args.warmup_runs, device)
    print("模型预热完成。")

    total_time_base, avg_error_base = run_inference(model_base, test_loader, device, metric_func)
    total_time_optimized, avg_error_optimized = run_inference(model_optimized, test_loader, device, metric_func)
            
    # --- 4. 报告结果 ---
    print("\n" + "="*60)
    print(f"模型性能与精度对比基准测试结果 (Airfoil 数据集)")
    print("="*60)
    # (报告逻辑与 Darcy 版本完全相同)
    print(f"基准模型 ({args.model_name_base}):")
    print(f"  - 总推理耗时: {total_time_base:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_base:.6f}")
    print(f"\n优化模型 ({args.model_name_optimized}):")
    print(f"  - 总推理耗し: {total_time_optimized:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_optimized:.6f}")
    print("-"*60)
    if total_time_optimized > 0 and total_time_base > 0:
        speedup = total_time_base / total_time_optimized
        print(f"🚀 加速比 (基准 / 优化): {speedup:.2f}x")
    if avg_error_base > 0:
        accuracy_change = ((avg_error_optimized - avg_error_base) / avg_error_base) * 100
        print(f"📉 精度变化 (误差增加/减少): {accuracy_change:+.2f}%")
    print("="*60)