import torch
import torch.utils.data
import time
import argparse
import numpy as np
import os
import inspect
import scipy.io as scio
from einops import rearrange
import torch.nn.functional as F

# 导入新的模型字典和与 exp_darcy.py 相同的工具
from model_dict_adaptive import get_model
from utils.testloss import TestLoss
from utils.normalizer import UnitTransformer

# --- 与 exp_darcy.py 相同的辅助函数 ---
def central_diff(x: torch.Tensor, dx, resolution):
    x = rearrange(x, 'b (h w) c -> b c h w', h=resolution, w=resolution)
    kernel_x = torch.tensor([[[[-0.5, 0, 0.5]]]], dtype=x.dtype, device=x.device) / dx
    kernel_y = torch.tensor([[[[-0.5], [0], [0.5]]]], dtype=x.dtype, device=x.device) / dx
    grad_x = F.conv2d(x, kernel_x, padding='same')
    grad_y = F.conv2d(x, kernel_y, padding='same')
    grad_x = rearrange(grad_x, 'b c h w -> b (h w) c')
    grad_y = rearrange(grad_y, 'b c h w -> b (h w) c')
    return grad_x, grad_y

def get_benchmark_args():
    parser = argparse.ArgumentParser(description='Transolver Benchmark for Darcy Flow')
    # ... (参数定义与 benchmark_pipe.py 类似, 但包含 darcy 特有参数)
    parser.add_argument('--model_path_base', type=str, required=True)
    parser.add_argument('--model_path_optimized', type=str, required=True)
    parser.add_argument('--model_name_base', type=str, default='Transolver_Structured_Mesh_2D')
    parser.add_argument('--model_name_optimized', type=str, default='PipeAdaptiveTransolver')
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--warmup_runs', type=int, default=10)
    parser.add_argument('--downsample', type=int, default=5, help='Darcy特有的下采样率')
    parser.add_argument('--ntrain', type=int, default=1000, help='Darcy特有的训练样本数')
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--n-layers', type=int, default=8)
    parser.add_argument('--n-head', type=int, default=8)
    parser.add_argument('--mlp_ratio', type=int, default=1)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--slice_num', type=int, default=64)
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--unified_pos', type=int, default=1)
    parser.add_argument('--capacity_ratios', type=float, nargs='+')
    return parser.parse_args()

def load_model_from_checkpoint(model_path, model_name, cli_args, device, s, fun_dim):
    # (此函数与 benchmark_pipe.py 中的版本几乎完全相同, 只是为了完整性而包含)
    print(f"--- 正在加载模型: {os.path.basename(model_path)} ({model_name}) ---")
    if not os.path.exists(model_path): raise FileNotFoundError(f"{model_path} not found")
    model_class = get_model(argparse.Namespace(model=model_name)).Model
    model_params = vars(cli_args).copy()
    constructor_params = inspect.signature(model_class.__init__).parameters
    if 'n_head' in constructor_params and 'n_heads' in model_params:
        model_params['n_head'] = model_params['n_heads']
    model_params.update({'space_dim': 2, 'fun_dim': fun_dim, 'out_dim': 1, 'H': s, 'W': s})
    final_model_params = {k: v for k, v in model_params.items() if k in constructor_params}
    model = model_class(**final_model_params).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"模型 '{model_name}' 加载成功。")
    return model

def warmup_model(model, warmup_loader, warmup_runs, device):
    # (此函数与 benchmark_pipe.py 中的版本几乎完全相同)
    model.eval()
    with torch.no_grad():
        for i, (x, fx, y) in enumerate(warmup_loader):
            if i >= warmup_runs: break
            x, fx = x.to(device), fx.to(device)
            _ = model(x, fx=fx.unsqueeze(-1))
    torch.cuda.synchronize(device=device)

def run_inference(model, test_loader, device, metric_func, y_normalizer, dx, s):
    model.eval()
    total_time_s, total_loss, num_samples = 0.0, 0.0, 0
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        for x, fx, y in test_loader:
            x, fx, y = x.to(device), fx.to(device), y.to(device)
            num_samples += x.size(0)

            torch.cuda.synchronize(device=device)
            start_event.record()
            
            # --- [方案一] 智能解包逻辑 ---
            model_output = model(x, fx=fx.unsqueeze(-1))
            if isinstance(model_output, tuple):
                out = model_output[0]
            else:
                out = model_output
            # ---------------------------

            end_event.record()
            torch.cuda.synchronize(device=device)
            total_time_s += start_event.elapsed_time(end_event) / 1000.0
            
            out_decoded = y_normalizer.decode(out.squeeze(-1))
            
            # Darcy 任务只评估 L2 loss for benchmark, 不包含梯度项
            loss = metric_func(out_decoded, y)
            total_loss += loss.item() * x.size(0)

    avg_loss = total_loss / num_samples
    return total_time_s, avg_loss

