"""
训练 Mass-Aware DPLM 做质谱 De Novo
===================================

组件：
- Backbone: `DiffusionProteinLanguageModel`（DPLM）
- Mass-aware wrapper: `MassAwareDiffusionProteinLanguageModel`
- 数据：HuggingFace `jingbo02/NovoBench`

特点：
- 训练阶段：与原始 DPLM 相同的掩码扩散损失（不在 loss 中硬性加入质量约束）。
- 推理阶段：通过 `precursor_neutral_mass` 在生成时施加质量硬约束。
此代码仅为简要示例，并非真实使用版本，完整版本将在后续提供
"""

# 自动设置 PYTHONPATH（如果未设置）
import os
import sys
from pathlib import Path

# 获取当前脚本所在目录
SCRIPT_DIR = Path(__file__).parent.absolute()
SRC_DIR = SCRIPT_DIR / "src"

# 如果 src 目录存在且不在 PYTHONPATH 中，自动添加
if SRC_DIR.exists() and str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))
    os.environ["PYTHONPATH"] = f"{SRC_DIR}:{os.environ.get('PYTHONPATH', '')}"

# 兼容性修复：必须在导入任何可能间接导入 torchvision 的库之前执行
# torchvision 在导入时会使用 torch.library.register_fake，需要提前确保其可用
import torch

# 强制初始化 torch.library，确保 register_fake 可用
# 关键：直接访问 register_fake 来触发其初始化，这样在导入 torchvision 时就不会报错
if hasattr(torch.library, 'register_fake'):
    # 直接访问，触发初始化（即使已经存在，访问一次可以确保完全初始化）
    _ = torch.library.register_fake
else:
    # 如果不可用，尝试从内部模块导入
    try:
        from torch._library.fake_impl import register_fake_impl
        torch.library.register_fake = register_fake_impl
    except (ImportError, AttributeError):
        # 最后的尝试：直接调用内部函数来初始化
        try:
            import torch._library.fake_impl as fake_impl
            if hasattr(fake_impl, 'register_fake_impl'):
                torch.library.register_fake = fake_impl.register_fake_impl
        except (ImportError, AttributeError):
            import warnings
            warnings.warn(
                "torch.library.register_fake 不可用，某些功能可能无法正常工作。"
                "建议升级 PyTorch 或检查版本兼容性。"
            )

import argparse
import os
from pathlib import Path

from torch.utils.tensorboard import SummaryWriter

from byprot.models.dplm.dplm_mass_aware import (
    MassAwareDiffusionProteinLanguageModel,
    MassAwareDPLMConfig,
)
from byprot.datamodules.novobench_ms import build_novobench_dataloaders


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Train Mass-Aware DPLM on NovoBench for MS De Novo"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./outputs_ms_denovo",
        help="保存 checkpoint 与日志的目录",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="facebook/esm2_t12_35M_UR50D",
        help="ESM-2 backbone 名称（12层ESM-2，默认使用35M参数版本）",
    )
    parser.add_argument(
        "--batch_size", type=int, default=4, help="batch size（建议先用小一点调试）"
    )
    parser.add_argument("--max_epochs", type=int, default=1, help="训练 epoch 数")
    parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
    parser.add_argument(
        "--weight_decay", type=float, default=0.01, help="权重衰减"
    )
    parser.add_argument(
        "--max_length", type=int, default=60, help="最大肽段长度（token 数）"
    )
    parser.add_argument(
        "--max_peaks", type=int, default=200, help="每条谱图最多保留的 peaks 数"
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=2,
        help="DataLoader num_workers（根据机器自行调整）",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="训练设备：cuda / cpu",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="",
        help="可选：自己预训练的模型checkpoint路径（.pt或.ckpt文件）。如果不提供，则使用随机初始化的ESM-2。",
    )
    return parser.parse_args()


def build_model(
    cfg: MassAwareDPLMConfig, model_name: str, checkpoint: str = ""
) -> MassAwareDiffusionProteinLanguageModel:
    import torch
    """
    根据配置构建 Mass-Aware 模型，使用12层ESM-2作为backbone。

    - 直接使用ESM-2架构（12层），不使用DPLM checkpoint
    - 如果提供checkpoint，则加载自己预训练的权重
    - 否则使用随机初始化的ESM-2
    """
    from byprot.models.dplm.dplm import DiffusionProteinLanguageModel
    from byprot.models.dplm.modules.dplm_modeling_esm import EsmForDPLM
    from transformers import AutoConfig

    # 直接构建ESM-2模型（12层）
    print(f"构建12层ESM-2模型: {model_name}")
    esm_config = AutoConfig.from_pretrained(model_name)
    esm_net = EsmForDPLM(esm_config, dropout=cfg.net.dropout if hasattr(cfg, 'net') else 0.1)
    
    # 如果提供了checkpoint，加载预训练权重
    if checkpoint:
        print(f"加载自己预训练的checkpoint: {checkpoint}")
        if os.path.isfile(checkpoint):
            ckpt = torch.load(checkpoint, map_location='cpu')
            # 处理不同的checkpoint格式
            if 'state_dict' in ckpt:
                state_dict = ckpt['state_dict']
                # 移除可能的prefix
                new_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('model.net.'):
                        new_state_dict[k[10:]] = v
                    elif k.startswith('net.'):
                        new_state_dict[k[4:]] = v
                    elif k.startswith('model.'):
                        new_state_dict[k[6:]] = v
                    else:
                        new_state_dict[k] = v
                esm_net.load_state_dict(new_state_dict, strict=False)
            elif 'model' in ckpt:
                esm_net.load_state_dict(ckpt['model'], strict=False)
            else:
                esm_net.load_state_dict(ckpt, strict=False)
            print("Checkpoint加载完成")
        else:
            print(f"[警告] 未找到checkpoint文件: {checkpoint}，使用随机初始化")
    else:
        print("使用随机初始化的ESM-2模型（12层）")
    
    # 构建DPLM wrapper
    base = DiffusionProteinLanguageModel(cfg, net=esm_net)
    # 构建Mass-Aware模型
    model = MassAwareDiffusionProteinLanguageModel(cfg, net=base.net)
    return model


