# TWNM 模型架构深度分析

## 概览

TWNM（“The World is Not Mono”）强调真实世界音频的多声源与多任务属性，具体实现上采用“双分支 Whisper & Spatial MoE”架构以完成统一的多任务音频理解。模型将冻结的 Whisper 语义编码器与自研的 SpatialEncoder 结合，通过 MoE 融合层将双分支特征映射到 Qwen2 音频大模型的隐藏空间，实现参数高效的跨任务生成。

- 主干实现：`src/twnm/models/twnm.py`
- 预训练/推理封装：`src/twnm/models/twnm_pretrained_model.py` 与 `src/twnm/models/twnm_sft2.py`
- 空间编码器：`src/twnm/models/spatial_encoder/`
- 训练与评估脚本：`train.py`, `train_grpo.py`, `evaluation_*.py`, `inference*.py`
- 强化学习支持：`grpo_trainer.py`, `rewards.py`

下文按模块拆解 TWNM 的主要构件与运行流程。

---

## 仓库结构一览

- `src/twnm/`：核心 Python 包，包含模型、数据、RL 组件与工具函数。
- `configs/`：训练、推理、Deepspeed 等 YAML 配置。
- `scripts/`：分门别类的命令行入口（训练 / 推理 / 评测 / 调试）。
- `datasets/`：随仓库附带的小样本数据与 MMAU 子集。
- `assets/`：大文件存放处（如空间编码器 ckpt、提取后的 Qwen2 权重）。
- `tools/benchmark/`：用于生成空间问答基准的 LLM workflow。
- `outputs/`：训练日志与结果，占地大的产物集中于此。

---

## 1. 主干模型 `src/twnm/models/twnm.py`

### 1.1 前向流程概览

```python
class TWNM(BaseModel):
    def forward(self, samples):
        audios = samples["audios"]              # [B, 2, T]
        text = samples["text"]                  # 文本标签/答案
        router_label = samples["router_label"]  # MoE 监督标签

        prompt = [t + " <AcousticTokens>" for t in samples.get("task", ["AAC"] * len(audios))]

        encoder_hidden_states, router_logits = self.forward_encoder(audios, prompt, router_label)
        encoder_atts = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long, device=audios.device)

        input_embeds, input_mask, decoder_targets = self.prepare_inputs_labels_for_multimodal(
            encoder_hidden_states, encoder_atts, prompt, text
        )
        decoder_output = self.decoder(
            input_ids=None,
            inputs_embeds=input_embeds,
            attention_mask=input_mask,
            labels=decoder_targets,
            return_dict=True,
        )

        ce_loss = decoder_output.loss
        router_loss = nn.BCEWithLogitsLoss()(router_logits, router_label.float())
        return {"loss": ce_loss + router_loss, "ce_loss": ce_loss, "router_loss": router_loss, "logits": decoder_output.logits}
```

流程关键步骤：

1. **提示拼装**：将任务提示与 `<AcousticTokens>` 标记组合，统一驱动生成式解码。
2. **双分支编码**：`forward_encoder` 同时调用 SpatialEncoder 与 Whisper，总结合并为 768 维序列。
3. **MoE 融合**：针对 Whisper 专家、四个空间专家与一条组合支路进行路由加权。
4. **多模态输入构建**：将左右提示嵌入、音频特征以及标签拼接后输入 Qwen2 解码器。
5. **损失**：文本交叉熵与 MoE 路由 BCE 之和。

### 1.2 `forward_encoder` 细节

```python
spatial_embeds = self.spatial_encoder.forward_as_encoder(audios)    # [B, 129, T', 192]
whisper_embeds = self.whisper.model.encoder(mels)['last_hidden_state']  # [B, 1500, 768]

spatial_projected = self.spatial_proj(spatial_embeds.reshape(B, T', 129 * 192))  # -> [B, T', 768]
spatial_aligned = interpolate(spatial_projected.transpose(1, 2), size=1500, mode="linear").transpose(1, 2)

combined_embeds = torch.add(self.spatial_norm(spatial_aligned), self.whisper_norm(whisper_embeds))

audio_context = self.audio_context_extractor(audios.mean(dim=1))    # [B, 256]
prompt_embeds = self.get_prompt_embeds(prompt, audios)              # [B, prompt_len, hidden]
moe_output, router_logits = self.moe_layer(whisper_embeds, spatial_aligned, combined_embeds, audio_context, prompt_embeds, ...)
```

要点：

- SpatialEncoder 输出的频率 × 时间 × 通道特征被展平后投影到 768 维，并使用线性插值对齐至 Whisper 帧长（默认 1500）。
- `AudioContextExtractor` 通过一维卷积与自适应池化从单声道波形中提取全局上下文，供 MoE 路由器使用。
- `prompt_embeds` 由解码器词嵌入直接生成，Teacher Forcing 时可使用外部标签指导专家选择。

### 1.3 MoE 融合层

`MoELayer` 由一个固定的组合专家、一个 Whisper 专家与四个空间专家组成。Router 接收音频上下文与 prompt 嵌入的均值，输出五维权重后与专家输出做加权平均。训练阶段可以按设定概率 (`teacher_forcing_ratio`) 使用标注好的 `router_label`。

---

## 2. 编码器组件

### 2.1 SpatialEncoder 接口

