# 代码修改详细说明

本文档详细列出了所有修改的文件以及具体的代码变更。

## 📝 文件修改总览

| 文件路径 | 修改类型 | 说明 |
|---------|---------|------|
| `models/model.py` | **完全重写** | 实现了新的语义原型方法，移除了原有的VQ和Separator方法 |
| `run.py` | **部分修改** | 修改了训练流程，适配新的模型接口和损失函数 |
| `eval.py` | **部分修改** | 修改了测试流程，适配新的模型接口 |
| `args_parse.py` | **添加参数** | 添加了新方法所需的所有超参数 |
| `METHOD_README.md` | **新建文件** | 方法说明文档 |

---

## 1. `models/model.py` - 完全重写

### 1.1 移除的类
- ❌ `Separator` - 原有的分离器类
- ❌ `DiscreteEncoder` - 原有的离散编码器类（使用VQ）
- ❌ `MyModel` 中的 `mix_cs_proj` 方法

### 1.2 新增的类

#### ① `SemanticPrototypeModule` (第14-43行)
```python
class SemanticPrototypeModule(nn.Module):
    """可学习语义原型模块"""
    def __init__(self, emb_dim, num_prototypes, temperature=1.0):
        # 语义原型参数: P = {p_k}_{k=1}^K
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, emb_dim))
    
    def forward(self, node_feat):
        # 计算分配权重: alpha_ik = softmax(f_i^T p_k / tau)
        # 语义分量: s_i = sum_k alpha_ik * p_k
```

**功能**：实现节点到语义原型的软分配，计算节点级语义分量。

---

#### ② `NodeDecompositionModule` (第46-89行)
```python
class NodeDecompositionModule(nn.Module):
    """节点级语义-残差分解模块"""
    def __init__(self, args, config):
        self.gnn = ...  # GNN编码器
        self.prototype_module = SemanticPrototypeModule(...)
    
    def forward(self, data):
        # F = GNN编码
        # S = 语义分量
        # R = F - S (残差分量)
```

**功能**：集成GNN编码器和语义原型模块，输出节点特征F、语义分量S和残差R。

---

#### ③ `GraphLevelSemanticModule` (第91-134行)
```python
class GraphLevelSemanticModule(nn.Module):
    """分子级语义汇聚模块"""
    def forward(self, semantic_feat, batch):
        # z_G = Pool(S) = 分子级语义表示
    
    def get_topk_neighbors(self, z_G, k=None):
        # 获取top-k语义近邻用于分子间对齐
```

**功能**：
- 将节点级语义分量聚合为分子级语义表示
- 计算分子间语义相似度
- 选择top-k近邻用于分子间语义对齐

---

#### ④ `ContrastiveProjectionHead` (第137-151行)
```python
class ContrastiveProjectionHead(nn.Module):
    """对比学习投影头"""
    def __init__(self, emb_dim, proj_dim=None):
        self.proj_head = nn.Sequential(...)
    
    def forward(self, z):
        return self.proj_head(z)
```

**功能**：将分子级语义表示投影到对比学习空间。

---

#### ⑤ `AdversarialPerturbationModule` (第154-230行)
```python
class AdversarialPerturbationModule(nn.Module):
    """对抗扰动生成模块（内层最大化）"""
    def generate_perturbation(self, residual_feat, semantic_feat, batch, ...):
        # 内层最大化：在残差子空间中搜索最坏扰动
        for step in range(self.inner_steps):
            # 1. 构造扰动: tilde_f = s_d + (r_d + delta)
            # 2. 计算对比损失
            # 3. 梯度上升更新delta
            # 4. 投影到约束集合: ||delta||_2 <= epsilon
```

**功能**：实现双层优化的内层最大化，在约束集合内搜索最坏扰动。

---

#### ⑥ 新的 `MyModel` 类 (第232-403行)

**主要变化**：

1. **初始化部分** (第233-283行)
   ```python
   def __init__(self, args, config):
       # 移除了: self.separator, self.encoder
       # 新增了:
       self.node_decomp = NodeDecompositionModule(...)
       self.graph_semantic = GraphLevelSemanticModule(...)
       self.contrastive_head = ContrastiveProjectionHead(...)
       self.adv_perturb = AdversarialPerturbationModule(...)
   ```

