import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os

# --- 导入模型定义、损失函数和数据加载逻辑 ---
from nn_module.encoder_module import SpatialEncoder2D
from nn_module.decoder_module import PointWiseDecoder2DSimple, AdaptivePointWiseDecoder2D_SteadyState
from loss_fn import pointwise_rel_l2norm_loss
from torch.utils.data import DataLoader, TensorDataset

# --- 确保随机种子一致，尽管在推理中影响较小 ---
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

def get_arguments(parser):
    """定义脚本所需的命令行参数"""
    # --- 核心参数 ---
    parser.add_argument('--ckpt_path', type=str, required=True, help='Path to the model checkpoint (.ckpt) file.')
    parser.add_argument('--adaptive', action='store_true', help='Specify if the checkpoint is for an ADAPTIVE model.')
    
    # --- 数据集相关参数 (必须与训练时一致) ---
    parser.add_argument('--data_path', type=str, required=True, help='Base path to Pipe .npy files')
    parser.add_argument('--ntest', type=int, default=200, help='Number of testing samples')
    parser.add_argument('--downsample_res', type=int, default=8, help='Downsampling ratio for the grid')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')

    # --- 模型架构参数 (必须与被加载模型的训练参数一致) ---
    parser.add_argument('--propagator_depth', type=int, default=4, help='Depth of the refinement/propagator network')
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None, help='Capacity ratios for the adaptive model')
    parser.add_argument('--final_keep_ratio', type=float, default=0.25, help='Final keep ratio for default capacity decay')
    
    return parser

def build_model(args, res: int):
    """根据参数构建模型架构 (与train_pipe.py中完全一致)"""
    encoder = SpatialEncoder2D(2, 96, 256, 4, 6, res=res, use_ln=True)

    if args.adaptive:
        capacity_ratios = args.capacity_ratios
        if capacity_ratios is None:
            capacity_ratios = np.linspace(1.0, args.final_keep_ratio, args.propagator_depth).tolist()
        
        decoder = AdaptivePointWiseDecoder2D_SteadyState(
            latent_channels=256, out_channels=1,
            propagator_depth=args.propagator_depth,
            capacity_ratios=capacity_ratios,
            scale=0.5, heads=4, dim_head=32)
    else:
        decoder = PointWiseDecoder2DSimple(
            latent_channels=256, out_channels=1,
            refinement_depth=args.propagator_depth,
            res=res, scale=0.5, heads=4, dim_head=32)
            
    return encoder, decoder

def load_pipe_data(args):
    """加载并预处理Pipe测试数据 (与train_pipe.py中逻辑一致)"""
    INPUT_X_PATH = os.path.join(args.data_path, 'Pipe_X.npy')
    INPUT_Y_PATH = os.path.join(args.data_path, 'Pipe_Y.npy')
    OUTPUT_Q_PATH = os.path.join(args.data_path, 'Pipe_Q.npy')

    ntrain = 1000 # 用于定位测试集
    ntest = args.ntest
    r = args.downsample_res
    s = int(((129 - 1) / r) + 1)

    inputX = torch.from_numpy(np.load(INPUT_X_PATH)).float()
    inputY = torch.from_numpy(np.load(INPUT_Y_PATH)).float()
    coords = torch.stack([inputX, inputY], dim=-1)
    output = torch.from_numpy(np.load(OUTPUT_Q_PATH)[:, 0]).float().unsqueeze(-1)
    
    # 完整的数据集用于计算正确的归一化统计量
    full_x_train = coords[:ntrain, ::r, ::r][:, :s, :s].reshape(ntrain, -1, 2)
    full_y_train = output[:ntrain, ::r, ::r][:, :s, :s].reshape(ntrain, -1, 1)

    x_test = coords[-ntest:, ::r, ::r][:, :s, :s].reshape(ntest, -1, 2)
    y_test = output[-ntest:, ::r, ::r][:, :s, :s].reshape(ntest, -1, 1)

    x_mean, x_std = full_x_train.mean(dim=[0, 1]), full_x_train.std(dim=[0, 1])
    y_mean, y_std = full_y_train.mean(dim=[0, 1]), full_y_train.std(dim=[0, 1])
    
    x_test = (x_test - x_mean) / x_std
    
    test_dataloader = DataLoader(TensorDataset(x_test, y_test), batch_size=args.batch_size, shuffle=False)
    
    return test_dataloader, y_mean, y_std, s

