import torch
import torch.utils.data
import time
import argparse
import numpy as np
import os
from einops import rearrange, repeat
from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
from tqdm import tqdm

# --- 1. 导入所有必要的模块 ---
# 导入1D版本的Encoder和两个版本的Decoder
from nn_module.encoder_module import Encoder1D
from nn_module.decoder_module import PointWiseDecoder1D, AdaptivePointWiseDecoder1D
from loss_fn import rel_loss
from utils import load_checkpoint

# ==============================================================================
# 2. 命令行参数定义
# ==============================================================================
def get_benchmark_args(parser):
    """
    向一个已有的 parser 对象中添加用于 Burgers benchmark 的参数。
    """
    # --- 核心路径参数 ---
    parser.add_argument('--path_base', type=str, required=True, help='基准OFormer (PointWiseDecoder1D) 检查点的路径')
    parser.add_argument('--path_optimized', type=str, required=True, help='自适应OFormer (AdaptivePointWiseDecoder1D) 检查点的路径')
    parser.add_argument('--dataset_path', type=str, required=True, help='数据集 .mat 文件路径')
    
    # --- 任务参数 (与 train_burgers.py 匹配) ---
    parser.add_argument('--train_seq_num', type=int, default=1000, help='用于计算统计量的训练集样本数')
    parser.add_argument('--test_seq_num', type=int, default=100, help='测试集中的序列数量')
    parser.add_argument('--resolution', type=int, default=2048, help='空间分辨率')
    parser.add_argument('--batch_size', type=int, default=16, help='批处理大小')
    
    # --- 硬件与性能参数 ---
    parser.add_argument('--gpu', type=int, default=0, help='GPU ID')
    parser.add_argument('--warmup_runs', type=int, default=10, help='预热运行的次数')

    # --- 模型架构参数 (必须与被测试的模型训练时一致) ---
    parser.add_argument('--encoder_emb_dim', type=int, default=96, help='Encoder/Decoder 嵌入维度')
    parser.add_argument('--encoder_depth', type=int, default=4, help='Encoder 深度')
    parser.add_argument('--propagator_depth', type=int, default=3, help='Propagator 深度 (对应原始模型的decoding_depth)')
    parser.add_argument('--capacity_ratios', nargs='+', type=float, required=True, help='[自适应] Decoder每层递归的保留比例列表')
    
    return parser

# ==============================================================================
# 3. 辅助函数
# ==============================================================================
def build_model_from_args(opt, is_adaptive):
    """根据配置对象，构建用于Burgers任务的模型。"""
    encoder = Encoder1D(
        2, # u(x,0) + x_coord
        in_emb_dim=opt.encoder_emb_dim,
        out_seq_emb_dim=opt.encoder_emb_dim, # 对于burgers, in=out
        depth=opt.encoder_depth,
        res=opt.resolution
    )

    if is_adaptive:
        print("    > 正在构建: 自适应解码器 (AdaptivePointWiseDecoder1D)")
        decoder = AdaptivePointWiseDecoder1D(
            latent_channels=opt.encoder_emb_dim,
            out_channels=1,
            propagator_depth=opt.propagator_depth,
            capacity_ratios=opt.capacity_ratios,
            res=opt.resolution
        )
    else:
        print("    > 正在构建: 原始解码器 (PointWiseDecoder1D)")
        decoder = PointWiseDecoder1D(
            latent_channels=opt.encoder_emb_dim,
            out_channels=1,
            decoding_depth=opt.propagator_depth,
            res=opt.resolution
        )
    return encoder, decoder

