# Ver@K Training in verl

This repository contains the codebase and experiments for "Learning to Correct: Reinforcement Learning for Multi-Attempt Chain-of-Thought" paper.
It is a fork of verl that implements verifier-in-the-loop multi-turn training, and step-aware GRPO advantage estimators.

## Fork provenance

- Upstream repo: volcengine/verl
- Fork point commit: f59c23a6d44c35bf2bf6a5461dd3846416740950
- This fork adds: Ver@K interaction logic, Ver@K-aware GRPO estimators, dataset
  preprocessing scripts, SGLang multi-turn configs, and experiment scripts/plots.

## What is Ver@K Retry?

Ver@K Retry trains a model to solve verifiable tasks when it is allowed up to K
attempts per prompt, with a verifier giving feedback after each attempt. The
interaction is multi-turn:

1. The model produces an attempt.
2. A verifier checks correctness.
3. If incorrect, the verifier returns feedback and the model retries.
4. The episode ends on first success or after K attempts.

This repo implements the interaction as a reusable component and provides
step-aware advantage estimators so that training can assign credit to the
attempts and steps that led to eventual success.

## Repository map (Ver@K-specific)

- `verl/interactions/ver_k_retry_interaction.py`
  Ver@K Retry interaction loop (verifier feedback, early stop, per-turn metadata).

- `examples/sglang_multiturn/config/interaction_config/`
  Ver@K interaction configs for math and maze tasks.

- `examples/reward_fns/`
  Verifier + reward functions used by the interaction:
  - `boxed_answer_reward.py` (MATH)
  - `maze_path_reward.py` (Maze)

- `examples/data_preprocess/*_ver_k_retry.py`
  Preprocess scripts that build the RL dataset with per-sample interaction_kwargs.

- `examples/sglang_multiturn/config/ver_k_retry_multiturn_grpo_w_interaction*.yaml`
  Minimal configs for Ver@K multi-turn GRPO (math + maze).

- `scripts/ver_k/`
  Slurm and local run scripts for Qwen3 4B Ver@K experiments.

- `scripts/analysis/plot_wandb_ver_k_estimators.py`
  W&B plotting script to compare Ver@K estimators.

## Data preparation

All scripts write parquet files into `./data/<dataset>_ver_k_retry_k<k>/`.

MATH:
```
python examples/data_preprocess/math_ver_k_retry.py \
  --k_max_attempts 2 --num_proc 8
```

Maze (generated):
```
python examples/data_preprocess/maze_5x5_ver_k_retry.py \
  --grid_n 5 --n_train 10000 --n_test 1000 --k_max_attempts 2
```

## Training

See `scripts/ver_k/` for scripts.
Key environment variables are:

- `K` (max attempts)
- `N` (GRPO group size / rollout.n)
- `ADV_ESTIMATOR` (Ver@K-aware estimator)
- `TRAIN_FILE`, `VAL_FILE`
- `MODEL`

Example (math, Qwen3-4B):
```
K=2 N=16 ADV_ESTIMATOR=grpo_verk_step_reward_step_norm \
  sbatch scripts/ver_k/run_qwen3_4b_math_ver_k_retry_grpo_sglang.sbatch
```

Maze example:
```
K=2 GRID_N=5 N=16 ADV_ESTIMATOR=grpo_verk_step_reward_step_norm \
  sbatch scripts/ver_k/run_qwen3_4b_maze_ver_k_retry_grpo_sglang.sbatch
```

## Verifier + interaction configuration

The core interaction is defined in `verl/interactions/ver_k_retry_interaction.py`.
Important config fields (see `examples/sglang_multiturn/config/interaction_config/`):

- `verifier_type`: `gsm8k_strict`, `exact_match`, or `callable`
- `verifier_fn_path`: python path to a callable verifier
- `answer_extraction`: `boxed`, `answer_tag`, `gsm8k_hashes`, `last_line`
- `include_history`: if true, previous wrong attempts are appended to feedback
- `track_turn_metadata`: logs per-turn success and reward in the batch
- `log_attempts`: write sampled interactions to `outputs/ver_k_retry_attempts*.log`

Math config uses `boxed_answer_reward.verifier_fn`, maze config uses
`maze_path_reward.verifier_fn`.

## Ver@K-aware advantage estimators

All Ver@K estimators are implemented in `verl/trainer/ppo/core_algos.py` and
activated by `algorithm.adv_estimator`:

- `grpo_vectorized`
- `grpo_verk_step_reward_step_norm` (Step-level normalization across trajectories that reached each attempt.)
- `grpo_verk_step_reward_step_norm_reweight_future_only` (Step-level normalization with future-only weighting.)

## Metrics and logging

During multi-turn rollouts we compute and log:

- Per-turn success rates (`turn_success/p_turnX_cond`)
- Assistant turns per episode (`assistant_turns/mean`)
- Optional step-level weights and reward statistics for Ver@K estimators

Attempt logs (for debugging) are saved to:

- `outputs/ver_k_retry_attempts.log` (math)
- `outputs/ver_k_retry_attempts_maze.log` (maze)

## Analysis: W&B plots

Use the helper script to compare estimators and produce plots:

```
python scripts/analysis/plot_wandb_ver_k_estimators.py \
  --entity <your_wandb_entity> \
  --project verl_ver_k_retry_math \
  --output outputs/plots/qwen3_4b_math_ver_k2_n20_estimators
```

## Acknowledgements

This work is built on top of verl. Please refer to the upstream project for
full documentation and additional training infrastructure.
