# benchmark_pipe_adapt.py

import os
import argparse
import numpy as np
import torch
import torch.utils.data
import inspect

# ----------------------------------------------------
# 1. 导入项目模块
#    请确保这些模块在您的Python路径中
# ----------------------------------------------------
# 假设您的模型字典可以智能地处理自适应和基准模型
# 如果不行，您可能需要像之前讨论的那样导入两个不同的 get_model 函数
from model_dict_adaptive import get_model
from utils.testloss import TestLoss
from utils.normalizer import UnitTransformer

# ==============================================================================
#  2. 参数解析 (与原始benchmark脚本保持一致)
# ==============================================================================

def get_benchmark_args():
    """为 Pipe 稳态任务设计的白盒基准测试命令行参数解析器。"""
    parser = argparse.ArgumentParser(description='Transolver Core Block Benchmark for Pipe Flow')
    
    # --- 核心路径参数 ---
    parser.add_argument('--model_path_base', type=str, required=True, help='基准模型 (例如 Transolver_2D) 的检查点路径')
    parser.add_argument('--model_path_optimized', type=str, required=True, help='优化模型 (例如 PipeAdaptiveTransolver) 的检查点路径')
    
    # --- 模型类型指定 ---
    parser.add_argument('--model_name_base', type=str, default='Transolver_Structured_Mesh_2D', help='基准模型的类型名称')
    parser.add_argument('--model_name_optimized', type=str, default='PipeAdaptiveTransolver', help='优化模型的类型名称')
                        
    # --- 数据与任务参数 ---
    parser.add_argument('--data_path', type=str, required=True, help='包含 Pipe .npy 数据集的目录路径')
    parser.add_argument('--batch_size', type=int, default=16, help='推理时的批处理大小')
    
    # --- 硬件与性能控制 ---
    parser.add_argument('--gpu', type=str, default='0', help='要使用的GPU索引')
    parser.add_argument('--warmup_runs', type=int, default=10, help='在正式计时前，用多少个批次来进行预热')
    
    # --- 模型架构参数 (提供一个统一的接口) ---
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--n-layers', type=int, default=8)
    parser.add_argument('--n-head', type=int, default=8)
    parser.add_argument('--mlp_ratio', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--slice_num', type=int, default=64)
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None, help='[自适应模型专用]')
                      
    return parser.parse_args()

# ==============================================================================
#  3. 辅助函数 (复用自原始benchmark脚本)
# ==============================================================================

def load_model_from_checkpoint(model_path, model_name, cli_args, device, s1, s2, fun_dim=0):
    """从检查点加载模型，适配不同模型的构造函数。"""
    print(f"--- 正在加载模型: {os.path.basename(model_path)} (类型: {model_name}) ---")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"模型文件未找到: {model_path}")
    
    temp_args = argparse.Namespace(model=model_name)
    model_class = get_model(temp_args).Model
    
    model_params = vars(cli_args).copy()
    constructor_params = inspect.signature(model_class.__init__).parameters
    
    if 'n_head' in constructor_params and 'n_heads' in model_params:
        model_params['n_head'] = model_params['n_heads']

    model_params.update({'space_dim': 2, 'fun_dim': fun_dim, 'out_dim': 1, 'H': s1, 'W': s2})
    
    final_model_params = {key: val for key, val in model_params.items() if key in constructor_params}
    
    print(f"  > 正在使用以下参数创建模型: {list(final_model_params.keys())}")
    model = model_class(**final_model_params).to(device)
    
    checkpoint = torch.load(model_path, map_location=device)
    # 增加对不同检查点格式的兼容性
    state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
    model.load_state_dict(state_dict)
    model.eval()
    print(f"模型 '{model_name}' 加载成功。")
    return model

def warmup_model(model, warmup_loader, warmup_runs, device):
    """只执行模型的前向传播以预热CUDA内核。"""
    model.eval()
    with torch.no_grad():
        for i, (pos, _, _) in enumerate(warmup_loader):
            if i >= warmup_runs:
                break
            x = pos.to(device)
            _ = model(x, None) # 稳态任务调用方式
    torch.cuda.synchronize(device=device)

# ==============================================================================
#  4. 核心推理与计时函数 (新设计)
# ==============================================================================