def run_inference(encoder, decoder, dataloader, args, device, norm_stats):
    """
    执行一次完整的、带预热的、精确计时的推理，并返回结果。
    """
    encoder.eval(); decoder.eval()
    
    x_mean, x_std = norm_stats['x_mean'].to(device), norm_stats['x_std'].to(device)
    y_mean, y_std = norm_stats['y_mean'].to(device), norm_stats['y_std'].to(device)
    
    all_preds, all_gts = [], []
    diagnostics_report = None
    
    gridx = torch.tensor(np.linspace(0, 1, args.resolution), dtype=torch.float32).reshape(1, args.resolution, 1).to(device)

    print("    > 正在预热...")
    warmup_batch = next(iter(dataloader))
    with torch.no_grad():
        for _ in range(args.warmup_runs):
            x_warm, _ = warmup_batch
            x_warm = x_warm.to(device)
            x_norm = (x_warm - x_mean) / x_std
            
            input_pos = gridx.repeat(x_warm.shape[0], 1, 1)
            x_with_pos = torch.cat((x_norm, input_pos), dim=-1)
            
            z_warm = encoder(x_with_pos, input_pos)
            _ = decoder(z_warm, input_pos, input_pos)
    torch.cuda.synchronize(device)

    print("    > 正在计时并评估...")
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    with torch.no_grad():
        for x, y in tqdm(dataloader, desc="  Inferencing"):
            x, y = x.to(device), y.to(device)
            
            x_norm = (x - x_mean) / x_std
            
            input_pos = gridx.repeat(x.shape[0], 1, 1)
            x_with_pos = torch.cat((x_norm, input_pos), dim=-1)

            z = encoder(x_with_pos, input_pos)
            model_output = decoder(z, input_pos, input_pos)
            
            if isinstance(model_output, tuple):
                pred_out_norm, diagnostics = model_output
                if diagnostics_report is None: diagnostics_report = diagnostics
            else:
                pred_out_norm, diagnostics = model_output, None

            pred_out_unnorm = pred_out_norm * y_std + y_mean
            
            all_preds.append(pred_out_unnorm.cpu())
            all_gts.append(y.cpu())

    end_event.record()
    torch.cuda.synchronize(device)

    elapsed_time_s = start_event.elapsed_time(end_event) / 1000.0
    
    return elapsed_time_s, torch.cat(all_preds, dim=0), torch.cat(all_gts, dim=0), diagnostics_report

