import torch
import numpy as np
import time
from argparse import ArgumentParser
from torch.utils.data import DataLoader

# 导入所有需要的模型组件，与 main.py 保持一致
from models import (
    IPOTBasicPreprocessor,
    IPOTEncoder,
    IPOTProcessor,
    IPOTDecoder,
    EncoderProcessorDecoder
)
from models.ipot.ipot_processor_adapt import IPOTProcessorAdapt

# 导入所有需要的数据集类，与 main.py 保持一致
from pde_datasets.datasets import *

def load_test_data(args):
    """
    一个完整复现 main.py 数据加载逻辑的函数。
    它根据 args.data_name 返回对应的 test_loader，并动态更新 args。
    """
    data_dir = './data/'
    print(f"\n正在为 '{args.data_name}' 加载测试数据...")

    # 这部分逻辑完整克隆自 main.py
    if args.data_name == 'burgers':
        args.datapath = data_dir + 'burgers_data_R10.mat'
        test_dataset = Burgers(
            args.datapath, nx=args.nx, sub=args.sub, n_test=args.n_test)
        args.input_channel = 1; args.pos_channel = 1; args.output_channel = 1

    elif args.data_name == 'darcyflow':
        args.datapath_test = data_dir + 'piececonst_r421_N1024_smooth2.mat'
        # 注意：这里的sub参数根据main.py中的逻辑进行了调整，尽管main.py的逻辑有些复杂，这里简化为直接计算
        test_dataset = DarcyFlow(datapath=args.datapath_test, nx=421, sub=int(421/args.nx * args.sub) if args.nx != 0 else 1, num=args.n_test)
        args.input_channel = 1; args.pos_channel = 2; args.output_channel = 1

    elif args.data_name == 'navierstokes_Ve-3':
        args.datapath = data_dir + 'ns_V1e-3_N5000_T50.mat'
        # 精确复刻 main.py 中的参数设置
        args.T_start = 0; args.T_in = 10; args.T_out = 10
        test_dataset = NavierStokes(
            datapath=args.datapath, nx=args.nx, sub=args.sub,
            T_start=args.T_start, T_in=args.T_in, T_out=args.T_out,
            n_test=args.n_test, is_train=False)
        args.input_channel = 10; args.pos_channel = 2; args.output_channel = 1

    elif args.data_name == 'airfoil':
        args.input1_path = data_dir + 'NACA_Cylinder_X.npy'
        args.input2_path = data_dir + 'NACA_Cylinder_Y.npy'
        args.output_path = data_dir + 'NACA_Cylinder_Q.npy'
        test_dataset = Airfoil(
            input1_path=args.input1_path, input2_path=args.input2_path,
            output_path=args.output_path, n_train=args.n_train, n_test=args.n_test)
        args.input_channel = 2; args.pos_channel = 2; args.output_channel = 1

    elif args.data_name == 'pipe':
        pipe_data_path = args.data_path
        test_dataset = PipeDataset(
            data_path=pipe_data_path, n_total=args.n_total, n_train=args.n_train,
            downsample_x=args.downsamplex, downsample_y=args.downsampley, is_train=False)
        args.input_channel = 4; args.pos_channel = 2; args.output_channel = 1
        
    else:
        raise ValueError(f"Data loading for '{args.data_name}' is not implemented.")

    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    print(f"成功加载测试数据集，包含 {len(test_dataset)} 个样本。")
    return test_loader, args


def measure_inference_time(model, data_loader, device, args, num_runs=100, num_warmup=20):
    """
    在给定数据集上精确测量模型的平均推理时间。
    该版本精确复现了 `helpers.py` 中的调用逻辑。
    """
    model.to(device)
    model.eval()
    timings = []
    
    try:
        data_iterator = iter(data_loader)
        x, _ = next(data_iterator) # y 不是模型输入，我们忽略它
        x = x.to(device)
        
        # 从数据集中获取 mesh，这是解码器的查询输入
        mesh = data_loader.dataset.mesh.to(device)

    except StopIteration:
        print("错误: DataLoader 为空，无法执行基准测试。")
        return 0, 0, 0
    
    # 复刻 main.py 中的分支逻辑，判断是否为时序模型
    is_time_dependent = args.data_name in [
        'navierstokes_Ve-3', 'navierstokes_Ve-4', 'navierstokes_Ve-5', 
        'shallowwater', 'era5_temperature'
    ]

    with torch.no_grad():
        print(f"  正在进行 {num_warmup} 次预热运行...")
        for _ in range(num_warmup):
            if is_time_dependent:
                _ = model(x, mesh, args.T_out) # 时序模型调用
            else:
                _ = model(x, mesh)             # 常规模型调用
        
        print(f"  正在进行 {num_runs} 次计时运行...")
        for _ in range(num_runs):
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            
            if is_time_dependent:
                _ = model(x, mesh, args.T_out) # 时序模型调用
            else:
                _ = model(x, mesh)             # 常规模型调用

            torch.cuda.synchronize()
            end_time = time.perf_counter()
            timings.append((end_time - start_time) * 1000)

    avg_time_ms = np.mean(timings)
    std_dev_ms = np.std(timings)
    batch_size = x.size(0)
    throughput = batch_size / (avg_time_ms / 1000)
    return avg_time_ms, std_dev_ms, throughput


