import torch
import torch.nn as nn
import numpy as np
import argparse
from tqdm import tqdm
import time
import os

# --- 导入模型定义、损失函数和数据加载逻辑 ---
from nn_module.encoder_module import SpatialEncoder2D
from nn_module.decoder_module import PointWiseDecoder2DSimple, AdaptivePointWiseDecoder2D_SteadyState
from loss_fn import pointwise_rel_l2norm_loss
from torch.utils.data import DataLoader, TensorDataset

# --- 确保随机种子一致，尽管在推理中影响较小 ---
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

def get_arguments(parser):
    """定义脚本所需的命令行参数"""
    # --- 核心参数 ---
    parser.add_argument('--ckpt_path', type=str, required=True, help='Path to the model checkpoint (.ckpt) file.')
    parser.add_argument('--adaptive', action='store_true', help='Specify if the checkpoint is for an ADAPTIVE model.')
    
    # --- 数据集相关参数 (必须与训练时一致) ---
    parser.add_argument('--data_path', type=str, required=True, help='Base path to NACA_Cylinder .npy files')
    parser.add_argument('--ntest', type=int, default=200, help='Number of testing samples')
    parser.add_argument('--downsample_x', type=int, default=10, help='Downsampling ratio for the x-axis (must match training)')
    parser.add_argument('--downsample_y', type=int, default=2, help='Downsampling ratio for the y-axis (must match training)')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')

    # --- 模型架构参数 (必须与被加载模型的训练参数一致) ---
    parser.add_argument('--propagator_depth', type=int, default=4, help='Depth of the refinement/propagator network')
    parser.add_argument('--capacity_ratios', type=float, nargs='+', default=None, help='Capacity ratios for the adaptive model (must match training)')
    parser.add_argument('--final_keep_ratio', type=float, default=0.25, help='Final keep ratio for default capacity decay (must match training)')
    
    return parser

def build_model(args, res_for_model: int):
    """根据参数构建模型架构 (与train_naca.py中完全一致)"""
    encoder = SpatialEncoder2D(2, 96, 256, 4, 6, res=res_for_model, use_ln=True)

    if args.adaptive:
        capacity_ratios = args.capacity_ratios
        if capacity_ratios is None:
            # 如果命令行未提供，则尝试使用与训练时相同的默认生成逻辑
            capacity_ratios = np.linspace(1.0, args.final_keep_ratio, args.propagator_depth).tolist()
        
        decoder = AdaptivePointWiseDecoder2D_SteadyState(
            latent_channels=256, out_channels=1,
            propagator_depth=args.propagator_depth,
            capacity_ratios=capacity_ratios,
            scale=0.5, heads=4, dim_head=32)
    else:
        decoder = PointWiseDecoder2DSimple(
            latent_channels=256, out_channels=1,
            refinement_depth=args.propagator_depth,
            res=res_for_model, scale=0.5, heads=4, dim_head=32)
            
    return encoder, decoder

def load_naca_data(args):
    """
    加载并预处理NACA_Cylinder测试数据，与train_naca.py中逻辑完全一致。
    """
    INPUT_X_PATH = os.path.join(args.data_path, 'NACA_Cylinder_X.npy')
    INPUT_Y_PATH = os.path.join(args.data_path, 'NACA_Cylinder_Y.npy')
    OUTPUT_Q_PATH = os.path.join(args.data_path, 'NACA_Cylinder_Q.npy')

    ntrain = 1000 # 用于定位测试集
    ntest = args.ntest
    r1, r2 = args.downsample_x, args.downsample_y
    s1 = int(((221 - 1) / r1) + 1) # x轴分辨率
    s2 = int(((51 - 1) / r2) + 1) # y轴分辨率

    inputX = torch.from_numpy(np.load(INPUT_X_PATH)).float()
    inputY = torch.from_numpy(np.load(INPUT_Y_PATH)).float()
    coords = torch.stack([inputX, inputY], dim=-1)
    output = torch.from_numpy(np.load(OUTPUT_Q_PATH)[:, 4]).float().unsqueeze(-1) # 选择第5个通道

    # 完整的数据集用于计算正确的归一化统计量 (这里使用训练集的统计量)
    full_x_train = coords[:ntrain, ::r1, ::r2][:, :s1, :s2].reshape(ntrain, -1, 2)
    full_y_train = output[:ntrain, ::r1, ::r2][:, :s1, :s2].reshape(ntrain, -1, 1)

    x_test = coords[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2].reshape(ntest, -1, 2)
    y_test = output[ntrain:ntrain + ntest, ::r1, ::r2][:, :s1, :s2].reshape(ntest, -1, 1)

    x_mean, x_std = full_x_train.mean(dim=[0, 1]), full_x_train.std(dim=[0, 1])
    y_mean, y_std = full_y_train.mean(dim=[0, 1]), full_y_train.std(dim=[0, 1])
    
    x_test_norm = (x_test - x_mean) / x_std
    
    test_dataloader = DataLoader(TensorDataset(x_test_norm, y_test), batch_size=args.batch_size, shuffle=False)
    
    # 确保返回用于模型构建的res值
    res_for_model = s1 # 对于非方形网格，选择其中一个作为SpatialEncoder2D的res
    
    return test_dataloader, y_mean, y_std, res_for_model