# ==============================================================================
# 5. 主执行逻辑
# ==============================================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="OFormer vs. AdaptiveOFormer Benchmark for 1D Burgers Equation")
    parser = get_benchmark_args(parser)
    args = parser.parse_args()
    
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"将在设备 {device} 上进行测试。")
    if torch.cuda.is_available():
        torch.cuda.set_device(device)
    
    print('--- 正在准备数据和归一化统计量 ---')
    data = loadmat(args.dataset_path)
    sub = 2 ** 13 // args.resolution
    x_data = data['a'][:, ::sub]
    y_data = data['u'][:, ::sub]
    
    x_train = torch.from_numpy(x_data[:args.train_seq_num, :].reshape(args.train_seq_num, args.resolution, 1)).float()
    y_train = torch.from_numpy(y_data[:args.train_seq_num, :].reshape(args.train_seq_num, args.resolution, 1)).float()
    x_test = torch.from_numpy(x_data[-args.test_seq_num:, :].reshape(args.test_seq_num, args.resolution, 1)).float()
    y_test = torch.from_numpy(y_data[-args.test_seq_num:, :].reshape(args.test_seq_num, args.resolution, 1)).float()
    
    # Burgers任务通常不进行归一化，以保留能量衰减的物理特性。
    x_mean, x_std = torch.tensor(0.0, dtype=torch.float32), torch.tensor(1.0, dtype=torch.float32)
    y_mean, y_std = torch.tensor(0.0, dtype=torch.float32), torch.tensor(1.0, dtype=torch.float32)
    
    norm_stats = {'x_mean': x_mean, 'x_std': x_std, 'y_mean': y_mean, 'y_std': y_std}
    print("统计量设置完毕 (假设不归一化)。")

    test_dataloader = DataLoader(TensorDataset(x_test, y_test), batch_size=args.batch_size, shuffle=False)
    
    # --- 加载和测试基准模型 ---
    print("\n--- [1/2] 正在评估基准OFormer (Burgers) ---")
    ckpt_base = load_checkpoint(args.path_base, map_location=device)
    encoder_base, decoder_base = build_model_from_args(args, is_adaptive=False)
    encoder_base.load_state_dict(ckpt_base['encoder'])
    decoder_base.load_state_dict(ckpt_base['decoder'])
    encoder_base.to(device); decoder_base.to(device)
    
    base_time, base_pred, base_gt, _ = run_inference(encoder_base, decoder_base, test_dataloader, args, device, norm_stats)
    base_error = rel_loss(base_pred, base_gt, p=2).item()
    print("基准模型评估完成。")

    # --- 加载和测试优化模型 ---
    print("\n--- [2/2] 正在评估自适应OFormer (Burgers) ---")
    ckpt_optimized = load_checkpoint(args.path_optimized, map_location=device)
    encoder_optimized, decoder_optimized = build_model_from_args(args, is_adaptive=True)
    encoder_optimized.load_state_dict(ckpt_optimized['encoder'])
    decoder_optimized.load_state_dict(ckpt_optimized['decoder'])
    encoder_optimized.to(device); decoder_optimized.to(device)
    
    optimized_time, optimized_pred, optimized_gt, optimized_diagnostics = run_inference(encoder_optimized, decoder_optimized, test_dataloader, args, device, norm_stats)
    optimized_error = rel_loss(optimized_pred, optimized_gt, p=2).item()
    print("优化模型评估完成。")
    
    # --- 报告结果 ---
    print("\n" + "=" * 60)
    print("      OFormer (Burgers) 模型性能与精度对比")
    print("=" * 60)
    print(f"分辨率: {args.resolution}, 批处理大小: {args.batch_size}")
    print("-" * 60)
    print(f"基准模型 (PointWiseDecoder1D):")
    print(f"  - 总耗时: {base_time:.4f} 秒")
    print(f"  - 相对L2误差: {base_error:.6f}")
    print(f"\n优化模型 (AdaptivePointWiseDecoder1D):")
    print(f"  - 总耗时: {optimized_time:.4f} 秒")
    print(f"  - 相对L2误差: {optimized_error:.6f}")
    print("-" * 60)
    if optimized_time > 0 and base_time > 0:
        speedup = base_time / optimized_time
        print(f"🚀 加速比 (基准 / 优化): {speedup:.2f}x")
    if base_error > 0:
        accuracy_change = ((optimized_error - base_error) / base_error) * 100
        print(f"📉 精度变化: {accuracy_change:+.2f}%")
    if optimized_diagnostics:
        print("-" * 60)
        print("📊 自适应Decoder逐层诊断报告:")
        active_tokens = optimized_diagnostics.get('active_tokens_per_layer', [])
        layer_times = optimized_diagnostics.get('time_per_layer_ms', [])
        total_block_time = sum(layer_times)
        if total_block_time > 0:
            print(f"  {'Layer':<10} | {'Active Tokens':<15} | {'Keep Ratio':<15} | {'Layer Time (ms)':<18} | {'Time %':<10}")
            print(f"  {'-'*8} | {'-'*13} | {'-'*13} | {'-'*16} | {'-'*8}")
            initial_tokens = active_tokens[0] if active_tokens else 0
            if initial_tokens > 0:
                for i, (count, layer_time) in enumerate(zip(active_tokens, layer_times)):
                    ratio = (count / initial_tokens) * 100
                    time_percent = (layer_time / total_block_time) * 100
                    print(f"  Block {i:<5} | {int(count):<15} | {ratio:<14.1f}% | {layer_time:<18.4f} | {time_percent:<9.1f}%")
                print(f"  {'-'*8} | {'-'*13} | {'-'*13} | {'-'*16} | {'-'*8}")
                print(f"  Total Propagator Time: {total_block_time:.4f} ms")
    print("=" * 60)