Thank you for spending time reviewing our work! Since we spent a lot of time adopting the code with mtp mode, there may be a lot of changes in the code of open-source frameworks we use. Here we briefly go through the data and code structure of our work. 

# Overall
- **MTP Implementation**: Since there is no existing implementation on github.com about mtp, we work on opensource framework transformers==4.51.1 to realize the causal mtp modules using the way described in Deepseek-V3.

- **MPO Implementation**: For RL framework, we work on opensource framework verl==0.3.1 and change its source code to adopt multi-token logprobs calculation.

- **Dataset**: In our work, we only use two open-source math datasets GSM8K and MATH which can be easily found on huggingface.co. We use the training sets of those two dataset as the warm-up dataset for mtp modules.

- **Models**: This paper only using opensource models which can be downloaded from huggingface.co

# Environment Setup
We implement our MPO with CUDA==1.24 and python==3.11. We can setup the basic environment for MPO using:

1.  pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

2. pip install transformers==4.51.1

3. pip install vllm==0.8.2

4. pip3 install flash-attn==2.7.3 --no-build-isolation

5. cd path_of_verl

6. pip3 install -e .[vllm]

7. pip install trl

# Implementation Setup

- **MTP source code**: We have provided the code of our implementation of mtp module based on Qwen2.5 in file "code/qwen2_mtp" for reference, you may override the qwen implementation codes of transformers library using the code in "code/qwen2_mtp/causal_heads" to enable MTP modules. Besides, since we change some of the output configuration, you also need to override the transformers library source files provided in "code/common", the "code/common/config.json" should override you downloaded qwen2.5 model's config.json.

- **Verl MPO source code**: We adjust the source code of verl to acquire and record mtp logits during MPO training base on the original PPO trainer. We provide the key file "verl/verl/trainer/ppo/core_algos.py" here just for reference. 

- **Important Note**: We tried our best to delete the meta information in our code, so there may be some missing paths in the code.

# MPO Process

## Warmup MTP Modules

Finetune the model using the training sets of GSM8K and MATH, enable the gradient of weights only when its parameter name contains "mtp". Before training, remember to call model.load_state_dict() to use the checkpoint weights to initialize the MTP modules.

## Prepare Data for MPO
Exactly the same following the original verl (see its introduction on github.com). Run 

> python verl/examples/data_preprocess/gsm8k.py

or

>python verl/examples/data_preprocess/math_dataset.py

for the corresponding dataset.

## Runing MPO

Run

> sh verl/examples/ppo_trainer/run_deepseek7b_llm.sh

To run the best performance we got so far:

For Deepseek-Distilled-qwen2.5-1.5b on GSM8K, run with 1 mtp module with MPO ratio weight $\beta$ of 0.9:0.1.

For Deepseek-Distilled-qwen2.5-1.5b on MATH, run with 1 mtp module with MPO ratio weight $\beta$ of 0.925:0.075.

Note that we did not use any training tricks for both the baseline and MPO, just using the default settings in the script run_deepseek7b_llm.sh.

# Implementation Note

Please align the settings in the config.json of the using LLM with the MPO process, especially the number of mtp modules. Specifically, if you change the number of mtp modules in config.json of the LLM checkpoint, please adjust the code in "verl/verl/trainer/ppo/core_algos.py" to assign the same amount of ratio weights to those mtp heads.