# --- 主程序 ---
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="End-to-end inference test for NACA_Cylinder models")
    opt = get_arguments(parser)
    opt = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device: {device}")

    # 1. 加载数据
    test_dataloader, y_mean, y_std, res_for_model = load_naca_data(opt)
    y_mean, y_std = y_mean.to(device), y_std.to(device)
    print(f"Test data loaded. Number of batches: {len(test_dataloader)}. Model resolution: {res_for_model}")

    # 2. 构建模型架构并加载权重
    encoder, decoder = build_model(opt, res_for_model)
    
    if not os.path.exists(opt.ckpt_path):
        raise FileNotFoundError(f"Checkpoint file not found at: {opt.ckpt_path}")
        
    ckpt = torch.load(opt.ckpt_path, map_location=device)
    encoder.load_state_dict(ckpt['encoder'])
    decoder.load_state_dict(ckpt['decoder'])
    encoder.to(device)
    decoder.to(device)
    print(f"Models loaded successfully from '{opt.ckpt_path}'")
    
    encoder.eval()
    decoder.eval()

    # 3. 执行推理并进行性能评估 (与inference_test.py完全相同)
    total_loss = 0.0
    total_time_ms = 0.0
    
    # --- GPU预热 (Warm-up) ---
    print("Warming up GPU...")
    with torch.no_grad():
        for _ in range(5):
            try:
                x_warmup, _ = next(iter(test_dataloader))
                x_warmup = x_warmup.to(device)
                z = encoder(x_warmup, x_warmup)
                _, _ = decoder(z, x_warmup, x_warmup)
            except StopIteration:
                break
    
    print("Starting inference test...")
    if use_cuda:
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        for x, y in tqdm(test_dataloader, desc="Inference"):
            x, y = x.to(device), y.to(device)
            
            if use_cuda:
                starter.record()

            input_pos = prop_pos = x
            z = encoder(input_pos, input_pos)
            x_out, _ = decoder(z, prop_pos, input_pos)

            if use_cuda:
                ender.record()
                torch.cuda.synchronize()
                batch_time = starter.elapsed_time(ender)
            else:
                start_time = time.perf_counter()
                # Dummy call for CPU time if needed
                _ = encoder(input_pos, input_pos)
                _ = decoder(z, prop_pos, input_pos)
                batch_time = (time.perf_counter() - start_time) * 1000

            total_time_ms += batch_time
            
            x_out = x_out * y_std + y_mean
            loss = pointwise_rel_l2norm_loss(x_out, y)
            total_loss += loss.item()

    # 4. 计算并打印最终结果
    avg_loss = total_loss / len(test_dataloader)
    avg_time_per_batch = total_time_ms / len(test_dataloader)
    total_samples = len(test_dataloader.dataset)
    throughput = total_samples / (total_time_ms / 1000)

    print("\n" + "="*50)
    print("      NACA_Cylinder Inference Test Results")
    print("="*50)
    print(f"  Model Type:          {'ADAPTIVE' if opt.adaptive else 'DEEP (DENSE)'}")
    print(f"  Checkpoint:          {opt.ckpt_path}")
    print("-" * 50)
    print(f"  Accuracy (Avg. Loss): {avg_loss:.6f}")
    print("-" * 50)
    print(f"  Avg. Time per Batch: {avg_time_per_batch:.2f} ms")
    print(f"  Throughput:          {throughput:.2f} samples/sec")
    print("="*50)