## Strong/Weak LLM Training Pipeline

This repository provides a streamlined pipeline to augment clinical data, train a supervised fine-tuned (SFT) model on the augmented data, and optionally distill a strong expert via DPO.

## Part 1: Augmentation

- Place your raw JSON dataset in `./train_data/` (or provide an absolute path).
- Use the augmentation notebook `data.ipynb` to generate:
  - SFT-style augmented data → saved under `./data/augmented/`
  - Preference data for DPO → saved under `./data/augmented_dpo/`

Tips
- Keep task names consistent using `--task_type` in later steps: `readmission`, `mortality`, or `period`.
- You can also bring your own already-augmented JSON files and place them in the folders above.

## Part 2: Training with SFT on Augmented Data

Train a LoRA-adapted Qwen model on the augmented SFT dataset using `train_sft.py`. Supports automatic split, testing, and flexible hyperparameters.

Example: basic SFT training with testing
```bash
cd strong_weak_src
python train_sft.py \
  --dataset_path ./data/augmented/your_augmented_sft.json \
  --split_dataset \
  --test_size 200 \
  --task_type readmission \
  --sft_epochs 2 \
  --sft_lr 5e-5 \
  --lora_r 32 \
  --test_model
```

High-performance SFT (larger rank and batch)
```bash
python train_sft.py \
  --dataset_path ./data/augmented/your_augmented_sft.json \
  --split_dataset \
  --task_type period \
  --sft_epochs 3 \
  --lora_r 64 \
  --batch_size 4 \
  --gradient_accumulation_steps 2 \
  --test_model
```

Key arguments
- `--dataset_path`: Augmented SFT JSON file
- `--split_dataset`: First N go to test set, rest to train
- `--test_size`: N test examples (default 200)
- `--task_type`: `readmission` | `mortality` | `period`
- `--model_name`: Base model (default `Qwen/Qwen3-0.6B`)
- `--lora_r`, `--lora_dropout`: LoRA config
- `--sft_lr`, `--sft_epochs`, `--batch_size`, `--gradient_accumulation_steps`
- `--test_model`: Evaluate on saved test split using `qwen_inference.py` logic

Outputs
```
./trained_models/
└── {dataset_name}_Qwen3-0.6B-SFT_r{rank}_lr{lr}_ep{epochs}_bs{batch}_gas{gas}/

./test_data/
├── {dataset_name}_test.json
└── {dataset_name}_test_results_r{rank}_lr{lr}_ep{epochs}_bs{batch}_gas{gas}.json
```

## Part 3: Training a Strong Expert with DPO

After SFT, you can train a stronger expert via DPO using the preference dataset produced in Part 1.

Example: SFT followed by DPO in one run
```bash
python train_sft.py \
  --dataset_path ./data/augmented/your_augmented_sft.json \
  --dpo_dataset_path ./data/augmented_dpo/your_augmented_dpo.json \
  --split_dataset \
  --enable_dpo \
  --task_type mortality \
  --sft_epochs 1 \
  --dpo_epochs 1 \
  --sft_lr 4e-5 \
  --dpo_lr 4e-5 \
  --test_model
```

Key DPO arguments
- `--enable_dpo`: Run DPO after SFT
- `--dpo_dataset_path`: Path to preference JSON (required with `--enable_dpo`)
- `--dpo_lr`, `--dpo_epochs`: DPO optimizer config

Outputs
```
./trained_models/
└── {dataset_name}_Qwen3-0.6B-SFT-DPO_r{rank}_sftlr{lr}_dpolr{lr}_sftep{ep}_dpoep{ep}_bs{batch}_gas{gas}/
```

Metrics and evaluation
- Binary tasks (`readmission`, `mortality`): accuracy, F1, full classification report
- Regression (`period`): MAE, RMSE, and accuracy within tolerance
- Evaluation uses the same code path as `qwen_inference.py` for consistency.

## Requirements

Install dependencies from `requirements.txt` in `strong_weak_src`:
```bash
pip install -r requirements.txt
```

Includes: `datasets`, `trl`, `transformers`, `peft`, `wandb`, `torch`, `accelerate`, `sklearn`, `tqdm`.