# Training Comparison Visualization

这个功能允许你比较三个不同训练集在6个evaluation datasets上的表现，并生成可视化图表。

## 功能概述

- **6个Evaluation Datasets**: helpful-reject, helpful-rewrite, math-reject, math-rewrite, safety-reject, safety-rewrite
- **可视化**: 生成6张图，每张图对应一个数据集，每张图上有3条线对应3个训练集
- **横坐标**: Epoch
- **纵坐标**: Accuracy (准确率)

## 使用方法

### 1. 训练模型并收集数据

使用修改后的 `pd_train_accelerate.py` 进行训练：

```bash
python pd_train_accelerate.py --config_file pd_train_accelerate.yaml
```

训练过程中会自动在每个epoch结束时：
- 运行cross-prediction测试
- 保存结果到 `prompt_decoder/training_results/training_results_{run_name}.json`

### 2. 生成比较图表

使用 `plot_training_comparison.py` 生成比较图表：

```bash
# 基本用法：比较3个训练结果
python plot_training_comparison.py \
    --json_files \
        prompt_decoder/training_results/training_results_init_raw_bs32_lr1e-4_linear.json \
        prompt_decoder/training_results/training_results_ft_norm_bs32_lr5e-5_h16384.json \
        prompt_decoder/training_results/training_results_init_raw_bs32_lr3e-4_h8192.json \
    --output_dir plots

# 生成单独的图表（每个数据集一张图）
python plot_training_comparison.py \
    --json_files file1.json file2.json file3.json \
    --output_dir plots \
    --individual
```

### 3. 测试功能

运行测试脚本验证功能：

```bash
python test_plot_functionality.py
```

## 文件结构

```
├── pd_train_accelerate.py          # 修改后的训练脚本
├── pd_train_accelerate.yaml        # 训练配置文件
├── plot_training_comparison.py     # 绘图脚本
├── test_plot_functionality.py      # 测试脚本
├── README_training_comparison.md   # 本说明文档
└── prompt_decoder/
    └── training_results/
        ├── training_results_init_raw_bs32_lr1e-4_linear.json
        ├── training_results_ft_norm_bs32_lr5e-5_h16384.json
        └── training_results_init_raw_bs32_lr3e-4_h8192.json
```

## JSON数据格式

每个训练结果JSON文件包含：

```json
{
  "run_name": "80K",
  "epochs": [
    {
      "epoch": 0,
      "timestamp": 1234567890,
      "rewritten_results": {
        "test_subset_used": true,
        "six_datasets_accuracy": {
          "helpful_reject": {"correct": 45, "total": 50},
          "helpful_rewrite": {"correct": 42, "total": 50},
          "math_reject": {"correct": 38, "total": 50},
          "math_rewrite": {"correct": 40, "total": 50},
          "safety_reject": {"correct": 48, "total": 50},
          "safety_rewrite": {"correct": 46, "total": 50}
        }
      }
    }
  ]
}
```

## 图表说明

### 6张子图对应6个数据集：
1. **Helpful - Reject**: 帮助性任务中的拒绝比较
2. **Helpful - Rewrite**: 帮助性任务中的重写比较  
3. **Math - Reject**: 数学任务中的拒绝比较
4. **Math - Rewrite**: 数学任务中的重写比较
5. **Safety - Reject**: 安全性任务中的拒绝比较
6. **Safety - Rewrite**: 安全性任务中的重写比较

### 线条说明：
- 每条线代表一个训练配置
- 标签直接使用training dataset的名字（run_name）
- 例如：`80K` 表示使用80K数据集训练的结果

## 配置参数说明

### 训练配置 (YAML)
- `fine_tune`: 是否使用SAE参数进行fine-tuning
- `normalize`: 是否对参数进行归一化
- `hidden_layer`: 隐藏层维度（0表示线性层）
- `learning_rate`: 学习率
- `batch_size`: 批次大小
- `loss_type`: 损失函数类型（"mse"或"cosine"）

### 绘图配置
- `--json_files`: 要比较的JSON结果文件列表
- `--output_dir`: 输出图片的目录
- `--individual`: 是否为每个数据集生成单独的图片

## 注意事项

1. **数据收集**: 确保训练过程中每个epoch都完成了cross-prediction测试
2. **文件路径**: JSON文件路径要正确，建议使用绝对路径
3. **依赖库**: 需要安装matplotlib, numpy等绘图库
4. **内存使用**: 大量epoch数据可能占用较多内存

## 故障排除

### 常见问题：

1. **JSON文件未找到**
   - 检查文件路径是否正确
   - 确认训练已完成并生成了结果文件

2. **图片生成失败**
   - 检查matplotlib是否正确安装
   - 确认输出目录有写入权限

3. **数据格式错误**
   - 检查JSON文件格式是否正确
   - 确认包含必要的字段

4. **内存不足**
   - 减少同时比较的训练结果数量
   - 使用`--individual`参数分别生成图片
