"""
NovoBench MS De Novo datamodule
--------------------------------

基于 HuggingFace 上的 [jingbo02/NovoBench](
`https://huggingface.co/datasets/jingbo02/NovoBench`) 构建 PyTorch Dataset，
为 DPLM / MassAwareDPLM 提供训练与验证数据。

每条样本包含：
- MS2 谱图：mz_array, intensity_array → tensor [max_peaks, 2]
- MS1 信息：precursor_mz, charge → 计算 precursor neutral mass
- 目标序列：modified_sequence / sequence → token ids
"""

from typing import Dict, List, Optional, Tuple

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

try:
    from datasets import load_dataset
except ImportError as e:  # pragma: no cover
    raise ImportError(
        "需要安装 `datasets` 库以加载 jingbo02/NovoBench 数据集：\n"
        "  pip install datasets"
    ) from e


PROTON_MASS = 1.007276466812

# 直接使用 parquet 文件，绕过 NovoBench builder 的 schema cast 问题
NOVOBENCH_BASE_URL = (
    "https://huggingface.co/datasets/jingbo02/NovoBench/resolve/main"
)


class NovoBenchDataset(Dataset):
    """面向 Mass-Aware DPLM 的 NovoBench 样本。"""

    def __init__(
        self,
        split: str,
        tokenizer_name: str,
        max_length: int = 100,
        max_peaks: int = 200,
        cache_dir: Optional[str] = None,
    ):
        """
        Args:
            split: "train" / "validation" / "test"
            tokenizer_name: ESM 等 tokenizer 名称
            max_length: 最大氨基酸序列长度
            max_peaks: 每条谱图最多保留的 peaks 数
            cache_dir: HF datasets 缓存目录
        """
        super().__init__()
        self.split = split
        self.max_length = max_length
        self.max_peaks = max_peaks
        # nine_species 子集 parquet 文件
        # nine_species 官方仓库中只提供 train/test parquet，
        # 这里将 validation 复用 test 文件，主要用于监控验证 loss。
        fname_map = {
            "train": "data/nine_species/train.parquet",
            "validation": "data/nine_species/test.parquet",
            "test": "data/nine_species/test.parquet",
        }
        if split not in fname_map:
            raise ValueError(f"Unsupported split: {split}")
        url = f"{NOVOBENCH_BASE_URL}/{fname_map[split]}"
        # 这里直接用 parquet builder 读取，不使用 jingbo02/NovoBench 自带的 schema
        self.ds = load_dataset("parquet", data_files={"data": url})["data"]

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def __len__(self) -> int:
        return len(self.ds)

    def _build_spectrum(
        self, mz_array: List[float], intensity_array: List[float]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        构建 [max_peaks, 2] 的谱图张量及 mask（True 表示有效 peak）。
        """
        n = min(len(mz_array), len(intensity_array), self.max_peaks)

        mz = torch.zeros(self.max_peaks, dtype=torch.float32)
        intensity = torch.zeros(self.max_peaks, dtype=torch.float32)

        if n > 0:
            mz[:n] = torch.tensor(mz_array[:n], dtype=torch.float32)
            intensity[:n] = torch.tensor(intensity_array[:n], dtype=torch.float32)

        spectrum = torch.stack([mz, intensity], dim=-1)  # [max_peaks, 2]
        mask = torch.zeros(self.max_peaks, dtype=torch.bool)
        mask[:n] = True
        return spectrum, mask

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.ds[idx]

        # 1) 序列（优先使用 modified_sequence）
        seq = row.get("modified_sequence") or row.get("sequence")
        if seq is None:
            seq = ""

        encoded = self.tokenizer(
            seq,
            add_special_tokens=True,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        input_ids = encoded["input_ids"].squeeze(0)
        attention_mask = encoded["attention_mask"].squeeze(0)

        # 2) MS1 → 中性质量
        precursor_mz = float(row.get("precursor_mz", 0.0))
        charge = int(row.get("charge", row.get("precursor_charge", 1)))
        charge = max(charge, 1)
        neutral_mass = (precursor_mz - PROTON_MASS) * charge

        # 3) MS2 谱图
        mz_array = row.get("mz_array", [])
        intensity_array = row.get("intensity_array", [])

        spectrum, spec_mask = self._build_spectrum(mz_array, intensity_array)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "targets": input_ids,  # 语言建模目标
            "spectrum": spectrum,  # [max_peaks, 2]
            "spectrum_mask": spec_mask,  # [max_peaks]
            "precursor_mz": torch.tensor(precursor_mz, dtype=torch.float32),
            "charge": torch.tensor(charge, dtype=torch.int64),
            "neutral_mass": torch.tensor(neutral_mass, dtype=torch.float32),
        }


def collate_novobench(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """简单的 batch 拼接函数。"""
    keys = batch[0].keys()
    out: Dict[str, torch.Tensor] = {}
    for k in keys:
        if isinstance(batch[0][k], torch.Tensor):
            out[k] = torch.stack([b[k] for b in batch], dim=0)
        else:
            raise TypeError(f"Unsupported field type for key={k}")
    return out


def build_novobench_dataloaders(
    tokenizer_name: str,
    max_length: int = 100,
    max_peaks: int = 200,
    batch_size: int = 8,
    num_workers: int = 2,
    cache_dir: Optional[str] = None,
) -> Tuple[DataLoader, DataLoader]:
    """
    构建 train / validation dataloader。
    """
    train_ds = NovoBenchDataset(
        split="train",
        tokenizer_name=tokenizer_name,
        max_length=max_length,
        max_peaks=max_peaks,
        cache_dir=cache_dir,
    )
    val_ds = NovoBenchDataset(
        split="validation",
        tokenizer_name=tokenizer_name,
        max_length=max_length,
        max_peaks=max_peaks,
        cache_dir=cache_dir,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_novobench,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_novobench,
    )
    return train_loader, val_loader