def build_model_from_main(args, processor_type):
    """
    一个本地函数，用于构建完整的 IPOT 模型。
    该函数精确克隆了您可成功运行的 main.py 中的模型构建逻辑。
    """
    if args.model_type == "ipot":
        # 这部分逻辑完整复制自您的 main.py
        input_channel = args.input_channel
        pos_channel = args.pos_channel
        num_bands = args.num_bands
        max_resolution = args.max_resolution
        num_latents = args.num_latents
        latent_channel = args.latent_channel
        self_per_cross_attn = args.self_per_cross_attn
        cross_heads_num = args.cross_heads_num
        self_heads_num = args.self_heads_num
        cross_heads_channel = args.cross_heads_channel
        self_heads_channel = args.self_heads_channel
        ff_mult = args.ff_mult
        latent_init_scale = args.latent_init_scale
        output_scale = args.output_scale
        output_channel = args.output_channel
        position_encoding_type = args.position_encoding_type

        # Preprocessor
        ipot_input_preprocessor = IPOTBasicPreprocessor(
            position_encoding_type=position_encoding_type,
            in_channel=input_channel,
            pos_channel=pos_channel,
            pos2fourier_position_encoding_kwargs=dict(
                num_bands=num_bands,
                max_resolution=max_resolution,
            )
        )
        # Encoder
        if position_encoding_type == 'pos2fourier':
             encoder_input_channel = input_channel + (2 * sum(num_bands) + len(num_bands))
        else:
             raise NotImplementedError(f"Position encoding type '{position_encoding_type}' not supported in benchmark script.")
             
        ipot_encoder = IPOTEncoder(
            input_channel=encoder_input_channel,
            num_latents=num_latents,
            latent_channel=latent_channel,
            cross_heads_num=cross_heads_num,
            cross_heads_channel=cross_heads_channel,
            latent_init_scale=latent_init_scale
        )
        # Processor
        if processor_type == 'standard':
            print("正在构建 Standard IPOT Processor...")
            ipot_processor = IPOTProcessor(
                self_per_cross_attn=self_per_cross_attn,
                latent_channel=latent_channel,
                self_heads_num=self_heads_num,
                self_heads_channel=self_heads_channel,
                ff_mult=ff_mult
            )
        elif processor_type == 'adaptive':
            print("正在构建 Adaptive IPOT Processor...")
            assert len(args.reduction_schedule) == args.self_per_cross_attn, \
                f"Reduction schedule length ({len(args.reduction_schedule)}) must match the number of processor layers ({args.self_per_cross_attn})."
            ipot_processor = IPOTProcessorAdapt(
                self_per_cross_attn=self_per_cross_attn,
                latent_channel=latent_channel,
                self_heads_num=self_heads_num,
                self_heads_channel=self_heads_channel,
                ff_mult=ff_mult,
                reduction_schedule=args.reduction_schedule
            )
        else:
            raise ValueError(f"Unknown processor type: {processor_type}")

        # Decoder
        ipot_decoder = IPOTDecoder(
            output_channel=output_channel,
            query_channel=2 * sum(num_bands) + len(num_bands),
            latent_channel=latent_channel,
            cross_heads_num=cross_heads_num,
            cross_heads_channel=cross_heads_channel,
            ff_mult=ff_mult,
            output_scale=output_scale,
            position_encoding_type=position_encoding_type,
            pos2fourier_position_encoding_kwargs=dict(
                num_bands=num_bands,
                max_resolution=max_resolution, )
        )
        model = EncoderProcessorDecoder(
            encoder=ipot_encoder,
            processor=ipot_processor,
            decoder=ipot_decoder,
            input_preprocessor=ipot_input_preprocessor
        )
    else:
        raise NotImplementedError("Only 'ipot' model_type is supported.")

    print("模型构建成功。")
    return model


