#!/usr/bin/env python
#-*- coding:utf-8 _*-

import torch
import torch.utils.data
import time
import argparse
import numpy as np
import os
import copy
from einops import rearrange, repeat
from torch.utils.data import DataLoader, TensorDataset

# --- 1. 导入所有必要的模块 ---
# 假设OFormer的模块和我们自己的模块都在可导入的路径下
from nn_module.encoder_module import SpatialTemporalEncoder2D
from nn_module.decoder_module import PointWiseDecoder2D, PointWiseDecoder2D_Adaptive
from loss_fn import rel_l2norm_loss
from utils import load_checkpoint, ensure_dir

# ==============================================================================
# 2. 命令行参数定义
# ==============================================================================
def get_benchmark_args():
    parser = argparse.ArgumentParser(description='OFormer vs. AdaptiveOFormer Benchmark (Manual Config & Diagnostics)')

    # --- 核心路径参数 ---
    parser.add_argument('--path_base', type=str, required=True, help='基准OFormer检查点的路径')
    parser.add_argument('--path_optimized', type=str, required=True, help='自适应OFormer检查点的路径')
    parser.add_argument('--dataset_path', type=str, required=True, help='数据集 .npy 文件路径')
    
    # --- 任务参数 ---
    parser.add_argument('--in_seq_len', type=int, default=10, help='输入序列长度')
    parser.add_argument('--out_seq_len', type=int, default=40, help='输出序列长度')
    parser.add_argument('--test_seq_num', type=int, default=200, help='测试集中的序列数量')
    parser.add_argument('--batch_size', type=int, default=16, help='批处理大小')
    
    # --- 硬件与性能参数 ---
    parser.add_argument('--gpu', type=int, default=0, help='GPU ID')
    parser.add_argument('--warmup_runs', type=int, default=10, help='预热运行的rollout迭代次数')

    # --- 重要的架构参数 (手动指定以确保一致性) ---
    # Encoder 参数
    parser.add_argument('--in_channels', type=int, default=12, help='输入特征通道数 (例如, in_seq_len+2)')
    parser.add_argument('--encoder_emb_dim', type=int, default=64, help='Encoder嵌入维度')
    parser.add_argument('--out_seq_emb_dim', type=int, default=128, help='Encoder输出嵌入维度')
    parser.add_argument('--encoder_depth', type=int, default=5, help='Encoder深度')
    parser.add_argument('--encoder_heads', type=int, default=1, help='Encoder注意力头数')
    # Decoder 参数
    parser.add_argument('--decoder_emb_dim', type=int, default=256, help='Decoder嵌入维度')
    parser.add_argument('--out_channels', type=int, default=1, help='输出特征通道数')
    parser.add_argument('--out_step', type=int, default=1, help='Decoder单次rollout步长')
    parser.add_argument('--propagator_depth', type=int, default=4, help='Propagator深度或自适应递归深度')
    parser.add_argument('--fourier_frequency', type=int, default=8, help='傅里叶特征频率')
    # 自适应Decoder专用参数
    parser.add_argument('--capacity_ratios', nargs='+', type=float, default=None, help='[自适应] 每层递归的保留比例列表')
    
    return parser

# ==============================================================================
# 3. 辅助函数
# ==============================================================================

def build_model_from_args(opt, is_adaptive):
    """
    一个根据传入的args(opt)对象，专门用于构建模型的辅助函数。
    """
    encoder = SpatialTemporalEncoder2D(
        opt.in_channels, opt.encoder_emb_dim, opt.out_seq_emb_dim,
        opt.encoder_heads, opt.encoder_depth
    )

    decoder_kwargs = {
        'latent_channels': opt.decoder_emb_dim, 'out_channels': opt.out_channels,
        'out_steps': opt.out_step, 'propagator_depth': opt.propagator_depth,
        'scale': opt.fourier_frequency, 'dropout': 0.0,
        'encoder_heads': opt.encoder_heads, 'capacity_ratios': opt.capacity_ratios,
        'out_seq_emb_dim': opt.out_seq_emb_dim
    }

    if is_adaptive:
        decoder = PointWiseDecoder2D_Adaptive(**decoder_kwargs)
    else:
        # 过滤掉原始Decoder不支持的参数
        original_decoder_params = PointWiseDecoder2D.__init__.__code__.co_varnames
        original_decoder_kwargs = {k: v for k, v in decoder_kwargs.items() if k in original_decoder_params}
        decoder = PointWiseDecoder2D(**original_decoder_kwargs)
        
    return encoder, decoder

def run_inference(encoder, decoder, data_batch, args, device):
    """
    执行一次完整的、带预热的、精确计时的rollout推理，并返回诊断信息。
    """
    encoder.eval(); decoder.eval()
    in_seq, _ = data_batch
    in_seq = in_seq.to(device)
    
    grid_h = grid_w = int(np.sqrt(in_seq.shape[-1]))
    x0, y0 = np.meshgrid(np.linspace(0, 1, grid_h), np.linspace(0, 1, grid_w))
    grid_raw = torch.from_numpy(np.stack([x0, y0], axis=-1)).float()
    grid = rearrange(grid_raw, 'h w c -> () (h w) c') # 修复 EinopsError
    
    input_pos = prop_pos = repeat(grid, '() n c -> b n c', b=in_seq.shape[0]).to(device)
    in_seq = rearrange(in_seq, 'b t n -> b n t')
    in_seq_with_pos = torch.cat((in_seq, input_pos), dim=-1)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        print("    > 正在预热...")
        for _ in range(args.warmup_runs):
            z = encoder(in_seq_with_pos, input_pos)
            _ = decoder.rollout(z, prop_pos, args.out_seq_len, input_pos)
        torch.cuda.synchronize(device)

        print("    > 正在计时...")
        start_event.record()
        z = encoder(in_seq_with_pos, input_pos)
        model_output = decoder.rollout(z, prop_pos, args.out_seq_len, input_pos)
        end_event.record()
        torch.cuda.synchronize(device)

    # 正确解包返回值
    if isinstance(model_output, tuple):
        pred_out, diagnostics = model_output
    else:
        pred_out, diagnostics = model_output, None

    elapsed_time_s = start_event.elapsed_time(end_event) / 1000.0
    return elapsed_time_s, pred_out, diagnostics

