import torch
import torch.utils.data
import time
import argparse
import numpy as np
import os
import inspect
import scipy.io

# 导入新的模型字典和与 exp_pipe.py 相同的工具
from model_dict_adaptive import get_model
from utils.testloss import TestLoss
from utils.normalizer import UnitTransformer

def get_benchmark_args():
    """
    为 Pipe 稳态任务设计的基准测试命令行参数解析器。
    """
    parser = argparse.ArgumentParser(description='Transolver Benchmark for Pipe Flow')
    
    # --- 核心路径参数 ---
    parser.add_argument('--model_path_base', type=str, required=True, 
                        help='基准模型 (例如 Transolver_2D) 的检查点路径')
    parser.add_argument('--model_path_optimized', type=str, required=True, 
                        help='优化模型 (例如 PipeAdaptiveTransolver) 的检查点路径')
    
    # --- 模型类型指定 ---
    parser.add_argument('--model_name_base', type=str, default='Transolver_Structured_Mesh_2D', 
                        help='基准模型的类型名称')
    parser.add_argument('--model_name_optimized', type=str, default='PipeAdaptiveTransolver', 
                        help='优化模型的类型名称')
                        
    # --- 数据与任务参数 ---
    parser.add_argument('--data_path', type=str, required=True, 
                        help='包含 Pipe .npy 数据集的目录路径')
    parser.add_argument('--batch_size', type=int, default=16, 
                        help='推理时的批处理大小')
    
    # --- 硬件与性能控制 ---
    parser.add_argument('--gpu', type=str, default='0', 
                        help='要使用的GPU索引')
    parser.add_argument('--warmup_runs', type=int, default=10, 
                        help='在正式计时前，用多少个批次来进行预热')
    
    # --- 模型架构参数 (提供一个统一的接口) ---
    parser.add_argument('--n-hidden', type=int, default=128, help='隐藏层维度')
    parser.add_argument('--n-layers', type=int, default=8, help='层数')
    parser.add_argument('--n-head', type=int, default=8, help="注意力头数")
    parser.add_argument('--mlp_ratio', type=int, default=2, help='MLP膨胀比例')
    parser.add_argument('--dropout', type=float, default=0.0, help='Dropout率')
    parser.add_argument('--slice_num', type=int, default=64, help='Physics-Attention切片数')
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--unified_pos', type=int, default=0)
    
    # --- 自适应模型专用参数 ---
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None,
                        help='[自适应模型专用] 每个递归层的保留比例列表')
                      
    return parser.parse_args()

def load_model_from_checkpoint(model_path, model_name, cli_args, device, s1, s2, fun_dim=0):
    """
    从检查点加载模型。
    s1, s2: 网格尺寸 H, W
    fun_dim: 函数输入维度 (Pipe任务为0)
    """
    print(f"--- 正在加载模型: {os.path.basename(model_path)} (指定类型: {model_name}) ---")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"模型文件未找到: {model_path}")
    
    temp_args = argparse.Namespace(model=model_name)
    model_class = get_model(temp_args).Model
    
    model_params = vars(cli_args).copy()
    constructor_params = inspect.signature(model_class.__init__).parameters
    
    # 智能名称适配 n_heads -> n_head
    if 'n_head' in constructor_params and 'n_heads' in model_params:
        model_params['n_head'] = model_params['n_heads']

    model_params.update({
        'space_dim': 2,
        'fun_dim': fun_dim,
        'out_dim': 1,
        'H': s1, 'W': s2
    })
    
    # 过滤掉模型构造函数不需要的参数
    final_model_params = {key: val for key, val in model_params.items() if key in constructor_params}
    
    print(f"  > 正在使用以下参数创建模型: {list(final_model_params.keys())}")
    model = model_class(**final_model_params).to(device)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"模型 '{model_name}' 加载成功。")
    return model

def warmup_model(model, warmup_loader, warmup_runs, device):
    """
    只执行模型的前向传播以预热CUDA内核。
    """
    model.eval()
    with torch.no_grad():
        for i, (pos, fx, y) in enumerate(warmup_loader):
            if i >= warmup_runs:
                break
            x = pos.to(device)
            _ = model(x, None) # 稳态任务调用方式
    
    torch.cuda.synchronize(device=device)

def run_inference(model, test_loader, device, metric_func, y_normalizer):
    """
    在一个完整的数据集上运行推理，测量总时间和总误差。
    """
    model.eval()
    
    total_time_s = 0.0
    total_loss = 0.0
    num_samples = 0

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    with torch.no_grad():
        for pos, fx, y in test_loader:
            x, y = pos.to(device), y.to(device)
            num_samples += x.size(0)

            torch.cuda.synchronize(device=device)
            start_event.record()
            
            model_output = model(x, None)
            
            # 检查返回值的类型
            if isinstance(model_output, tuple):
                # 如果是元组，解包它
                out = model_output[0]
            else:
                # 如果不是元组，直接使用它
                out = model_output
            
            end_event.record()
            torch.cuda.synchronize(device=device)
            total_time_s += start_event.elapsed_time(end_event) / 1000.0
            
            # --- 精度计算 (包含反归一化) ---
            out = y_normalizer.decode(out.squeeze(-1))
            # 注意：y本身没有在DataLoader中归一化，所以直接使用
            # 如果y在加载时也被归一化了，这里也需要 y = y_normalizer.decode(y)
            
            loss = metric_func(out, y)
            total_loss += loss.item() * x.size(0)

    avg_loss = total_loss / num_samples
    return total_time_s, avg_loss

