#!/usr/bin/env python
#-*- coding:utf-8 _*-
# Filename: benchmark_SS.py
# (SS for Single Step)
# 用于评估 GNOT 和 SR_GNOT_single_step 等单步预测模型的性能。

import sys
import os
import time
import json
import numpy as np
import torch
from tqdm import tqdm

# 确保项目路径在 sys.path 中
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# --- 关键：从 data_utils.py 导入，而不是 data_utils_fix.py ---
from args import get_args
from data_utils import get_dataset, get_model, get_loss_func, MIODataLoader
from utils import get_seed, get_num_params

def warmup(models, loader, device):
    """GPU预热"""
    try:
        # 从 loader 中获取一个批次的数据
        g_batch, u_p_batch, inputs_f_batch = next(iter(loader))
        
        # 将数据移动到设备
        g_batch = g_batch.to(device)
        u_p_batch = u_p_batch.to(device)
        # inputs_f_batch 的处理依赖其具体类型
        if hasattr(inputs_f_batch, 'to'):
            inputs_f_batch = inputs_f_batch.to(device)
        
        print("Warming up GPU with a sample batch...")
        for _ in tqdm(range(10), desc="Warming up", leave=False):
            for model in models.values():
                with torch.no_grad():
                    # 调用模型 forward
                    _ = model(g_batch.clone(), u_p_batch.clone(), inputs_f_batch)
        torch.cuda.synchronize()
        print("GPU warmup complete.")
    except StopIteration:
        print("Warning: DataLoader is empty, skipping warmup.")
    except Exception as e:
        print(f"An error occurred during warmup: {e}")
        raise e

def evaluate_accuracy(model, loader, metric_func, device):
    """评估模型的单步预测误差 (包含形状修正逻辑)"""
    all_metrics = []
    pbar = tqdm(loader, desc="Evaluating Accuracy", leave=False)
    
    with torch.no_grad():
        for g_batch, u_p_batch, inputs_f_batch in pbar:
            # 数据移动到设备
            g_batch = g_batch.to(device)
            u_p_batch = u_p_batch.to(device)
            if hasattr(inputs_f_batch, 'to'):
                inputs_f_batch = inputs_f_batch.to(device)
            
            # 模型前向传播
            pred_out = model(g_batch, u_p_batch, inputs_f_batch)
            # 兼容返回元组的模型
            if isinstance(pred_out, tuple):
                pred_out = pred_out[0]

            # --- 关键修改：明确地处理形状 ---
            
            # 1. 获取目标张量，但不立即 squeeze
            target_raw = g_batch.ndata['y']
            
            # 2. 获取预期的输出维度
            #    我们从 pred_out 推断，因为它来自模型，形状更可靠
            num_nodes = pred_out.shape[0]
            output_dim = pred_out.shape[1]
            
            # 3. 强制重塑 target，使其形状与 pred_out 完全一致
            try:
                target = target_raw.view(num_nodes, output_dim)
            except RuntimeError as e:
                print("\n--- FATAL SHAPE ERROR ---")
                print(f"Failed to reshape target tensor.")
                print(f"pred_out.shape: {pred_out.shape}")
                print(f"target_raw.shape: {target_raw.shape}")
                print(f"Error: {e}")
                print("This usually happens if the total number of elements in the target tensor")
                print("does not match the expected number from the prediction tensor.")
                print("-------------------------\n")
                raise e

            # 4. 确保 pred_out 也没有多余的维度 (防御性)
            pred = pred_out.view(num_nodes, output_dim)
            # --- 修改结束 ---

            # 计算指标，现在 pred 和 target 的形状保证一致
            _, _, metric = metric_func(g_batch, pred, target)
            all_metrics.append(metric)

    if not all_metrics:
        return np.nan
    return np.mean(all_metrics)


def evaluate_inference_speed(model, loader, device, num_runs):
    """评估模型的推理吞吐量 (样本/秒)"""
    total_throughputs = []
    
    with torch.no_grad():
        for run in range(num_runs):
            total_samples = 0
            
            torch.cuda.synchronize(device=device)
            start_time = time.perf_counter()

            pbar = tqdm(loader, desc=f"Evaluating Speed (Run {run+1}/{num_runs})", leave=False)
            for g_batch, u_p_batch, inputs_f_batch in pbar:
                # 数据移动到设备
                g_batch = g_batch.to(device)
                u_p_batch = u_p_batch.to(device)
                if hasattr(inputs_f_batch, 'to'):
                    inputs_f_batch = inputs_f_batch.to(device)
                
                # 模型前向传播
                _ = model(g_batch, u_p_batch, inputs_f_batch)
                
                # DGL图的 batch_size 属性代表了批次中的图数量
                total_samples += g_batch.batch_size 

            torch.cuda.synchronize(device=device)
            end_time = time.perf_counter()

            elapsed_time = end_time - start_time
            if elapsed_time > 1e-9:
                # 吞吐量 = 总样本数 / 总时间
                total_throughputs.append(total_samples / elapsed_time)

    if not total_throughputs:
        return 0.0
    return np.mean(total_throughputs)