def run_inference_with_internal_profiling(model, test_loader, device, metric_func, y_normalizer):
    """
    在一个完整的数据集上运行推理，不进行外部计时，
    而是依赖并收集模型内部通过 profiler_dict 传递的时间和误差。
    """
    model.eval()
    
    profiler_dict = {}
    total_loss = 0.0
    num_samples = 0

    with torch.no_grad():
        for pos, _, y in test_loader:
            x, y = pos.to(device), y.to(device)
            num_samples += x.size(0)

            # 将 profiler_dict 传入 forward 方法
            # 模型会在内部填充这个字典
            model_output = model(x, None, profiler_dict=profiler_dict)
            
            # 兼容不同模型的返回值
            if isinstance(model_output, tuple):
                out = model_output[0]
            else:
                out = model_output
            
            # --- 精度计算 ---
            out = y_normalizer.decode(out.squeeze(-1))
            loss = metric_func(out, y)
            total_loss += loss.item() * x.size(0)

    # 从 profiler_dict 中提取由模型内部累加的总时间
    total_core_time_s = profiler_dict.get('core_blocks_time', 0.0)
    
    avg_loss = total_loss / num_samples
    return total_core_time_s, avg_loss

# ==============================================================================
#  5. 主流程
# ==============================================================================
if __name__ == "__main__":
    args = get_benchmark_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # --- 数据加载 (与原始benchmark脚本完全相同) ---
    print("--- 1. Loading and Processing Pipe Dataset ---")
    INPUT_X = os.path.join(args.data_path, 'Pipe_X.npy')
    INPUT_Y = os.path.join(args.data_path, 'Pipe_Y.npy')
    OUTPUT_Sigma = os.path.join(args.data_path, 'Pipe_Q.npy')

    ntrain = 1000
    ntest = 200
    
    inputX = np.load(INPUT_X)
    inputY = np.load(INPUT_Y)
    input_coords = torch.from_numpy(np.stack([inputX, inputY], axis=-1)).float()
    output_q = torch.from_numpy(np.load(OUTPUT_Sigma)[:, 0]).float()

    s1, s2 = input_coords.shape[1], input_coords.shape[2]

    x_train_coords = input_coords[:ntrain].reshape(ntrain, -1, 2)
    y_train_data = output_q[:ntrain].reshape(ntrain, -1)
    x_test_coords = input_coords[-ntest:].reshape(ntest, -1, 2)
    y_test_data = output_q[-ntest:].reshape(ntest, -1)

    y_normalizer = UnitTransformer(y_train_data)
    y_normalizer.to(device)

    # 注意：Pipe任务的输入坐标x通常不进行归一化
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(x_test_coords, x_test_coords, y_test_data),
        batch_size=args.batch_size, shuffle=False
    )
    print(f"将使用 {ntest} 个测试样本进行评估，批大小为 {args.batch_size}")
    
    # --- 加载模型 ---
    model_base = load_model_from_checkpoint(args.model_path_base, args.model_name_base, args, device, s1, s2)
    model_optimized = load_model_from_checkpoint(args.model_path_optimized, args.model_name_optimized, args, device, s1, s2)
    print("-" * 50)

    metric_func = TestLoss(size_average=True)
    
    # --- 预热与测试 ---
    print(f"开始基准测试 (每次测试前独立预热 {args.warmup_runs} 个批次)...")
    
    try:
        print("\n--- 正在预热模型 ... ---")
        warmup_model(model_base, test_loader, args.warmup_runs, device)
        warmup_model(model_optimized, test_loader, args.warmup_runs, device)
        print("模型预热完成。")
    except Exception as e:
        print(f"预热期间发生错误: {e}")

    # --- 正式测试 ---
    print("\n--- 正在执行核心模块性能测试 ... ---")
    core_time_base, avg_error_base = run_inference_with_internal_profiling(model_base, test_loader, device, metric_func, y_normalizer)
    core_time_optimized, avg_error_optimized = run_inference_with_internal_profiling(model_optimized, test_loader, device, metric_func, y_normalizer)
            
    # --- 报告结果 ---
    print("\n" + "=" * 60)
    print("           核心计算模块性能与精度对比 (白盒-内部计时)")
    print("=" * 60)
    
    print(f"基准模型 ({args.model_name_base}):")
    print(f"  - 核心模块总耗时: {core_time_base:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_base:.6f}")

    print(f"\n优化模型 ({args.model_name_optimized}):")
    print(f"  - 核心模块总耗时: {core_time_optimized:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_optimized:.6f}")
    
    print("-" * 60)
    
    if core_time_optimized > 0 and core_time_base > 0:
        speedup = core_time_base / core_time_optimized
        print(f"🚀 核心算法加速比 (基准 / 优化): {speedup:.2f}x")
    
    if avg_error_base > 0:
        accuracy_change = ((avg_error_optimized - avg_error_base) / avg_error_base) * 100
        print(f"📉 精度变化 (误差增加/减少): {accuracy_change:+.2f}%")
    
    print("=" * 60)