# ==============================================================================
# 主流程
# ==============================================================================
if __name__ == "__main__":
    args = get_benchmark_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # --- 1. 数据加载与处理 (严格遵循 exp_darcy.py 逻辑) ---
    print("--- 1. Loading and Processing Darcy Dataset ---")
    train_path = os.path.join(args.data_path, 'piececonst_r421_N1024_smooth1.mat')
    test_path = os.path.join(args.data_path, 'piececonst_r421_N1024_smooth2.mat')
    ntrain, ntest = args.ntrain, 200
    r = args.downsample
    s = h = int(((421 - 1) / r) + 1)
    dx = 1.0 / s

    train_data = scio.loadmat(train_path)
    x_train_coeff = torch.from_numpy(train_data['coeff'][:ntrain, ::r, ::r][:, :s, :s].reshape(ntrain, -1)).float()
    y_train_sol = torch.from_numpy(train_data['sol'][:ntrain, ::r, ::r][:, :s, :s].reshape(ntrain, -1)).float()

    test_data = scio.loadmat(test_path)
    x_test_coeff = torch.from_numpy(test_data['coeff'][:ntest, ::r, ::r][:, :s, :s].reshape(ntest, -1)).float()
    y_test_sol = torch.from_numpy(test_data['sol'][:ntest, ::r, ::r][:, :s, :s].reshape(ntest, -1)).float()

    x_normalizer = UnitTransformer(x_train_coeff)
    y_normalizer = UnitTransformer(y_train_sol)
    x_test_coeff_norm = x_normalizer.encode(x_test_coeff)
    x_normalizer.to(device)
    y_normalizer.to(device)

    pos = torch.tensor(np.c_[np.meshgrid(np.linspace(0,1,s), np.linspace(0,1,s), indexing='ij')[0].ravel(), np.meshgrid(np.linspace(0,1,s), np.linspace(0,1,s), indexing='ij')[1].ravel()], dtype=torch.float).unsqueeze(0)
    pos_test = pos.repeat(ntest, 1, 1)
    
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, x_test_coeff_norm, y_test_sol),
                                              batch_size=args.batch_size, shuffle=False)
    print(f"将使用 {ntest} 个测试样本进行评估。")

    # --- 2. 加载模型 ---
    model_base = load_model_from_checkpoint(args.model_path_base, args.model_name_base, args, device, s, fun_dim=1)
    model_optimized = load_model_from_checkpoint(args.model_path_optimized, args.model_name_optimized, args, device, s, fun_dim=1)
    print("-" * 50)
    
    metric_func = TestLoss(size_average=True)
    
    # --- 3. 预热与测试 ---
    print("开始基准测试...")
    warmup_model(model_base, test_loader, args.warmup_runs, device)
    warmup_model(model_optimized, test_loader, args.warmup_runs, device)
    print("模型预热完成。")

    total_time_base, avg_error_base = run_inference(model_base, test_loader, device, metric_func, y_normalizer, dx, s)
    total_time_optimized, avg_error_optimized = run_inference(model_optimized, test_loader, device, metric_func, y_normalizer, dx, s)
            
    # --- 4. 报告结果 ---
    print("\n" + "="*60)
    print(f"模型性能与精度对比基准测试结果 (Darcy 数据集)")
    print("="*60)
    print(f"基准模型 ({args.model_name_base}):")
    print(f"  - 总推理耗时: {total_time_base:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_base:.6f}")
    print(f"\n优化模型 ({args.model_name_optimized}):")
    print(f"  - 总推理耗时: {total_time_optimized:.4f} 秒")
    print(f"  - 平均相对L2误差: {avg_error_optimized:.6f}")
    print("-"*60)
    if total_time_optimized > 0 and total_time_base > 0:
        speedup = total_time_base / total_time_optimized
        print(f"🚀 加速比 (基准 / 优化): {speedup:.2f}x")
    if avg_error_base > 0:
        accuracy_change = ((avg_error_optimized - avg_error_base) / avg_error_base) * 100
        print(f"📉 精度变化 (误差增加/减少): {accuracy_change:+.2f}%")
    print("="*60)