"""
简单推理脚本：用 Mass-Aware 模型（12层ESM-2）在 NovoBench 上做一次 De Novo 生成
----------------------------------------------------------------------

流程：
1. 从 jingbo02/NovoBench 取一条样本（默认 test split 第 0 条）
2. 用 ESM tokenizer 把目标序列编码（这里只用来确定 max_length，真正生成从全 mask 开始）
3. 根据 precursor_mz 和 charge 计算 precursor neutral mass
4. 构建 MassAwareDiffusionProteinLanguageModel（使用12层ESM-2架构）
5. 加载训练好的checkpoint
6. 从全 mask 序列开始，用 generate() + precursor_neutral_mass 生成一条肽段，并打印出来
此代码仅为简要示例，并非真实使用版本，完整版本将在后续提供
"""

# 自动设置 PYTHONPATH（如果未设置）
import os
import sys
from pathlib import Path

# 获取当前脚本所在目录
SCRIPT_DIR = Path(__file__).parent.absolute()
SRC_DIR = SCRIPT_DIR / "src"

# 如果 src 目录存在且不在 PYTHONPATH 中，自动添加
if SRC_DIR.exists() and str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))
    os.environ["PYTHONPATH"] = f"{SRC_DIR}:{os.environ.get('PYTHONPATH', '')}"

# 兼容性检查：确保 torch.library.register_fake 在导入其他库之前可用
import torch
if not hasattr(torch.library, 'register_fake'):
    # 尝试从内部模块导入（某些 PyTorch 版本可能需要）
    try:
        from torch._library.fake_impl import register_fake_impl
        torch.library.register_fake = register_fake_impl
    except (ImportError, AttributeError):
        pass  # 如果失败，继续执行，让错误自然发生以便调试

import argparse
import csv
import os

from datasets import load_dataset
from transformers import AutoTokenizer

# 只导入我们需要的类，尽量避免触发 byprot 的全局 import_modules
from byprot.models.dplm.dplm import DiffusionProteinLanguageModel
from byprot.models.dplm.dplm_mass_aware import (
    MassAwareDiffusionProteinLanguageModel,
    MassAwareDPLMConfig,
)


PROTON_MASS = 1.007276466812

# 直接从 HuggingFace 仓库读取 parquet，绕过 NovoBench 的 builder schema 问题
NOVOBENCH_BASE_URL = (
    "https://huggingface.co/datasets/jingbo02/NovoBench/resolve/main"
)