2. **forward方法** (第285-329行)
   ```python
   def forward(self, data, compute_adv=False):
       # 1. 节点分解: F, S, R = node_decomp(data)
       # 2. 分子级汇聚: z_G = graph_semantic(S, batch)
       # 3. 任务预测: logit = classifier(z_G)
       # 4. [训练时] 对抗扰动生成
       # 5. 计算各种损失
   ```
   
   **返回**：`logit, z_G, S, R, losses` (vs 原来的 `c_logit, c_f, s_f, cmt_loss, reg_loss`)

3. **新增损失计算方法**：
   - `compute_infonce_loss` (第331-349行) - InfoNCE对比损失
   - `compute_intra_loss` (第351-365行) - 分子内语义一致性损失
   - `compute_inter_loss` (第367-403行) - 分子间语义对齐损失

---

## 2. `run.py` - 训练流程修改

### 2.1 `train_step` 方法修改 (第150-193行)

**原代码**：
```python
def train_step(self, epoch):
    # 交替训练separator和encoder
    if epoch % 4 in range(1):
        set_requires_grad([self.model.separator], requires_grad=True)
        set_requires_grad([self.model.encoder], requires_grad=False)
    else:
        set_requires_grad([self.model.separator], requires_grad=False)
        set_requires_grad([self.model.encoder], requires_grad=True)
    
    # 模型前向传播
    c_logit, c_f, s_f, cmt_loss, reg_loss = self.model(data)
    
    # SimSiam损失
    mix_f = self.model.mix_cs_proj(c_f, s_f)
    inv_loss = self.simsiam_loss(c_f, mix_f)
    
    # 总损失
    loss = cls_loss + cmt_loss + self.args.inv_w * inv_loss + self.args.reg_w * reg_loss
```

**新代码**：
```python
def train_step(self, epoch):
    # 移除了交替训练逻辑（不再需要）
    
    # 模型前向传播（添加compute_adv=True）
    logit, z_G, S, R, losses = self.model(data, compute_adv=True)
    
    # 任务损失
    task_loss = self.metric.loss_func(logit, target.float(), ...)
    
    # 总损失 = 任务损失 + 对比损失 + 语义正则化损失
    total_loss = (task_loss + 
                 losses['contrastive'] + 
                 self.args.lambda_intra * losses['intra'] + 
                 self.args.lambda_inter * losses['inter'])
    
    # 移除了: simsiam_loss方法
    # 新增了: 更详细的tensorboard日志记录
```

**关键变化**：
- ❌ 移除了交替训练逻辑（`set_requires_grad`）
- ❌ 移除了 `simsiam_loss` 方法（第208-212行已删除）
- ✅ 修改模型调用：添加 `compute_adv=True` 参数
- ✅ 修改损失计算：使用新的损失字典结构
- ✅ 更新TensorBoard日志：记录5种损失（total, task, contrastive, intra, inter）

---

### 2.2 `test_step` 方法修改 (第196-206行)

**原代码**：
```python
@torch.no_grad()
def test_step(self, loader):
    logit, _, _, _, _ = self.model(data)
```

**新代码**：
```python
@torch.no_grad()
def test_step(self, loader):
    # 推理时不计算对抗扰动
    logit, _, _, _, _ = self.model(data, compute_adv=False)
```

**关键变化**：
- ✅ 添加 `compute_adv=False` 参数，推理时不进行对抗扰动计算

---

## 3. `eval.py` - 评估流程修改

### 3.1 `test_step` 方法修改 (第91-106行)

**原代码**：
```python
@torch.no_grad()
def test_step(self, loader):
    logit, _, _, _, _ = self.model(data)
```

**新代码**：
```python
@torch.no_grad()
def test_step(self, loader):
    # 推理时不计算对抗扰动
    logit, _, _, _, _ = self.model(data, compute_adv=False)
```

**关键变化**：
- ✅ 添加 `compute_adv=False` 参数

---

## 4. `args_parse.py` - 参数配置修改

### 4.1 移除的参数 (第26-40行)
```python
# ❌ 移除了VQ相关参数
# parser.add_argument("--num_e", default=4000, type=int)
# parser.add_argument("--commitment_weight", default=0.1, type=float)

# ❌ 移除了原方法的损失权重参数
# parser.add_argument("--inv_w", default=0.01, type=float)  # lambda_1
# parser.add_argument("--reg_w", default=0.5, type=float)   # lambda_2
# parser.add_argument("--gamma", default=0.8, type=float)   # threshold gamma
```