# ==============================================================================
# 4. 主执行逻辑
# ==============================================================================

if __name__ == "__main__":
    parser = get_benchmark_args()
    args = parser.parse_args()
    
    device = torch.device(f'cuda:{args.gpu}')
    torch.cuda.set_device(device)
    
    # --- 加载数据 ---
    print('--- 正在准备数据 ---')
    data = np.load(args.dataset_path)
    x_test = data[:args.in_seq_len, ..., -args.test_seq_num:]
    y_test = data[args.in_seq_len:args.in_seq_len+args.out_seq_len, ..., -args.test_seq_num:]
    x_test = rearrange(torch.from_numpy(x_test).float(), 't h w n -> n t (h w)')
    y_test = rearrange(torch.from_numpy(y_test).float(), 't h w n -> n t (h w)')
    
    test_dataloader = DataLoader(TensorDataset(x_test, y_test), batch_size=args.batch_size, shuffle=False)
    initial_data_batch = next(iter(test_dataloader))
    _, gt_out = initial_data_batch
    gt_out = gt_out.to(device)
    print(f"已准备好一个批次的测试数据，批大小: {args.batch_size}")
    print("-" * 50)
    
    # --- 加载和测试基准模型 ---
    print("\n--- [1/2] 正在评估基准OFormer ---")
    ckpt_base = load_checkpoint(args.path_base, map_location=device)
    encoder_base, decoder_base = build_model_from_args(args, is_adaptive=False)
    encoder_base.load_state_dict(ckpt_base['encoder'])
    decoder_base.load_state_dict(ckpt_base['decoder'])
    encoder_base.to(device); decoder_base.to(device)
    
    base_time, base_pred, _ = run_inference(encoder_base, decoder_base, initial_data_batch, args, device)
    base_error = rel_l2norm_loss(base_pred, gt_out).item()
    print("基准模型评估完成。")
    print("-" * 50)

    # --- 加载和测试优化模型 ---
    print("\n--- [2/2] 正在评估自适应OFormer ---")
    ckpt_optimized = load_checkpoint(args.path_optimized, map_location=device)
    encoder_optimized, decoder_optimized = build_model_from_args(args, is_adaptive=True)
    encoder_optimized.load_state_dict(ckpt_optimized['encoder'])
    decoder_optimized.load_state_dict(ckpt_optimized['decoder'])
    encoder_optimized.to(device); decoder_optimized.to(device)

    optimized_time, optimized_pred, optimized_diagnostics = run_inference(encoder_optimized, decoder_optimized, initial_data_batch, args, device)
    optimized_error = rel_l2norm_loss(optimized_pred, gt_out).item()
    print("优化模型评估完成。")
    print("-" * 50)
    
    # --- 报告结果 ---
    print("\n" + "=" * 60)
    print("                 OFormer模型性能与精度对比")
    print("=" * 60)
    print(f"批处理大小: {args.batch_size}, 输入步长: {args.in_seq_len}, 预测步长: {args.out_seq_len}")
    print("-" * 60)
    
    print(f"基准模型 (OFormer):")
    print(f"  - 总耗时: {base_time:.4f} 秒")
    print(f"  - 相对L2误差: {base_error:.6f}")

    print(f"\n优化模型 (StructuredRecursiveOFormer):")
    print(f"  - 总耗时: {optimized_time:.4f} 秒")
    print(f"  - 相对L2误差: {optimized_error:.6f}")
    
    print("-" * 60)
    
    if optimized_time > 0 and base_time > 0:
        speedup = base_time / optimized_time
        print(f"🚀 加速比 (基准 / 优化): {speedup:.2f}x")
    
    if base_error > 0:
        accuracy_change = ((optimized_error - base_error) / base_error) * 100
        print(f"📉 精度变化: {accuracy_change:+.2f}%")
    
    if optimized_diagnostics:
        print("-" * 60)
        print("📊 自适应Decoder逐层诊断报告 (单步推理):")
        
        active_tokens = optimized_diagnostics.get('active_tokens_per_layer', [])
        layer_times = optimized_diagnostics.get('time_per_layer_ms', [])
        total_block_time = sum(layer_times)

        print(f"  {'Layer':<10} | {'Active Tokens':<15} | {'Keep Ratio':<15} | {'Layer Time (ms)':<18} | {'Time %':<10}")
        print(f"  {'-'*8} | {'-'*13} | {'-'*13} | {'-'*16} | {'-'*8}")
        
        initial_tokens = active_tokens[0] if active_tokens else 0
        
        if initial_tokens > 0:
            for i, (count, layer_time) in enumerate(zip(active_tokens, layer_times)):
                ratio = (count / initial_tokens) * 100
                time_percent = (layer_time / total_block_time) * 100 if total_block_time > 0 else 0
                print(f"  Block {i:<5} | {int(count):<15} | {ratio:<14.1f}% | {layer_time:<18.4f} | {time_percent:<9.1f}%")
            
            print(f"  {'-'*8} | {'-'*13} | {'-'*13} | {'-'*16} | {'-'*8}")
            print(f"  Total Propagator Time: {total_block_time:.4f} ms")
    
    print("=" * 60)