# GRPO-MA: Multi-Answer Generation in GRPO for Stable and Efficient Chain-of-Thought Training

## 📋 Table of Contents

- [Installation](#installation)
- [Quick Start](#quick-start)
- [Project Structure](#project-structure)
- [Task Configuration](#task-configuration)
- [Training](#training)
- [Adding New Tasks](#adding-new-tasks)

## 🚀 Installation


```bash
cd grpo-ma
conda create -n grpo-ma python=3.10
conda activate grpo-ma
pip install -r requirements.txt
pip install -e .
```


## 🏃 Quick Start

### Training Trajectory Prediction

#### Downloading Dataset

```
mkdir data
cd data
git clone https://huggingface.co/datasets/BAAI/ShareRobot
```

#### Downloading Base Weights

```
mkdir pretrained_weights
cd pretrained_weights
git clone https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct
```
#### Run the Shell
Suppose you have a machine equipped with N GPUs (for example, an 4x H100 machine). If you wish to run K thoughts, with each thought generating M answers, you can refer to the following method for replacing environment variables. In this scenario, each GPU is assigned K/N thoughts to be generated, and each GPU generates KM/N answer. 
```
export THINK_NUM=4
export ANSWER_NUM=4
bash scripts/run_grpo_lora.sh
```

The experiments in this paper were conducted on a 4×H100 machine, with K and M both set to 4. This means that only one thought was assigned to each GPU, along with its corresponding four answers. 

For more information on hyperparameters, please refer to [scripts/README.md](scripts/README.md).


## 📁 Project Structure

```
grpo-ma/
├── config/              # Training hyperparameter configuration files
├── dataset/             # Dataset implementations
│   └── dataset_grpo.py  # GRPO dataset class
├── model/               # Model components
│   ├── qwen_module.py   # Qwen model integration
│   ├── vlm_module.py    # Vision-language module interface
│   ├── reward_func.py   # Central reward function router
│   └── task_configs.py  # Task configuration loader
├── task/                # Modular task definitions
│   └── trajectory_sharerobot.py    
├── trainer/             # Training logic
│   └── grpo_ma_trainer.py  # Main GRPO-MA trainer
└── train.py             # Training entry point
```

➕ ## Adding A New Task

Adding a new task touches four pieces: dataset samples, a dataset manifest, the task module itself, and a quick verification pass before training. The framework auto-discovers anything you add under `task/`, so once the pieces below are in place the new task is ready for GRPO training.

### 1. Prepare dataset samples

- 📂 Place your raw assets (images, videos, point clouds, etc.) under `data/` or another folder that can be referenced at training time.
- 🗃️ Create a metadata file inside `metadata/` (JSON or JSONL). Each entry should contain at least:
    - `question`: the prompt shown to the model.
    - `answer`: the ground-truth answer (string, number, JSON, etc.).
    - `question_type`: must match the `task_type` you will define in the task module (lowercase with hyphens).
    - Optional media fields such as `image`, `video`, or task-specific attributes. All such values will be organized into an item by dataset_group and transmitted to reward_func for reward calculation.

Example (`metadata/grpo_sharerobot_trajectory_train.json`):

```json
{"question": "Which tool is highlighted?", "answer": "hammer", "image": "train/toolbox_001.jpg", "question_type": "my-task"}
{"question": "Locate the object.", "answer": "[120, 48, 256, 220]", "image": "train/toolbox_002.jpg", "question_type": "my-task"}
```

### 2. Describe the dataset with a YAML manifest

Training scripts read a YAML manifest that lists one or more metadata files. Create a new file such as `scripts/train/grpo_my_task.yaml` so that `train.py` knows how to sample your data. Like 

```yaml
datasets:
  - json_path: metadata/grpo_sharerobot_trajectory_train.json
    sampling_strategy: "all"
    data_root: data/ShareRobot/trajectory/images
    data_modality: image

```

Pass this YAML via `--dataset_name` (or set environment variables  `DATASET_NAME` before running `scripts/run_grpo_lora.sh`).

### 3. Implement the task module

1. Copy the template: `cp task/TEMPLATE.py task/my_task.py`.
2. Fill out `TASK_CONFIG`:
    - `task_type`: the identifier used everywhere (`question_type`, registries, logs).
    - `description`, `input_format`, `output_format`, and an optional `grpo_template` describing the prompt structure.
    - `evaluation_metrics` and `format_requirements` (used for logging and automatic checks).
    - The most important parameters here are task_type and grpo_template; the other parameters do not affect training.
3. Implement helper utilities if your task needs pre/post-processing (see `task/processing_utils.py` for scaling helpers).
    - Since the Qwen2.5-VL series models perform smart_resize on images (scaling the resolution to multiples of 28 and resizing based on max_pixels and min_pixels), the pixel coordinates corresponding to the answers also require resizing in certain tasks (such as Object Detection). 
4. Write the reward functions:
    - `format_reward`: Depends on the task output. For the Instruct model, it can be written more strictly.
    - `accuracy_reward` Depends on the task.
5. Register the functions. The loader calls `register(...)` automatically; keep the names in sync:


```python
def register(accuracy_registry, format_registry, answer_registry=None):
    task_name = TASK_CONFIG["task_type"]
    accuracy_registry[task_name] = accuracy_reward
    format_registry[task_name] = format_reward
    if answer_registry is not None:
        answer_registry[task_name] = process_answer  # optional
```

If you need to rescale answers (e.g., bounding boxes) implement `process_answer` and store it in `answer_registry` so the dataset loader can adapt answers after image resizing.

For details, refer to the comments in TEMPLATE.py.

### 4. Verify registration and launch training

- Make sure your new file is importable (`task/__init__.py` auto-loads everything in the folder). To double-check:

```bash
python -c "from task import print_summary; print_summary()"
```

- Point the trainer to your dataset manifest:

```bash
export DATASET_NAME=scripts/train/grpo_my_task.yaml
export RUN_NAME=Qwen2.5-VL-3B-GRPO-lora-my-task
bash scripts/run_grpo_lora.sh
```

During the first batches you should see logs for your `task_type` and the rewards you implemented. If something fails to load, rerun `python -c "from task import reload_tasks; reload_tasks(); print('done')"` to surface import errors.

Once these steps are complete, the task becomes available to both training and evaluation pipelines without further wiring.