def process_and_print_results(results, args):
    """汇总、计算并打印最终的基准测试结果表格"""
    processed = {}
    if not results: return {}
    
    # 将第一个模型作为基准
    base_model_name = os.path.splitext(os.path.basename(args.model_paths[0]))[0]

    for model_name, data in results.items():
        processed[model_name] = {
            'mean_error': float(data['mean_error']),
            'throughput_sps': float(data['throughput_sps']) # samples per second
        }

    base_throughput = processed.get(base_model_name, {}).get('throughput_sps', 1.0)
    if base_throughput < 1e-9: base_throughput = 1.0
    
    for model_name in processed:
        processed[model_name]['speedup'] = processed[model_name]['throughput_sps'] / base_throughput

    header = f"{'Model':<60} | {'Mean Rel2 Error':<20} | {'Throughput (samples/s)':<25} | {'Speedup':<10}"
    print("\n" + "="*len(header))
    print(" " * (len(header)//2 - 14) + "SINGLE-STEP BENCHMARK SUMMARY")
    print("="*len(header))
    print(header)
    print("-"*len(header))
    
    for model_name, data in processed.items():
        print(f"{model_name:<60} | {data['mean_error']:<20.4f} | {data['throughput_sps']:<25.2f} | {data['speedup']:<10.2f}x")
    
    print("="*len(header))
    return processed

def run_benchmark(args):
    """主函数，执行整个基准测试流程"""
    get_seed(args.seed, printout=True)
    
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() and not getattr(args, 'no_cuda', False) else "cpu")
    print(f"Using device: {device}")

    if not args.model_paths:
        print("错误: 未提供 --model_paths 参数。请指定要测试的模型文件路径。")
        return

    # --- 1. 加载数据 ---
    print("\n--- 1. Loading Benchmark Data (from data_utils.py) ---")
    _ , test_dataset = get_dataset(args)

    test_loader = MIODataLoader(
        test_dataset,
        batch_size=args.benchmark_batch_size,
        shuffle=False,
        drop_last=False
    )
    print(f"Data loaded. Number of test samples: {len(test_dataset)}")
    
    args.normalizer = test_dataset.y_normalizer.to(device) if hasattr(test_dataset, 'y_normalizer') and test_dataset.y_normalizer is not None else None
    metric_func = get_loss_func(name='rel2', args=args, normalizer=args.normalizer)

    # --- 2. 加载模型 ---
    print("\n--- 2. Loading Models ---")
    models = {}
    for model_path in args.model_paths:
        try:
            model_name_from_path = os.path.splitext(os.path.basename(model_path))[0]
            print(f"Loading model '{model_name_from_path}' from {model_path}...")
            checkpoint = torch.load(model_path, map_location=device)
            model_args_from_ckpt = checkpoint['args']
            print(model_args_from_ckpt)
            model = get_model(model_args_from_ckpt).to(device)
            model.load_state_dict(checkpoint['model'])
            model.eval()
            models[model_name_from_path] = model
            print(f" -> Model loaded successfully. Params: {get_num_params(model)/1e6:.2f}M")
        except Exception as e:
            print(f"Error loading model from {model_path}: {e}")
            raise e
    
    if not models:
        print("No models were loaded. Exiting.")
        return

    # --- 3. 预热阶段 ---
    print("\n--- 3. Warming Up GPU ---")
    warmup(models, test_loader, device)

    # --- 4. 核心基准测试 ---
    print("\n--- 4. Running Benchmark (in FP32 mode) ---")
    results = {}
    for model_name, model in models.items():
        print(f"\nBenchmarking model: {model_name}")
        
        mean_error = evaluate_accuracy(
            model, test_loader, metric_func, device
        )
        
        avg_throughput = evaluate_inference_speed(
            model, test_loader, device, args.speed_test_runs
        )
        
        results[model_name] = {
            'mean_error': mean_error,
            'throughput_sps': avg_throughput
        }

    # --- 5. 结果汇总与报告 ---
    print("\n--- 5. Benchmark Results ---")
    processed_results = process_and_print_results(results, args)
    
    if args.save_path:
        save_dir = os.path.dirname(args.save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(args.save_path, 'w') as f:
            json.dump(processed_results, f, indent=4)
        print(f"\nDetailed results saved to {args.save_path}")

if __name__ == '__main__':
    args = get_args()
    
    # 为 benchmark 添加必要的默认参数 (如果 args.py 中没有)
    defaults = {
        'model_paths': [], 'save_path': None, 
        'speed_test_runs': 3, 'benchmark_batch_size': 16
    }
    for key, value in defaults.items():
        if not hasattr(args, key):
            setattr(args, key, value)
    
    run_benchmark(args)