def main(args):
    device = torch.device(f"cuda:{args.gpu_num}" if torch.cuda.is_available() else "cpu")
    print(f"基准测试将在设备: {device} 上运行")

    test_loader, args = load_test_data(args)

    # --- 评测基准模型 ---
    print("\n--- 正在评测基准模型 (Standard Processor) ---")
    baseline_model = build_model_from_main(args, 'standard')
    base_avg, base_std, base_tput = measure_inference_time(baseline_model, test_loader, device, args, args.num_runs)

    # --- 评测自适应模型 ---
    print("\n--- 正在评测自适应模型 (Adaptive Processor) ---")
    adaptive_model = build_model_from_main(args, 'adaptive')
    adapt_avg, adapt_std, adapt_tput = measure_inference_time(adaptive_model, test_loader, device, args, args.num_runs)

    # --- 最终报告 ---
    if base_avg > 0 and adapt_avg > 0:
        ratio = base_avg / adapt_avg
        improvement = (ratio - 1) * 100
        print("\n" + "="*50)
        print(" " * 15 + "基准测试总结")
        print("="*50)
        print(f"{'指标':<25} {'基准模型':<12} {'自适应模型':<12}")
        print("-"*50)
        print(f"{'平均推理时间 (ms)':<25} {base_avg:<12.2f} {adapt_avg:<12.2f}")
        print(f"{'吞吐量 (样本/秒)':<25} {base_tput:<12.2f} {adapt_tput:<12.2f}")
        print("-"*50)
        print(f"加速比: {ratio:.2f}x")
        print(f"性能提升: {improvement:.2f}%")
        print("="*50)
    else:
        print("\n评测未能成功完成，无法生成总结报告。")


if __name__ == '__main__':
    parser = ArgumentParser(description='IPOT 模型架构推理速度基准测试')
    
    # --- 只需要提供架构和数据参数 ---
    # Data
    parser.add_argument('--data_name', type=str, default='burgers', help='数据集名称')
    parser.add_argument('--n_train', type=int, default=1000)
    parser.add_argument('--n_test', type=int, default=100)
    parser.add_argument('--n_total', type=int, default=1100)
    parser.add_argument('--nx', type=int, default=8192)
    parser.add_argument('--sub', type=int, default=8)
    parser.add_argument('--data_path', type=str, default='./data/', help='存放数据集文件的根目录路径')
    parser.add_argument('--downsamplex', type=int, default=4, help='[Pipe] x方向的下采样因子')
    parser.add_argument('--downsampley', type=int, default=4, help='[Pipe] y方向的下采样因子')

    # Model (与main.py完全一致)
    parser.add_argument('--model_type', type=str, default='ipot')
    parser.add_argument('--num_bands', type=int, default=[64], nargs='+')
    parser.add_argument('--max_resolution', type=int, default=[64], nargs='+')
    parser.add_argument('--num_latents', type=int, default=128)
    parser.add_argument('--latent_channel', type=int, default=64)
    parser.add_argument('--self_per_cross_attn', type=int, default=4)
    parser.add_argument('--cross_heads_num', type=int, default=1)
    parser.add_argument('--self_heads_num', type=int, default=4)
    parser.add_argument('--cross_heads_channel', type=int, default=64)
    parser.add_argument('--self_heads_channel', type=int, default=None)
    parser.add_argument('--ff_mult', type=int, default=2)
    parser.add_argument('--latent_init_scale', type=float, default=0.02)
    parser.add_argument('--output_scale', type=float, default=0.1)
    parser.add_argument('--position_encoding_type', type=str, default="pos2fourier")
    parser.add_argument('--reduction_schedule', type=float, nargs='+', default=[1.0, 0.9, 0.8, 0.7], help='自适应处理器的削减策略。')
    
    # 时序数据参数
    parser.add_argument('--T_in', type=int, default=10, help='输入时间步长')
    parser.add_argument('--T_out', type=int, default=10, help='输出时间步长 (用于模型调用)')

    # Benchmark Control
    parser.add_argument('--gpu_num', type=str, default="0", help='使用的GPU ID。')
    parser.add_argument('--batch_size', type=int, default=1, help='推理时的批次大小。')
    parser.add_argument('--num_runs', type=int, default=100, help='用于平均的计时运行次数。')
    parser.add_argument('--num_warmup', type=int, default=20, help='预热运行次数。')

    args = parser.parse_args()

    main(args)