## Interactive Learning for LLM Reasoning
[![Framework: PyTorch](https://img.shields.io/badge/Framework-PyTorch-orange.svg)](https://pytorch.org/) 

<p align="center">
  <img src="method.png"/ width="1200">
</p>

This is the official implementation of paper "Interactive Learning for LLM Reasoning" and we build our framework based on [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF).


**1. Data**

We put the training (Train.jsonl) and validation data (Valid.jsonl), which contains continuous difficulty level measured by LLM self-ranking, in the ./data.

**2. Conda environment**

```shell
conda create -n ILR python=3.10

pip install -r requirements.txt
```

**3. Replace path**

Replace your own paths/parameters in shell files (examples/scripts). 

Take train_ilr_group1_ray.sh for illustration.

```text
--LLM1LR learning rate of LLM1
--LLM1KL kl coefficient of LLM1
--LLM2LR learning rate of LLM2
--LLM2KL kl coefficient of LLM2
--TEMP temperature
--TAG tag of your checkpoints
--temp-dir the path for temporary output of ray
--num-gpus the number of your gpus. The minimum gpus setting will be 4 A100s.
--working_dir the current work dir
--pretrain_llm1 the path of pretrain llm1
--save_path_llm1 checkpoint path of llm1
--ckpt_path_llm1 checkpoint path of llm1
--use_tensorboard_llm1 tensorboard path of llm1
--pretrain_llm2 the path of pretrain llm2
--save_path_llm2 checkpoint path of llm2
--ckpt_path_llm2 checkpoint path of llm2
--use_tensorboard_llm2 tensorboard path of llm2
--reward_pretrain the path of reward model
--eval_ability_file the path of validation file (Valid.jsonl)
--prompt_data the path of training file (Train.jsonl)
```

**4. ILR Training**

The group combination in our paper is:

    - Group1 (different series, same scale): Llama-3.1-8B-Instruct and Qwen2.5-7B-Instruct
    - Group2 (different series, different scale): Llama-3.1-8B-Instruct and Qwen2.5-14B-Instruct
    - Group3 (same series, different scale): Qwen2.5-7B-Instruct and Qwen2.5-14B-Instruct

We can use the following command to train LLMs:

```shell
bash examples/scripts/train_ilr_group1_ray.sh
```

**4. ILR Evaluation**

For mathematical benchmarks, we use [qwen-math](https://github.com/QwenLM/Qwen2.5-Math) as the evaluation tool.

For code benchmark MBPP, we use [ms-swift](https://github.com/modelscope/ms-swift) as the evaluation tool.

Please refer their instructions and evaluate your checkpoints using the corresponding datasets (GSM8K, MATH-500, Minerva Math, Olympiad Bench, AIME, MBPP) in our paper.