"""
Mass-aware DPLM for MS De Novo
--------------------------------

在标准 DPLM 的基础上加入：
1. 质量感知约束（Mass-Aware Transition）：在反向扩散采样时，利用 precursor neutral mass
   硬性裁剪不可能的氨基酸选择。
2. 兼容原始 DPLM 的训练接口：训练阶段仍然用掩码扩散目标，质量约束主要在生成阶段使用。

注意：
- 这里不直接依赖质谱 encoder（MS2），只使用 MS1 的 precursor_mz / charge 计算的总质量。
- 如需进一步引入 MS2 encoder，可在此类基础上扩展 cross-attention 条件。
"""

from dataclasses import dataclass, field
from typing import Dict, Optional

import math
import torch
import torch.nn as nn
from omegaconf import OmegaConf

from byprot.models import register_model
from byprot.models.dplm.dplm import DPLMConfig, DiffusionProteinLanguageModel
from byprot.models.utils import (
    sample_from_categorical,
    stochastic_sample_from_categorical,
)


# 氨基酸单同位素质量（含常用 PTM，对应 NovoBench 中的修饰）
AA_MASS: Dict[str, float] = {
    # canonical AAs (carbamidomethyl C)
    "G": 57.021464,
    "A": 71.037114,
    "S": 87.032028,
    "P": 97.052764,
    "V": 99.068414,
    "T": 101.047678,
    "C": 160.030649,  # C +57.021464 (carbamidomethyl)
    "L": 113.084064,
    "I": 113.084064,
    "N": 114.042927,
    "D": 115.026943,
    "Q": 128.058578,
    "K": 128.094963,
    "E": 129.042593,
    "M": 131.040485,
    "H": 137.058912,
    "F": 147.068414,
    "R": 156.101111,
    "Y": 163.063329,
    "W": 186.079313,
    # PTMs used in NovoBench (与数据卡片一致)
    "M(ox)": 147.035400,  # M +15.9949
    "N(deamidated)": 115.026943,  # N +0.9840
    "Q(deamidated)": 129.042594,  # Q +0.9840
}

H2O_MASS = 18.01056
PROTON_MASS = 1.007276466812


@dataclass
class MassAwareDPLMConfig(DPLMConfig):
    """在 DPLMConfig 基础上扩展质量相关配置。"""

    # MS1 质量约束
    mass_tolerance_ppm: float = field(default=50.0)  # 质量容差（ppm）
    min_aa_mass: float = field(default=57.0)  # 最小氨基酸质量（近似 G）
    max_aa_mass: float = field(default=200.0)  # 最大氨基酸质量（含PTM时略放宽）
    enable_mass_constraint: bool = field(default=True)

    # MS2 条件（谱 encoder）相关配置
    use_ms2_condition: bool = field(default=True)
    ms2_n_layers: int = field(default=2)          # Transformer 编码层数
    ms2_n_heads: int = field(default=8)           # 注意力头数
    ms2_dropout: float = field(default=0.1)       # dropout

    # 长度预测MLP配置
    use_length_predictor: bool = field(default=True)  # 是否使用MLP预测长度
    length_predictor_hidden_dim: int = field(default=128)  # MLP隐藏层维度
    min_length: int = field(default=5)  # 最小肽段长度
    max_length: int = field(default=60)  # 最大肽段长度