### 4.2 新增的参数 (第33-63行)

#### ① 语义原型模块参数 (第33-37行)
```python
parser.add_argument("--num_prototypes", default=64, type=int,
                    help="Number of semantic prototypes K")
parser.add_argument("--prototype_temperature", default=1.0, type=float,
                    help="Temperature parameter tau for prototype assignment")
```

#### ② 分子级语义模块参数 (第39-43行)
```python
parser.add_argument("--top_k", default=5, type=int,
                    help="Top-k neighbors for inter-molecular semantic alignment")
parser.add_argument("--inter_temperature", default=1.0, type=float,
                    help="Temperature parameter rho for inter-molecular weights")
```

#### ③ 对比学习参数 (第45-49行)
```python
parser.add_argument("--proj_dim", default=None, type=int,
                    help="Projection dimension for contrastive learning")
parser.add_argument("--contrastive_temperature", default=0.1, type=float,
                    help="Temperature parameter gamma for InfoNCE loss")
```

#### ④ 对抗训练参数 (第51-57行)
```python
parser.add_argument("--epsilon", default=0.1, type=float,
                    help="Perturbation bound epsilon")
parser.add_argument("--inner_steps", default=3, type=int,
                    help="Number of steps T for inner maximization")
parser.add_argument("--inner_lr", default=0.1, type=float,
                    help="Learning rate eta_delta for inner maximization")
```

#### ⑤ 损失权重参数 (第59-63行)
```python
parser.add_argument("--lambda_intra", default=0.1, type=float,
                    help="Weight lambda_intra for intra-molecular semantic consistency")
parser.add_argument("--lambda_inter", default=0.1, type=float,
                    help="Weight lambda_inter for inter-molecular semantic alignment")
```

---

## 5. `METHOD_README.md` - 新建文档

创建了完整的方法说明文档，包括：
- 方法概述
- 核心组件说明
- 主要改动
- 使用方法
- 方法流程图
- 注意事项
- 与原方法的区别

---

## 📊 代码统计

| 文件 | 新增行数 | 删除行数 | 修改行数 | 说明 |
|-----|---------|---------|---------|------|
| `models/model.py` | ~400 | ~150 | - | 完全重写，新增6个类 |
| `run.py` | ~50 | ~50 | ~30 | 修改训练和测试流程 |
| `eval.py` | ~2 | 0 | ~2 | 修改测试接口 |
| `args_parse.py` | ~30 | ~10 | 0 | 添加新参数 |
| `METHOD_README.md` | ~135 | 0 | 0 | 新建文档 |

---

## 🔍 关键接口变化

### 模型接口变化

**原接口**：
```python
c_logit, c_f, s_f, cmt_loss, reg_loss = model(data)
```

**新接口**：
```python
logit, z_G, S, R, losses = model(data, compute_adv=False)
# losses = {
#     'contrastive': tensor,
#     'intra': tensor,
#     'inter': tensor
# }
```

### 损失函数变化

**原损失**：
```
L = L_cls + L_cmt + lambda_1 * L_inv + lambda_2 * L_reg
```

**新损失**：
```
L = L_task + L_con + lambda_intra * L_intra + lambda_inter * L_inter
```

---

## ⚠️ 注意事项

1. **向后兼容性**：新代码**不兼容**原有模型检查点，需要重新训练
2. **参数名称**：所有参数名称已更改，请更新训练脚本
3. **推理模式**：必须使用 `compute_adv=False`，否则会影响性能
4. **Batch Size**：建议batch size >= 32，以确保分子间对齐效果

---

## 🚀 迁移指南

如果你想从旧代码迁移到新代码：

1. **更新训练命令**：
   ```bash
   # 旧命令
   python run.py --num_e 4000 --inv_w 0.01 --reg_w 0.5 --gamma 0.8 ...
   
   # 新命令
   python run.py --num_prototypes 64 --lambda_intra 0.1 --lambda_inter 0.1 --epsilon 0.1 ...
   ```

2. **更新模型加载**：旧的checkpoint无法直接使用，需要重新训练

3. **检查依赖**：确保所有新参数都有默认值，可以直接运行