def load_novobench_row(split: str = "test", index: int = 0):
    """
    使用 parquet builder 从官方仓库读取一条 NovoBench 样本。

    这里固定使用 nine_species 子集，文件结构参考官方仓库：
    data/nine_species/{train,validation,test}.parquet
    参见数据卡: https://huggingface.co/datasets/jingbo02/NovoBench
    """
    fname_map = {
        "train": "data/nine_species/train.parquet",
        "validation": "data/nine_species/validation.parquet",
        "test": "data/nine_species/test.parquet",
    }
    if split not in fname_map:
        raise ValueError(f"Unsupported split: {split}")

    url = f"{NOVOBENCH_BASE_URL}/{fname_map[split]}"
    ds = load_dataset("parquet", data_files={"data": url})["data"]
    if index >= len(ds):
        raise IndexError(f"{split} 只有 {len(ds)} 条样本，无法索引 {index}")
    return ds[index]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run one-step MS De Novo inference")
    parser.add_argument(
        "--model_name",
        type=str,
        default="facebook/esm2_t12_35M_UR50D",
        help="ESM-2 backbone 名称（12层ESM-2，默认使用35M参数版本）",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="test",
        choices=["train", "validation", "test"],
        help="从 NovoBench 的哪个 split 取样本",
    )
    parser.add_argument(
        "--sample_index",
        type=int,
        default=0,
        help="在该 split 中取第几条谱做测试",
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=60,
        help="生成时使用的最大肽段长度（token 数）",
    )
    parser.add_argument(
        "--max_iter",
        type=int,
        default=50,
        help="反向扩散迭代步数（越大越接近原 DPLM 设定）",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="cuda 或 cpu",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="",
        help="必需：训练好的模型checkpoint路径（train_ms_denovo.py 保存的 .pt 文件）。",
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default="./outputs_ms_denovo/inference.csv",
        help="推理结果保存路径（CSV），会自动追加写入",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    device = torch.device(args.device)

    print(f"使用设备: {device}")

    # ------------------------------------------------------------------ #
    # 1. 从 NovoBench 真实数据中取一条样本
    # ------------------------------------------------------------------ #
    print("从 jingbo02/NovoBench (nine_species) 读取一条真实样本...")
    row = load_novobench_row(split=args.split, index=args.sample_index)

    seq = row.get("modified_sequence") or row.get("sequence")
    precursor_mz = float(row["precursor_mz"])
    charge = int(row.get("charge", row.get("precursor_charge", 1)))
    charge = max(charge, 1)
    neutral_mass = (precursor_mz - PROTON_MASS) * charge

    print(f"目标原始序列: {seq}")
    print(
        f"precursor_mz = {precursor_mz:.4f}, charge = {charge}, "
        f"neutral_mass ≈ {neutral_mass:.4f} Da"
    )

    # ------------------------------------------------------------------ #
    # 2. tokenizer & 确定长度
    # ------------------------------------------------------------------ #
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    enc = tokenizer(
        seq,
        add_special_tokens=True,
        truncation=True,
        max_length=args.max_length,
        return_tensors="pt",
    )
    seq_len = enc["input_ids"].shape[1]
    max_length = min(args.max_length, seq_len)
    print(f"使用 max_length = {max_length}")

    # ------------------------------------------------------------------ #
    # 3. 构建 Mass-Aware 模型，使用12层ESM-2架构
    # ------------------------------------------------------------------ #
    model_cfg = MassAwareDPLMConfig()
    # 覆盖关键超参
    model_cfg.num_diffusion_timesteps = 500
    model_cfg.mass_tolerance_ppm = 50.0
    model_cfg.min_aa_mass = 57.0
    model_cfg.max_aa_mass = 200.0
    model_cfg.enable_mass_constraint = True

    # 构建模型（使用12层ESM-2，不依赖DPLM checkpoint）
    from byprot.models.dplm.dplm import DiffusionProteinLanguageModel
    from byprot.models.dplm.modules.dplm_modeling_esm import EsmForDPLM
    from transformers import AutoConfig

    print(f"构建12层ESM-2模型: {args.model_name}")
    esm_config = AutoConfig.from_pretrained(args.model_name)
    esm_net = EsmForDPLM(esm_config, dropout=0.1)
    
    # 构建DPLM wrapper
    base = DiffusionProteinLanguageModel(model_cfg, net=esm_net)
    # 构建Mass-Aware模型
    model = MassAwareDiffusionProteinLanguageModel(model_cfg, net=base.net).to(device)

    # 加载训练好的checkpoint（必需）
    if args.checkpoint:
        if os.path.isfile(args.checkpoint):
            print(f"加载训练好的checkpoint: {args.checkpoint}")
            ckpt = torch.load(args.checkpoint, map_location=device)
            # 处理不同的checkpoint格式
            if 'model' in ckpt:
                model.load_state_dict(ckpt['model'], strict=False)
            elif 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
                # 移除可能的prefix
                new_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('model.'):
                        new_state_dict[k[6:]] = v
                    else:
                        new_state_dict[k] = v
                model.load_state_dict(new_state_dict, strict=False)
            else:
                model.load_state_dict(ckpt, strict=False)
            print("Checkpoint加载完成")
        else:
            raise FileNotFoundError(f"未找到checkpoint文件: {args.checkpoint}")
    else:
        raise ValueError("必须提供--checkpoint参数，指定训练好的模型checkpoint路径")

    model.eval()

    # ------------------------------------------------------------------ #
    # 4. 构造全 mask 的初始输入，运行 generate()
    # ------------------------------------------------------------------ #
    B = 1
    input_tokens = torch.full(
        (B, max_length), model.mask_id, dtype=torch.long, device=device
    )
    precursor_neutral_mass = torch.tensor(
        [neutral_mass], dtype=torch.float32, device=device
    )

    print("开始生成（带质量硬约束）...")
    with torch.no_grad():
        gen_tokens = model.generate(
            input_tokens=input_tokens,
            tokenizer=tokenizer,
            max_iter=args.max_iter,
            temperature=1.0,
            sampling_strategy="gumbel_argmax",
            disable_resample=False,
            resample_ratio=0.25,
            precursor_neutral_mass=precursor_neutral_mass,
        )

    gen_seqs = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    gen_seq = gen_seqs[0]
    print(f"生成序列: {gen_seq}")

    # ------------------------------------------------------------------ #
    # 5. 将结果追加写入 CSV 文件
    # ------------------------------------------------------------------ #
    save_path = args.save_path
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    file_exists = os.path.isfile(save_path)
    with open(save_path, "a", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        # 写表头（第一次创建文件时）
        if not file_exists:
            writer.writerow(
                [
                    "split",
                    "sample_index",
                    "precursor_mz",
                    "charge",
                    "neutral_mass",
                    "reference_sequence",
                    "generated_sequence",
                ]
            )
        writer.writerow(
            [
                args.split,
                args.sample_index,
                f"{precursor_mz:.6f}",
                charge,
                f"{neutral_mass:.6f}",
                seq,
                gen_seq,
            ]
        )
    print(f"结果已追加保存到: {save_path}")


if __name__ == "__main__":
    main()