def train() -> None:
    args = parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(args.output_dir, "tb"))

    device = torch.device(args.device)

    # ------------------------------------------------------------------ #
    # 1. 数据
    # ------------------------------------------------------------------ #
    train_loader, val_loader = build_novobench_dataloaders(
        tokenizer_name=args.model_name,
        max_length=args.max_length,
        max_peaks=args.max_peaks,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )

    # ------------------------------------------------------------------ #
    # 2. 模型
    # ------------------------------------------------------------------ #
    model_cfg = MassAwareDPLMConfig()
    # 覆盖关键超参
    model_cfg.num_diffusion_timesteps = 500
    model_cfg.mass_tolerance_ppm = 50.0
    model_cfg.min_aa_mass = 57.0
    model_cfg.max_aa_mass = 200.0
    model_cfg.enable_mass_constraint = True

    # 指定 backbone 的架构和名称（用于 tokenizer 等）
    model_cfg.net.arch_type = "esm"
    model_cfg.net.name = args.model_name
    model_cfg.net.dropout = 0.1

    model = build_model(model_cfg, model_name=args.model_name, checkpoint=args.checkpoint).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.max_epochs
    )

    global_step = 0

    # ------------------------------------------------------------------ #
    # 3. 训练循环
    # ------------------------------------------------------------------ #
    for epoch in range(args.max_epochs):
        model.train()
        running_loss = 0.0

        for batch in train_loader:
            targets = batch["targets"].to(device)
            neutral_mass = batch["neutral_mass"].to(device)
            spectra = batch["spectrum"].to(device)
            spectrum_mask = batch["spectrum_mask"].to(device)

            # DPLM 的 compute_loss 只需要 targets
            logits, target, loss_mask, weight = model.compute_loss(
                batch={
                    "targets": targets,
                    "neutral_mass": neutral_mass,
                    "spectrum": spectra,
                    "spectrum_mask": spectrum_mask,
                },
                weighting="constant",
            )

            # 交叉熵损失
            loss_fn = torch.nn.CrossEntropyLoss(
                ignore_index=0, reduction="none"
            )
            loss = loss_fn(
                logits.reshape(-1, logits.size(-1)), target.reshape(-1)
            )
            loss = loss.reshape(loss_mask.shape)
            # 扩散 step 权重
            if weight.dim() == 2 and weight.shape[1] == 1:
                weight = weight.expand_as(loss_mask)
            loss = (loss * loss_mask * weight).sum() / (
                loss_mask.sum() + 1e-8
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            global_step += 1
            if global_step % 10 == 0:
                writer.add_scalar("train/loss", loss.item(), global_step)

        lr_scheduler.step()

        avg_train_loss = running_loss / max(1, len(train_loader))
        print(f"[Epoch {epoch+1}] train_loss = {avg_train_loss:.4f}")

        # -------------------- 验证 --------------------
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                targets = batch["targets"].to(device)
                neutral_mass = batch["neutral_mass"].to(device)
                spectra = batch["spectrum"].to(device)
                spectrum_mask = batch["spectrum_mask"].to(device)
                logits, target, loss_mask, weight = model.compute_loss(
                    batch={
                        "targets": targets,
                        "neutral_mass": neutral_mass,
                        "spectrum": spectra,
                        "spectrum_mask": spectrum_mask,
                    },
                    weighting="constant",
                )
                loss_fn = torch.nn.CrossEntropyLoss(
                    ignore_index=0, reduction="none"
                )
                loss = loss_fn(
                    logits.reshape(-1, logits.size(-1)), target.reshape(-1)
                )
                loss = loss.reshape(loss_mask.shape)
                if weight.dim() == 2 and weight.shape[1] == 1:
                    weight = weight.expand_as(loss_mask)
                loss = (loss * loss_mask * weight).sum() / (
                    loss_mask.sum() + 1e-8
                )
                val_loss += loss.item()

        avg_val_loss = val_loss / max(1, len(val_loader))
        writer.add_scalar("val/loss", avg_val_loss, epoch + 1)
        print(f"[Epoch {epoch+1}] val_loss = {avg_val_loss:.4f}")

        # 保存 checkpoint
        ckpt_path = Path(args.output_dir) / f"checkpoint_epoch{epoch+1}.pt"
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "config": model_cfg,
            },
            ckpt_path,
        )
        print(f"Saved checkpoint to {ckpt_path}")

    writer.close()


if __name__ == "__main__":
    train()