# --- 主程序 ---
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="End-to-end inference test for Pipe dataset models")
    opt = get_arguments(parser)
    opt = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device: {device}")

    # 1. 加载数据
    test_dataloader, y_mean, y_std, res = load_pipe_data(opt)
    y_mean, y_std = y_mean.to(device), y_std.to(device)
    print(f"Test data loaded. Number of batches: {len(test_dataloader)}")

    # 2. 构建模型架构并加载权重
    encoder, decoder = build_model(opt, res=res)
    
    if not os.path.exists(opt.ckpt_path):
        raise FileNotFoundError(f"Checkpoint file not found at: {opt.ckpt_path}")
        
    ckpt = torch.load(opt.ckpt_path, map_location=device)
    encoder.load_state_dict(ckpt['encoder'])
    decoder.load_state_dict(ckpt['decoder'])
    encoder.to(device)
    decoder.to(device)
    print(f"Models loaded successfully from '{opt.ckpt_path}'")
    
    encoder.eval()
    decoder.eval()

    # 3. 执行推理并进行性能评估
    total_loss = 0.0
    total_time_ms = 0.0
    
    # --- GPU预热 (Warm-up) ---
    # 运行几次前向传播，以确保CUDA kernel已被编译和加载，避免首次运行的开销影响计时
    print("Warming up GPU...")
    with torch.no_grad():
        for _ in range(5):
            try:
                x_warmup, y_warmup = next(iter(test_dataloader))
                x_warmup, y_warmup = x_warmup.to(device), y_warmup.to(device)
                z = encoder(x_warmup, x_warmup)
                _, _ = decoder(z, x_warmup, x_warmup)
            except StopIteration:
                # 如果测试集太小，可能预热就耗尽了
                break
    
    print("Starting inference test...")
    # --- 精确计时和评估 ---
    if use_cuda:
        # 使用CUDA Events进行精确的GPU计时
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        for x, y in tqdm(test_dataloader, desc="Inference"):
            x, y = x.to(device), y.to(device)
            
            # --- 计时开始 ---
            if use_cuda:
                starter.record()

            # --- 端到端前向传播 ---
            input_pos = prop_pos = x
            z = encoder(input_pos, input_pos)
            x_out, _ = decoder(z, prop_pos, input_pos)

            # --- 计时结束 ---
            if use_cuda:
                ender.record()
                torch.cuda.synchronize() # !! 关键：等待所有GPU核心完成计算
                batch_time = starter.elapsed_time(ender) # 毫秒
            else: # CPU 计时回退方案
                # (注意：CPU计时不如CUDA Events精确，但可用于无GPU环境)
                start_time = time.perf_counter()
                # ... (CPU上的前向传播) ...
                batch_time = (time.perf_counter() - start_time) * 1000 # 毫秒

            total_time_ms += batch_time
            
            # --- 计算损失 (精度) ---
            x_out = x_out * y_std + y_mean # 反归一化
            loss = pointwise_rel_l2norm_loss(x_out, y)
            total_loss += loss.item()

    # 4. 计算并打印最终结果
    avg_loss = total_loss / len(test_dataloader)
    avg_time_per_batch = total_time_ms / len(test_dataloader)
    total_samples = len(test_dataloader.dataset)
    throughput = total_samples / (total_time_ms / 1000) # 样本数 / 秒

    print("\n" + "="*50)
    print("           Inference Test Results")
    print("="*50)
    print(f"  Model Type:          {'ADAPTIVE' if opt.adaptive else 'DEEP (DENSE)'}")
    print(f"  Checkpoint:          {opt.ckpt_path}")
    print("-" * 50)
    print(f"  Accuracy (Avg. Loss): {avg_loss:.6f}")
    print("-" * 50)
    print(f"  Avg. Time per Batch: {avg_time_per_batch:.2f} ms")
    print(f"  Throughput:          {throughput:.2f} samples/sec")
    print("="*50)