# Evolution Strategies for Math Reasoning

This repository contains code for training large language models on math reasoning tasks using **Evolution Strategies (ES)**. Training is performed using **Ray** and **vLLM** for efficient multi-GPU rollout evaluation and supports multi-sample stochastic rollouts per prompt.

The primary entry point is `train.py`.

---

## Requirements

- Python **≥ 3.10**
- CUDA-enabled GPUs (multi-GPU strongly recommended)
- Linux environment
- CUDA, PyTorch, Ray, and vLLM properly installed

---

## Setup

### 1. Create a virtual environment

```bash
python -m venv es
source es/bin/activate
```

### 2. Install the repository

From the repository root:

```bash
pip install -e .
```

### 3. Install additional dependencies

```bash
pip install vllm==0.11.0
pip install wandb
```

### 4. Install Math dependencies

```bash
pip install math-verify
pip install pylatexenc
pip install latex2sympy2_extended
```

> **Note**  
> Ensure `nvidia-smi` works and Ray can see your GPUs.

---

## Datasets

Training and evaluation datasets must be saved using HuggingFace Datasets’
`save_to_disk()` format.

Each example must contain:

```json
{
  "problem": "...",
  "answer": "..."
}
```
The training and evaluation data is included in this codebase. This data is also publically available and it is possible to download the data in the correct format from the follow github repositry and many other sources: https://github.com/sail-sg/understand-r1-zero/tree/main/datasets. 

### Example directory structure

```
datasets/
├── train/
│   └── math_lvl3to5_8k/
└── evaluation_suite/
    ├── math/
    ├── aime/
    ├── math
    ├── minerva
    └── olympiad_bench
    
```

- `--train-dataset` points to the training dataset directory
- `--eval-dataset` points to the evaluation suite directory

---

## Usage

### Basic Training Run

To fine-tune **Qwen/Qwen2.5-Math-7B** using ES:

```bash
python train.py     --model "Qwen/Qwen2.5-Math-7B"     --sigma 0.001     --population-size 30     --mu 1     --n-iterations 1000     --eval-freq 16     --train-dataset "datasets/train/math_lvl3to5_8k"     --eval-dataset "datasets/evaluation_suite"     --batch-size 99999     --mini-batch-size 1024     --max-tokens 3000     --n-vllm-engines 8     --logging "wandb"
```

---

## Command-Line Arguments

### Evolution Strategies

| Argument | Description |
|--------|-------------|
| `--sigma` | Noise scale for ES perturbations |
| `--alpha` | Learning rate (defaults to `sigma / 2`) |
| `--population-size` | Number of ES perturbations per generation |
| `--mu` | Number of ES generations per data batch |
| `--reward-shaping` | `z-scores`, `centered-ranks`, or `softmax-centered-ranks` |
| `--mirror-sampling` | Enable mirrored ES sampling |

---

### Sampling & Rollouts

| Argument | Description |
|--------|-------------|
| `--n-samples` | Number of rollouts per prompt |
| `--rollout-reduce` | Reduce rollouts using `mean` or `max` |
| `--train-temperature` | Training sampling temperature |
| `--train-top-p` | Training top-p |
| `--eval-temperature` | Evaluation sampling temperature |
| `--eval-top-p` | Evaluation top-p |

---

### Data & Batching

| Argument | Description |
|--------|-------------|
| `--batch-size` | Number of examples per ES generation |
| `--mini-batch-size` | Mini-batch size for rollout pipelining |
| `--per-member-random-batch` | Use per-member random batches |

---

### Infrastructure & Logging

| Argument | Description |
|--------|-------------|
| `--n-vllm-engines` | Number of persistent vLLM engines (usually = GPUs) |
| `--logging` | Logging backend (`wandb` or none) |
| `--seed` | Random seed |

---

## Training Behavior

1. Ray launches persistent vLLM engines (one per GPU).
2. For each ES generation:
   - Model weights are perturbed per population member.
   - Rollouts are generated on the training batch.
   - Rewards are computed using a math grader.
3. Rewards are normalized and used to update the base model.
4. Evaluation runs every `--eval-freq` iterations on all evaluation datasets.
5. Best-performing population members are checkpointed automatically.

---

## Outputs

Each run produces a timestamped experiment directory:

```
experiments/
└── es-finetuned-math-.../
    ├── checkpoints/
    ├── eval-output/
    └── train-output/
```

### Checkpoints

- Best population-eval checkpoints are saved automatically
- Final model weights are saved at the end of training

### Logging

If `--logging wandb` is enabled, the following are logged:

- Training reward statistics
- Evaluation pass@1 metrics
- ES hyperparameters
- Sampling diagnostics

---

## Notes & Tips

- For best throughput:
  - Match `--n-vllm-engines` to GPU count
  - Tune `--mini-batch-size` based on GPU memory
- Designed for **single-node, multi-GPU** training

---

## Entry Point

```bash
python train.py
```