# ==============================================================================
# 主流程
# ==============================================================================
if __name__ == "__main__":
    args = get_benchmark_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # --- 1. 数据加载与处理 (严格遵循 exp_pipe.py 逻辑) ---
    print("--- 1. Loading and Processing Pipe Dataset ---")
    INPUT_X = os.path.join(args.data_path, 'Pipe_X.npy')
    INPUT_Y = os.path.join(args.data_path, 'Pipe_Y.npy')
    OUTPUT_Sigma = os.path.join(args.data_path, 'Pipe_Q.npy')

    ntrain = 1000
    ntest = 200
    N = 1200 # 总样本数

    inputX = np.load(INPUT_X)
    inputX = torch.tensor(inputX, dtype=torch.float)
    inputY = np.load(INPUT_Y)
    inputY = torch.tensor(inputY, dtype=torch.float)
    input_coords = torch.stack([inputX, inputY], dim=-1)

    output_q = np.load(OUTPUT_Sigma)[:, 0]
    output_q = torch.tensor(output_q, dtype=torch.float)

    s1, s2 = input_coords.shape[1], input_coords.shape[2] # 获取网格尺寸: 129x129

    # 划分训练/测试集
    x_train_coords = input_coords[:ntrain].reshape(ntrain, -1, 2)
    y_train_data = output_q[:ntrain].reshape(ntrain, -1)
    x_test_coords = input_coords[-ntest:].reshape(ntest, -1, 2)
    y_test_data = output_q[-ntest:].reshape(ntest, -1)

    # 数据归一化 (UnitTransformer)
    x_normalizer = UnitTransformer(x_train_coords) # 虽然x没在训练中使用，但保持逻辑一致性
    y_normalizer = UnitTransformer(y_train_data)
    
    # 注意：只对输入坐标进行归一化，y_test_data保持原始尺度用于计算loss
    x_test_normalized = x_normalizer.encode(x_test_coords)
    
    # 将归一化器移动到GPU
    x_normalizer.to(device)
    y_normalizer.to(device)

    # 创建 DataLoader
    # DataLoader 格式: (输入坐标, 占位符, 真实标签)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_normalized, x_test_normalized, y_test_data),
                                              batch_size=args.batch_size, shuffle=False)
    print(f"将使用 {ntest} 个测试样本进行评估，批大小为 {args.batch_size}")
    
    # --- 2. 加载模型 ---
    # 确保在加载自适应模型时传递了正确的 capacity_ratios
    if args.model_name_optimized == 'PipeAdaptiveTransolver' and args.capacity_ratios is None:
        raise ValueError("--capacity_ratios a_ratios 是自适应模型必需的参数。")

    model_base = load_model_from_checkpoint(args.model_path_base, args.model_name_base, args, device, s1, s2, fun_dim=0)
    model_optimized = load_model_from_checkpoint(args.model_path_optimized, args.model_name_optimized, args, device, s1, s2, fun_dim=0)
    print("-" * 50)

    metric_func = TestLoss(size_average=True) # 使用 TestLoss 来计算 L2 error
    
    # --- 3. 预热与测试 ---
    print(f"开始基准测试 (每次测试前独立预热 {args.warmup_runs} 个批次)...")
    
    # 预热
    try:
        print("\n--- 正在预热模型 ... ---")
        warmup_model(model_base, test_loader, args.warmup_runs, device)
        warmup_model(model_optimized, test_loader, args.warmup_runs, device)
        print("模型预热完成。")
    except StopIteration:
        print("警告：测试数据加载器为空，跳过预热。")

    # --- 正式测试 ---
    total_time_base, avg_error_base = run_inference(model_base, test_loader, device, metric_func, y_normalizer)
    total_time_optimized, avg_error_optimized = run_inference(model_optimized, test_loader, device, metric_func, y_normalizer)
            
    # --- 4. 报告结果 ---
    print("\n" + "=" * 60)
    print(f"                 模型性能与精度对比基准测试结果 (Pipe 数据集)")
    print("=" * 60)
    
    print(f"基准模型 ({args.model_name_base}):")
    print(f"  - 总推理耗时 (在 {ntest} 个样本上): {total_time_base:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_base:.6f}")

    print(f"\n优化模型 ({args.model_name_optimized}):")
    print(f"  - 总推理耗时 (在 {ntest} 个样本上): {total_time_optimized:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_optimized:.6f}")
    
    print("-" * 60)
    
    if total_time_optimized > 0 and total_time_base > 0:
        speedup = total_time_base / total_time_optimized
        print(f"🚀 加速比 (基准 / 优化): {speedup:.2f}x")
    
    if avg_error_base > 0:
        accuracy_change = ((avg_error_optimized - avg_error_base) / avg_error_base) * 100
        print(f"📉 精度变化 (误差增加/减少): {accuracy_change:+.2f}%")
    
    print("=" * 60)