class SimpleSpectrumEncoder(nn.Module):
    """
    简单的 MS2 谱图编码器：
    - 输入: [B, N_peaks, 2] (mz, intensity)
    - 输出: [B, N_peaks, d_model] 特征 以及 [B, N_peaks] 有效峰 mask

    这里不依赖 DeepSearch-main，而是用一个轻量级 TransformerEncoder，
    目的是让 DPLM 在训练/推理时可以通过 cross-attention 感知 MS2 结构信息。
    """

    def __init__(
        self,
        d_model: int,
        n_head: int = 8,
        dim_feedforward: int = 2048,
        n_layers: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        # 简单线性投影，将 (mz, intensity) -> d_model
        self.input_proj = nn.Linear(2, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

    def forward(
        self,
        spectra: torch.Tensor,            # [B, N_peaks, 2]
        spectrum_mask: Optional[torch.Tensor] = None,  # [B, N_peaks], True = 有效峰
    ):
        B, N, _ = spectra.shape
        device = spectra.device

        x = self.input_proj(spectra)      # [B, N, d_model]

        if spectrum_mask is None:
            # 默认 intensity > 0 视为有效峰
            spectrum_mask = spectra[..., 1] > 0  # [B, N]

        # TransformerEncoder 使用 src_key_padding_mask=True 表示要 mask 掉的位置
        key_padding_mask = ~spectrum_mask  # [B, N]
        feats = self.encoder(x, src_key_padding_mask=key_padding_mask)

        return feats, spectrum_mask


@register_model("dplm_mass_aware")
class MassAwareDiffusionProteinLanguageModel(DiffusionProteinLanguageModel):
    """
    Mass-aware DPLM

    - 训练阶段：与原始 DPLM 相同的离散掩码扩散损失。
    - 生成阶段：在每一次 decoder step，对 logits 应用质量可行性检查。
    """

    _default_cfg = MassAwareDPLMConfig()

    def __init__(self, cfg, net: Optional[nn.Module] = None):
        super().__init__(cfg, net)
        # 覆盖 cfg 为合并后的配置
        self._update_cfg(cfg)

        # 基于 tokenizer 构建 token -> AA -> 质量 映射
        self._init_aa_mass_mapping()

        # 如果启用 MS2 条件，则构建简单的谱图编码器，用于 cross-attention
        self.spectrum_encoder: Optional[SimpleSpectrumEncoder] = None
        if self.cfg.use_ms2_condition:
            hidden_size = getattr(self.net.config, "hidden_size", None)
            if hidden_size is None:
                # 兜底：如果 net.config 没有 hidden_size，使用 512
                hidden_size = 512
            self.spectrum_encoder = SimpleSpectrumEncoder(
                d_model=hidden_size,
                n_head=self.cfg.ms2_n_heads,
                dim_feedforward=4 * hidden_size,
                n_layers=self.cfg.ms2_n_layers,
                dropout=self.cfg.ms2_dropout,
            )

        # 基于前体质量预测token长度的MLP
        self.length_predictor: Optional[nn.Module] = None
        if self.cfg.use_length_predictor:
            self.length_predictor = nn.Sequential(
                nn.Linear(1, self.cfg.length_predictor_hidden_dim),  # 输入：前体质量（标量）
                nn.ReLU(),
                nn.Linear(self.cfg.length_predictor_hidden_dim, self.cfg.length_predictor_hidden_dim),
                nn.ReLU(),
                nn.Linear(self.cfg.length_predictor_hidden_dim, 1),  # 输出：预测的长度
            )

    # --------------------------------------------------------------------- #
    # 配置 / 质量相关工具函数
    # --------------------------------------------------------------------- #
    def _update_cfg(self, cfg):
        # 覆写父类方法以使用 MassAwareDPLMConfig
        self.cfg = OmegaConf.merge(self._default_cfg, cfg)

    def _init_aa_mass_mapping(self):
        """根据 tokenizer 的 vocab 初始化 token 到氨基酸质量的映射。"""
        vocab = self.tokenizer.get_vocab()

        vocab_size = len(vocab)
        self.token_to_mass = {i: None for i in range(vocab_size)}

        # 反向映射：token_id -> token_str
        id_to_token = {tid: tok for tok, tid in vocab.items()}

        for token_id, token_str in id_to_token.items():
            # 常见：单个氨基酸字母，如 "A", "G" 等
            aa_symbol = token_str
            # ESM tokenizer 中可能会出现前后空格或特殊前缀，做个基本清洗
            aa_symbol = aa_symbol.strip()

            # 尝试直接匹配 canonical AA
            if len(aa_symbol) == 1 and aa_symbol in AA_MASS:
                self.token_to_mass[token_id] = AA_MASS[aa_symbol]
                continue

            # 尝试匹配简单 PTM 记法，例如 "M(ox)"
            if aa_symbol in AA_MASS:
                self.token_to_mass[token_id] = AA_MASS[aa_symbol]
                continue

        # 预构建 tensor 方便后续查询
        self.aa_mass_tensor = torch.zeros(vocab_size, dtype=torch.float32)
        self.valid_aa_mask = torch.zeros(vocab_size, dtype=torch.bool)

        for token_id, mass in self.token_to_mass.items():
            if mass is not None:
                self.aa_mass_tensor[token_id] = float(mass)
                self.valid_aa_mask[token_id] = True

    # ------------------------------------------------------------------ #
    # ESM 前向：可选感知 MS2（通过 cross-attention）
    # ------------------------------------------------------------------ #
    def _forward_with_ms2(
        self,
        input_ids: torch.Tensor,                 # [B, L]
        spectra: Optional[torch.Tensor] = None,  # [B, N_peaks, 2]
        spectrum_mask: Optional[torch.Tensor] = None,  # [B, N_peaks]
    ) -> torch.Tensor:
        """
        带可选 MS2 条件的前向：
        - 若提供 spectra，则通过 SimpleSpectrumEncoder 编码为 encoder_hidden_states，
          并通过 ESM 的 cross-attention 让 AA token attend 到谱图特征。
        """
        encoder_hidden_states = None
        encoder_attention_mask = None

        if self.spectrum_encoder is not None and spectra is not None:
            spec_feats, spec_mask = self.spectrum_encoder(
                spectra, spectrum_mask
            )  # [B, N_peaks, d_model], [B, N_peaks]
            encoder_hidden_states = spec_feats
            # ESM 期望的 encoder_attention_mask: 1/True = 保留, 0/False = mask
            encoder_attention_mask = spec_mask

        net_out = self.net(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
        return net_out["logits"]

    def predict_length_from_mass(
        self,
        precursor_neutral_mass: torch.Tensor,  # [B] or scalar
    ) -> torch.Tensor:
        """
        根据前体质量预测输入噪声token的长度。
        
        Args:
            precursor_neutral_mass: 前体中性质量 [B] or scalar
            
        Returns:
            预测的长度 [B] or scalar，已裁剪到 [min_length, max_length] 范围内
        """
        if self.length_predictor is None:
            # 如果没有长度预测器，使用简单的启发式方法
            # 假设平均氨基酸质量约为110 Da
            avg_aa_mass = 110.0
            predicted_length = (precursor_neutral_mass - H2O_MASS) / avg_aa_mass
            predicted_length = torch.clamp(
                predicted_length,
                min=self.cfg.min_length,
                max=self.cfg.max_length
            )
            return predicted_length.round().long()
        
        # 使用MLP预测长度
        if precursor_neutral_mass.dim() == 0:
            mass_input = precursor_neutral_mass.unsqueeze(0).unsqueeze(-1)  # [1, 1]
        else:
            mass_input = precursor_neutral_mass.unsqueeze(-1)  # [B, 1]
        
        predicted_length = self.length_predictor(mass_input)  # [B, 1] or [1, 1]
        predicted_length = predicted_length.squeeze(-1)  # [B] or [1]
        
        # 裁剪到合理范围并取整
        predicted_length = torch.clamp(
            predicted_length,
            min=self.cfg.min_length,
            max=self.cfg.max_length
        )
        return predicted_length.round().long()

    @torch.no_grad()
    def calculate_peptide_mass(
        self,
        tokens: torch.Tensor,
        charge: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        根据当前 token 序列近似计算肽段中性质量或 m/z。

        Args:
            tokens: [B, L] token 序列
            charge: [B] 电荷（可选，若提供则返回 m/z）
        """
        device = tokens.device
        batch_size, seq_len = tokens.shape

        # map token_id -> mass (缺省为 0)
        masses = self.aa_mass_tensor.to(device)[tokens.clamp(min=0)]
        # 忽略特殊 token：pad / bos / eos / mask / x
        special_mask = (
            (tokens == self.pad_id)
            | (tokens == self.bos_id)
            | (tokens == self.eos_id)
            | (tokens == self.mask_id)
            | (tokens == self.x_id)
        )
        masses = masses.masked_fill(special_mask, 0.0)

        # sum over length + H2O
        mass_sum = masses.sum(dim=1) + H2O_MASS

        if charge is not None:
            charge = charge.to(device).float().clamp(min=1.0)
            mz = mass_sum / charge + PROTON_MASS
            return mz
        return mass_sum

    def _apply_mass_constraint_to_logits(
        self,
        logits: torch.Tensor,  # [B, L, V]
        output_tokens: torch.Tensor,  # [B, L]
        output_masks: torch.Tensor,  # [B, L] True = 待预测位置
        precursor_neutral_mass: torch.Tensor,  # [B]
    ) -> torch.Tensor:
        """
        对当前 step 的 logits 应用质量可行性约束。
        朴素实现：对每个 batch / 每个位置 / 每个候选 token 做背包式检查。
        """
        if not self.cfg.enable_mass_constraint:
            return logits

        device = logits.device
        batch_size, seq_len, vocab_size = logits.shape

        constrained = logits.clone()

        # 预计算当前序列（已确定位置）的质量
        # 只把非 mask 的位置计入质量
        with torch.no_grad():
            for b in range(batch_size):
                target_mass = float(precursor_neutral_mass[b].item())

                # 当前确定位置的质量
                cur_tokens = output_tokens[b]  # [L]
                # 当前位置 mask=False 表示已经确定
                determined_mask = ~output_masks[b]
                det_tokens = cur_tokens[determined_mask].unsqueeze(0)  # [1, <=L]
                if det_tokens.numel() > 0:
                    cur_mass = self.calculate_peptide_mass(det_tokens)[0].item()
                else:
                    cur_mass = H2O_MASS

                # 对每个待预测位置单独检查
                for pos in range(seq_len):
                    if not output_masks[b, pos]:
                        continue

                    # 剩余可以填的位置数（不含当前 pos）
                    remain_pos = int(output_masks[b, pos + 1 :].sum().item())

                    # 对每个候选 token 检查可行性
                    for token_id in range(vocab_size):
                        if not self.valid_aa_mask[token_id]:
                            # 对非 AA / 特殊 token，本身就不希望在内部位置出现
                            continue

                        aa_mass = float(self.aa_mass_tensor[token_id].item())
                        new_mass = cur_mass + aa_mass
                        remaining_mass = target_mass - new_mass

                        # 允许的质量范围
                        tol_da = target_mass * self.cfg.mass_tolerance_ppm / 1e6
                        min_possible = remain_pos * self.cfg.min_aa_mass
                        max_possible = remain_pos * self.cfg.max_aa_mass

                        # 无论如何，剩余质量不能为负
                        if remaining_mass < -tol_da:
                            constrained[b, pos, token_id] = -math.inf
                            continue

                        # 如果剩余质量不在 [min_possible - tol, max_possible + tol] 之内，则视为不可行
                        if (remaining_mass + tol_da) < min_possible or (
                            remaining_mass - tol_da
                        ) > max_possible:
                            constrained[b, pos, token_id] = -math.inf

        return constrained

    # --------------------------------------------------------------------- #
    # 训练：在 DPLM 掩码扩散损失的基础上，引入 MS1 质量感知约束
    # --------------------------------------------------------------------- #
    def compute_loss(self, batch, weighting: str = "constant"):
        """
        训练损失：在原始 DPLM 掩码扩散损失的基础上，引入 MS1 质量感知约束。

        与父类相比的区别：
        - 如果 batch 中提供了 ``neutral_mass``（由 MS1 计算的中性质量），
          则在计算交叉熵前对 logits 施加一次与生成阶段相同的质量可行性约束。
        - 这样反向传播会“感知”到质量约束，相当于在训练阶段也做了一次
          基于谱(precursor mass)的硬 mask（在 logits 上），接近于训练时的 remask 策略。
        """
        target = batch["targets"]
        neutral_mass = batch.get("neutral_mass", None)  # [B] or None

        # 下面基本复制自 DiffusionProteinLanguageModel.compute_loss，
        # 只是我们需要拿到中间的 x_t 和 loss_mask，以便调用质量约束。
        t1, t2 = torch.randint(
            1,
            self.cfg.num_diffusion_timesteps + 1,
            (2 * target.size(0),),
            device=target.device,
        ).chunk(2)

        # 先获取 spectra 和 spectrum_mask，以便在 rdm_couple 模式下同步 repeat
        spectra = batch.get("spectrum", None)
        spectrum_mask = batch.get("spectrum_mask", None)

        if self.cfg.rdm_couple:
            x_t, t, loss_mask = list(
                self.q_sample_coupled(
                    target,
                    t1,
                    t2,
                    maskable_mask=self.get_non_special_symbol_mask(target),
                ).values()
            )
            target = target.repeat(2, 1)
            if neutral_mass is not None:
                neutral_mass = neutral_mass.repeat(2)
            # 同步 repeat spectra 和 spectrum_mask
            if spectra is not None:
                spectra = spectra.repeat(2, 1, 1)
            if spectrum_mask is not None:
                spectrum_mask = spectrum_mask.repeat(2, 1)
        else:
            x_t, t, loss_mask = list(
                self.q_sample(
                    target,
                    t1,
                    maskable_mask=self.get_non_special_symbol_mask(target),
                ).values()
            )

        # 使用 MS2 条件 + mass-aware 约束进行前向：
        # 1) 先用（可选）MS2 cross-attention 计算 logits
        logits = self._forward_with_ms2(
            x_t,
            spectra=spectra,
            spectrum_mask=spectrum_mask,
        )

        # 2) 如果提供了 neutral_mass，则在训练阶段同样施加质量约束，
        #    约束只作用在当前需要预测的位置（loss_mask 为 True 的 token）
        if neutral_mass is not None and self.cfg.enable_mass_constraint:
            # output_masks 语义与 loss_mask 一致：True 表示需要预测的位置
            logits = self._apply_mass_constraint_to_logits(
                logits=logits,
                output_tokens=x_t,
                output_masks=loss_mask.bool(),
                precursor_neutral_mass=neutral_mass,
            )

        num_timesteps = self.cfg.num_diffusion_timesteps
        weight = {
            "linear": (num_timesteps - (t - 1)),
            "constant": num_timesteps * torch.ones_like(t),
        }[weighting][:, None].float() / num_timesteps

        return logits, target, loss_mask, weight

    # --------------------------------------------------------------------- #
    # 生成：在 forward_decoder 中加入 mass-aware transition
    # --------------------------------------------------------------------- #
    def forward_decoder(
        self,
        prev_decoder_out,
        encoder_out=None,
        need_attn_weights: bool = False,
        partial_masks=None,
        sampling_strategy: str = "gumbel_argmax",
        disable_resample: bool = True,
        resample_ratio: float = 0.25,
        precursor_neutral_mass: Optional[torch.Tensor] = None,
        spectra: Optional[torch.Tensor] = None,
        spectrum_mask: Optional[torch.Tensor] = None,
    ):
        """
        与基类基本一致，但在采样前对 logits 施加质量约束。

        Args:
            precursor_neutral_mass: [B] 由 MS1 得到的目标中性质量。
        """
        output_tokens = prev_decoder_out["output_tokens"].clone()
        output_scores = prev_decoder_out["output_scores"].clone()
        step, max_step = prev_decoder_out["step"], prev_decoder_out["max_step"]
        temperature = prev_decoder_out["temperature"]
        history = prev_decoder_out["history"]

        output_masks = self.get_non_special_symbol_mask(
            output_tokens, partial_masks=partial_masks
        )

        # 若提供 MS2 谱图，则使用 cross-attention 作为条件
        encoder_hidden_states = None
        encoder_attention_mask = None

        if self.spectrum_encoder is not None and spectra is not None:
            spec_feats, spec_mask = self.spectrum_encoder(
                spectra, spectrum_mask
            )
            encoder_hidden_states = spec_feats
            encoder_attention_mask = spec_mask

        net_out = self.net(
            input_ids=output_tokens,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )

        logits = net_out["logits"]
        attentions = net_out.get("attentions", None) if need_attn_weights else None

        if logits.dtype != output_scores.dtype:
            logits = logits.type_as(output_scores)

        # 先屏蔽特殊符号
        logits[..., self.mask_id] = -math.inf
        logits[..., self.x_id] = -math.inf
        logits[..., self.pad_id] = -math.inf
        logits[..., self.bos_id] = -math.inf
        logits[..., self.eos_id] = -math.inf

        # 质量约束（只在提供 precursor 信息时启用）
        if (
            precursor_neutral_mass is not None
            and self.cfg.enable_mass_constraint
            and step < max_step  # 最后一步可放宽一点
        ):
            logits = self._apply_mass_constraint_to_logits(
                logits,
                output_tokens=output_tokens,
                output_masks=output_masks,
                precursor_neutral_mass=precursor_neutral_mass,
            )

        # 采样
        if sampling_strategy == "vanilla":
            _tokens, _scores = sample_from_categorical(
                logits, temperature=temperature
            )
        elif sampling_strategy == "argmax":
            _scores, _tokens = logits.max(-1)
        elif sampling_strategy == "gumbel_argmax":
            noise_scale = 1.0
            _tokens, _scores = stochastic_sample_from_categorical(
                logits, temperature=0.0, noise_scale=noise_scale
            )

            if not disable_resample:
                self.resample(
                    _tokens, _scores, ratio=resample_ratio, scale=1.0
                )
        else:
            raise NotImplementedError

        output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
        output_scores.masked_scatter_(output_masks, _scores[output_masks])

        history.append(output_tokens.clone())

        return dict(
            output_tokens=output_tokens,
            output_scores=output_scores,
            attentions=attentions,
            step=step + 1,
            max_step=max_step,
            history=history,
            hidden_states=net_out["last_hidden_state"],
        )

    def generate(
        self,
        input_tokens: torch.Tensor,
        tokenizer=None,
        max_iter: int = 20,
        temperature: float = 1.0,
        partial_masks: Optional[torch.Tensor] = None,
        sampling_strategy: str = "gumbel_argmax",
        disable_resample: bool = False,
        resample_ratio: float = 0.25,
        precursor_neutral_mass: Optional[torch.Tensor] = None,
        spectra: Optional[torch.Tensor] = None,
        spectrum_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        生成肽段序列（反向扩散），支持 MS1 质量约束。

        Args:
            input_tokens: [B, L] 作为初始提示（通常是全部 mask 或部分已知残基）
            precursor_neutral_mass: [B] 目标中性质量
        """
        max_iter = max_iter or self.cfg.num_diffusion_timesteps
        temperature = temperature

        # 0) encoding（此处 encoder_out 不使用结构信息，仅作为占位）
        encoder_out = self.forward_encoder(input_tokens)

        # 1) 初始化为全 mask
        (
            initial_output_tokens,
            initial_output_scores,
        ) = self.initialize_output_tokens(
            input_tokens, encoder_out=encoder_out, partial_masks=partial_masks
        )
        prev_decoder_out = dict(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            output_masks=None,
            attentions=None,
            step=0,
            max_step=max_iter,
            history=[initial_output_tokens.clone()],
            temperature=temperature,
        )

        prev_decoder_out["output_masks"] = self.get_non_special_symbol_mask(
            prev_decoder_out["output_tokens"], partial_masks=partial_masks
        )

        for step in range(max_iter):
            with torch.no_grad():
                decoder_out = self.forward_decoder(
                    prev_decoder_out=prev_decoder_out,
                    encoder_out=encoder_out,
                    partial_masks=partial_masks,
                    sampling_strategy=sampling_strategy,
                    disable_resample=disable_resample,
                    resample_ratio=resample_ratio,
                    precursor_neutral_mass=precursor_neutral_mass,
                    spectra=spectra,
                    spectrum_mask=spectrum_mask,
                )

            output_tokens = decoder_out["output_tokens"]
            output_scores = decoder_out["output_scores"]

            prev_decoder_out.update(
                output_tokens=output_tokens,
                output_scores=output_scores,
                step=step + 1,
                history=decoder_out["history"],
            )

        return prev_decoder_out["output_tokens"]


