# Weighted-Reward Preference Optimization for Implicit Model Fusion

## Requirements

This repository includes a [requirements file](requirements.txt) specifying the Python package versions used in our experiments. We utilized `Python 3.10` and `CUDA 12.2` for this work.

## Training Data Construction
We use a subset of UltraFeedback provided in [princeton-nlp/llama3-ultrafeedback-armorm](https://huggingface.co/datasets/princeton-nlp/llama3-ultrafeedback-armorm) to construct our training dataset.

Our training dataset contains quadruples of $(x, y_{w_s}, y_{w_t}, y_l)$, where $y_{w_s}$ is a response from the Source LLMs, $y_{w_t}$ and $y_l$ are responses from the Target LLM.


### Target & Source LLMs
The Target and Source LLMs, along with their corresponding Huggingface IDs, are listed below:

| **Models**                   | **Huggingface ID**                                                                                                |
|------------------------------|-------------------------------------------------------------------------------------------------------------------|
| Target (LLaMA-3-8B-Instruct) | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)                 |
| Mistral-Large-Instruct-2407  | [mistralai/Mistral-Large-Instruct-2407](https://huggingface.co/mistralai/Mistral-Large-Instruct-2407)             |
| Gemma2-27B-IT                | [google/gemma-2-27b-it](https://huggingface.co/google/gemma-2-27b-it)                                             |
| Qwen2-72B-Instruct           | [Qwen/Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct)                                         |
| LLaMA3-70B-Instruct          | [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)               |
| Gemma2-9B-IT                 | [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it)                                               |
| Internlm2.5-20B-Chat         | [internlm/internlm2_5-20b-chat](https://huggingface.co/internlm/internlm2_5-20b-chat)                             |
| DeepSeek-V2-Chat             | [deepseek-ai/DeepSeek-V2-Chat-0628](https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628)                     |
| DeepSeek-Coder-V2-Instruct   | [deepseek-ai/DeepSeek-Coder-V2-Instruct-0724](https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct-0724) |
| Yi-1.5-34B-Chat              | [01-ai/Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)                                             |
| Phi-3-medium-4k-instruct     | [microsoft/Phi-3-medium-4k-instruct](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct)                   |

### Construction of $y_{w_s}$
1. For each prompt in the [Ultrafeedback](https://huggingface.co/datasets/princeton-nlp/llama3-ultrafeedback-armorm) dataset, we sample five responses from each source LLM. This was done using top-p sampling (p=0.95) with a temperature of 0.8, following the pipeline in the [SimPO repository](https://github.com/princeton-nlp/SimPO). 

2. We employ [ArmoRM-LLaMA3-8B-v0.1](https://huggingface.co/RLHFlow/ArmoRM-Llama3-8B-v0.1) as the reward model to score and rank these responses. The highest-scoring response across all source models is selected as one of the preferred responses, $y_{w_s}$. 


### Construction of $y_{w_t}$ and $y_l$

1. After the construction of $y_{w_s}$, we split the dataset into two parts: one-third of the data instances are randomly sampled for supervised fine-tuning (SFT), while the remaining instances are used for preference optimization, as detailed in our paper.

2. We apply Supervised Fine-Tuning (SFT) on the set of $y_{w_s}$ using first part of the dataset.

3. We then generate five samples from the SFT model using the remaining dataset. The response with the highest score is labeled as $y_{w_t}$, while the lowest-scoring response is regarded as $y_l$. 


Below is a example instance of our dataset, where "chosen" is a list of [$y_{w_s}$, $y_{w_t}$], and "rejected" is a list of [$y_l$].

```
{
    "prompt_id": "3ebac2832721f4ef8e9ead1bb19e251e5d21d60dbd9f89cae931fe4aac900058",
    "prompt": {
        "content": "Given the task definition and input, reply with output. The provided file includes inquiries about restaurants in Finnish, and we ask you to translate those to English language. Please bear in mind the following guidelines while doing the translation: 1) We are looking for the most naturally written and formal form of each sentence in your language. We are *NOT* looking for colloquial forms of the sentence. We are looking for formal form which is how you would type your queries in a text-based virtual assistant. 2) The words between quotation marks *SHOULD NOT* be translated. We expect you to keep those values intact and include the quotation marks around them as well. 3) The fully capitalized words like DATE_0, or DURATION_0 *SHOULD NOT* be translated. Please keep them as they are in the translations. 4) Please do not localize measurement units like miles to kilometers during your translation. miles should be translated to its equivalent in your language. 6) Note the input is all lowercased except for fully capitalized special placeholders (e.g. NUMBER, DATE, TIME). Please do the same in your translations.\n\nkerro minulle \" dave 's seafood shack \" -ravintolan sijainti.\n",
        "role": "user"
    },
    "chosen": [
        {
            "content": "Tell me the location of the \"dave's seafood shack\" restaurant.",
            "role": "assistant"
        },
        {
            "content": "Here is the translated output:\n\nTell me the location of \"Dave's Seafood Shack\".",
            "role": "assistant"
        }
    ],
    "rejected": [
        {
            "content": "Here's the translation:\n\nTell me the location of \"Dave's Seafood Shack\".",
            "role": "assistant"
        }
    ]
}
```
## Training configurations

We provide configuration files for both training stages, designed for an environment with 8x80G A800 GPUs. You may need to adjust `num_processes` and `per_device_train_batch_size` based on your specific computational setup.

### Commands

* To run Target-SFT:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file training_configs/deepspeed_zero3.yaml scripts/run_sft.py training_configs/llama-3-8b-instruct-sft.yaml
```

* To run Target-SFT-WRPO:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file training_configs/deepspeed_zero3.yaml scripts/run_wrpo.py training_configs/llama-3-8b-instruct-sft-wrpo.yaml
```