- 开源仓库默认提供 `TorchScriptSpatialEncoder` 封装（`assets/checkpoints/spatial_encoder/spatial_encoder.ts`），适用于推理部署。
- 若需要替换成自研实现，可设置环境变量 `TWNM_SPATIAL_ENCODER_MODULE=<模块路径>`，或在内网环境下将 `private_impl/spatial_encoder_impl` 加入 `PYTHONPATH`。
- 无论采用哪种实现，都应保持接口一致：输入双通道波形，输出 `[B, F, T, H]` 的空间特征，并支持从 checkpoint 加载冻结权重。

### 2.2 Whisper Encoder

TWNM 默认加载 `openai/whisper-small`，只使用 encoder 部分输出 1500 帧、768 维的语义特征。模型及特征提取器冻结，保证语义分支稳定。

### 2.3 对齐与归一化

- `self.spatial_proj`: 将频率 × 通道维展开后线性投影至 768 维。
- `F.interpolate`: 在时间维上进行线性插值，匹配 Whisper 帧长度。
- `LayerNorm`: 对两个分支独立归一化后，再做逐元素求和得到 `combined_embeds`。

---

## 3. 解码与提示工程

- 解码器：自带 LoRA 适配的 `Qwen2ForCausalLM`（`assets/checkpoints/qwen2-audio-llm-extracted`），以 bfloat16 运行并开启梯度检查点。
- Prompt 组装：`prepare_inputs_labels_for_multimodal` 负责将提示左/右两部分、音频 token 以及可选标签拼接成解码器输入。
- Loss 设计：使用标准交叉熵（忽略 `<AcousticTokens>` 左右的占位 token）以及 Router BCE。
- 推理：`generate` 支持束搜索或核采样，并设定自定义 `eos_token_id=151643` 以兼容 Qwen2 Audio 的特殊 token。

---

## 4. 预训练与推理封装

### 4.1 `twnm_pretrained_model.py`

基于 `PreTrainedModel` 的封装便于与 HuggingFace 生态集成，提供：

- `TWNMConfig`: 管理 Whisper、SpatialEncoder checkpoint、LoRA 超参等。
- 量化支持：可选 `BitsAndBytesConfig`。
- 适配器管理：封装 `change_to_policy`、`disable_adapters` 等便于在 SFT / GRPO 模式间切换。
- 权重加载：自动清理 SpatialEncoder CKPT 的 `"model."` 前缀，并冻结所有编码器参数。

### 4.2 `twnm_sft2.py`

专为第二阶段 SFT/推理定制，加入 `is_inference` 标记以跳过训练特有逻辑，并兼容已有 LoRA checkpoint 的加载。

---

## 5. 基础抽象与数据管线

### 5.1 `BaseModel` (`src/twnm/models/base_model.py`)

- 统一维护任务提示模板（AAC / ASR / S2TT / MC）。
- 提供 Encoder / Decoder 的 LoRA 包装策略与参数统计工具。
- 保留 `build_audio_qformer` 等通用接口（当前 TWNM 主干未使用，但保持兼容以便未来扩展）。

### 5.2 数据与训练脚本

- `train.py`: 使用 `TWNMTrainer`（继承自 HF `Trainer`），重写 `compute_loss` 与 `log` 以输出 CE 与 Router Loss；保存时仅保留可训练参数。
- `JsonlDataset`: 将 jsonl 中的音频路径、指令和答案预处理为双通道、固定长度的张量，并附带 MoE 路由标签。
- `collate_fn`: 打包批次，确保 router label 形状与模型输入匹配。

---

## 6. 强化学习与奖励设计

- `grpo_trainer.py` / `grpo_trainer_patched.py`: 在 HF `Trainer` 基础上适配 GRPO 流程，支持自定义奖励函数、参考模型同步与多样式生成配置。
- `grpo_dataset.py`: 针对多选问答场景构建 `(audio, prompt, solution)` 样本。
- `rewards.py`: 
  - `format_reward`: 强制输出包含 `|<think>| ... |</think>|` 与 `|<answer>| ... |</answer>|` 的结构。
  - `result_reward`: 检查答案是否与标准选项一致。
  - `length_reward`: 约束思维链长度在目标区间附近。

---

## 7. 推理、评测与工具

- `inference.py`, `inference_grpo.py`, `inference_grpo_list.py`: 统一的推理脚本，可加载训练配置与 checkpoint，输出带格式标签的答案。
- `evaluation_*.py`: 面向空间 QA、MMAU 等任务的评测器，支持断点续评、LoRA 适配器切换以及解析模型输出中的答案。
- `tools.py`, `filter_mmau.py`, `verify_loss.py`: 提供数据检查、过滤与快速验证工具。

---

## 8. 代码库现状与精简

- 旧版 CED、LPS、对比学习等模块已移除，仓库当前仅保留实际使用的 Whisper + SpatialEncoder + MoE 架构。
- 所有 `USAM` / `usam` 命名统一区分为 `TWNM` / `twnm`，脚本与路径亦同步更新。
- 通过 LoRA 与冻结策略，训练时主要调优 MoE 模块与 decoder 适配器，便于在有限算力下扩展到新任务。

---

本分析文档聚焦于仓库中的现行实现，后续如有架构变更，需同步更新本文与代码注释以保持一致。***
