<p align="center" width="100%">
</p>

<div id="top" align="center">

FuseChat: Knowledge Fusion of Chat Models
-----------------------------
<img src="https://img.shields.io/badge/License-Apache_2.0-green.svg" alt="License">
</div>

## Requirements

We use `python 3.11` in this project.

Then, we have to install all the libraries listed in `requirements.txt`.

```bash
pip install -r requirements.txt
```

## Training

We conduct experiments using six representative chat LLMs as the source LLMs, 
including [OpenChat-3.5-7B](https://huggingface.co/openchat/openchat_3.5), [Starling-LM-7B-alpha](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha), 
[NH2-SOLAR-10.7B](https://huggingface.co/NousResearch/Nous-Hermes-2-SOLAR-10.7B), [InternLM2-Chat-20B](https://huggingface.co/bartowski/internlm2-chat-20b-llama-old), 
[Mixtral-8x7B-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), [Qwen-1.5-Chat-72B](https://huggingface.co/Qwen/Qwen1.5-72B-Chat). 
As for the pivot LLM, which also serves as the starting point for the target LLMs, we opt for OpenChat-3.5-7B due to its balanced scale and performance. 
You should download all these models and place them into ```/models``` before experiments.

We curated a comprehensive training dataset, FuseChat-Mixture, from various sources. 
This dataset covers different styles and capabilities, featuring both human-written and model-generated, and spanning general instruction-following and specific skills. 
The dataset is provided in```data/fusechat_v1_clean_split_2048_filter_wrong.json```

Here, we perform 4 steps to get our target LLMs.

### 1. Get representations for each source LLM

Here we show the scripts to obtain representations from multiple source LLMs.

```bash
# We split the dataset into 4 splits, then process each split on one or multiple GPUs.

# OpenChat-3.5-7B Starling-LM-7B-alpha Nous-Hermes-2-SOLAR-10.7B internlm2-chat-20b Mixtral-8x7B-Instruct-v0.1 Qwen1.5-72B-Chat
export CUDA_VISIBLE_DEVICES=xx # specify one or multiple GPUs
PROJ_PATH=xx # specify your own project path
DATA_NAME="fusechat_v1_clean_split_2048_filter_wrong" 
MODEL_NAME=openchat_3.5 # Starling-LM-7B-alpha Nous-Hermes-2-SOLAR-10.7B internlm2-chat-20b Mixtral-8x7B-Instruct-v0.1 Qwen1.5-72B-Chat

for i in {0..3}; do
python ${PROJ_PATH}/train/get_data_representation.py \
    --model_name_or_path ${PROJ_PATH}/models/${MODEL_NAME} \
    --data_path ${PROJ_PATH}/data/${DATA_NAME}.json \
    --dataset_save_dir ${PROJ_PATH}/representations/${MODEL_NAME}_representation_split${i} \
    --tknz_dataset_path ${PROJ_PATH}/representations/${MODEL_NAME}_representation_tknz_split${i} \
    --cache_dir ${PROJ_PATH}/.cache/huggingface/datasets \
    --model_max_length 2048 \
    --load_in_half bf16 \
    --batch_size 32 \
    --top_k_logits 10 \
    --save_per_token_metric \
    --no_assert \
    --conv_temp "openchat" \
    --mask_instruction \
    --dataset_split_num 4 \
    --dataset_index ${i} \
    --get_representation \
    --device_map "auto" \
done
```

### 2. Align representations from different source LLMs

Here we show the scripts to align representations from different source LLMs to pivot LLM.

#### 2.1. LLMs with same vocabs

For source LLMs share the same vocab as pivot LLM, we only merge their representations into a single dataset.
```bash
# Pivot LLM:OpenChat-3.5-7B <->Source LLMs:Starling-LM-7B-alpha Nous-Hermes-2-SOLAR-10.7B Mixtral-8x7B-Instruct-v0.1
PROJ_PATH=xx # specify your own project path
PIVOT_NAME=openchat_3.5  # Pivot LLM
SOURCE_NAME=Starling-LM-7B-alpha # Source LLMs: Starling-LM-7B-alpha Nous-Hermes-2-SOLAR-10.7B Mixtral-8x7B-Instruct-v0.1
for i in {0..3}; do
python ${PROJ_PATH}/train/replace_model.py \
  --dataset_dir ${PROJ_PATH}/representations/${PIVOT_NAME}_representation_split${i} \
  --replace_dataset_dir ${PROJ_PATH}/representations/${SOURCE_NAME}_representation_split${i} \
  --dataset_save_dir ${PROJ_PATH}/representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split${i} \
  --preprocessing_num_workers 32 \
  --batch_size 1000 \
  --replace_model model_0
done 

```
#### 2.2. LLMs with different vocabs
For source LLMs have different vocabs with pivot LLM, we need to do token alignment and distribution alignment.
```bash
# Pivot LLM:OpenChat-3.5-7B <->Source LLMs:internlm2-chat-20b Qwen1.5-72B-Chat
PROJ_PATH=xx # specify your own project path
PIVOT_NAME=openchat_3.5 # Pivot LLM
SOURCE_NAME=internlm2-chat-20b # Source LLMs: internlm2-chat-20b Qwen1.5-72B-Chat
align_type="default" # hard -> EM, soft -> MinED, default -> MS
token_alignment_matrix_file=${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_token_sparse_matrix_${align_type}.npz
blending_to_base_file=${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_token_mapping_${align_type}.json

# token alignment
python ${PROJ_PATH}/train/align_token_and_vocab.py \
    --align_type ${align_type} \
    --base_model_name_or_path ${PROJ_PATH}/models/${PIVOT_NAME} \
    --blending_model_name_or_path ${PROJ_PATH}/models/${SOURCE_NAME} \
    --base_dataset_dir "${PROJ_PATH}/representations/${PIVOT_NAME}_representation_tknz_split0,${PROJ_PATH}/representations/${PIVOT_NAME}_representation_tknz_split1,${PROJ_PATH}/representations/${PIVOT_NAME}_representation_tknz_split2,${PROJ_PATH}/representations/${PIVOT_NAME}_representation_tknz_split03" \
    --blending_dataset_dir "${PROJ_PATH}/representations/${SOURCE_NAME}_representation_tknz_split0,${PROJ_PATH}/representations/${SOURCE_NAME}_representation_tknz_split1,${PROJ_PATH}/representations/${SOURCE_NAME}_representation_tknz_split2,${PROJ_PATH}/representations/${SOURCE_NAME}_representation_tknz_split3" \
    --aligned_dataset_save_dir ${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_tknz\
    --model_max_length 2048 \
    --preprocessing_num_workers 32 \
    --batch_size 16 \
    --token_alignment_matrix_file ${token_alignment_matrix_file} \
    --blending_to_base_file ${blending_to_base_file} \
    --do_token_alignment \
    --blending_model_index 0 \
    --metric_level "sequence"

# distribution alignment
for i in {0..3}; do
python ${PROJ_PATH}/train/align_token_and_vocab.py \
   --align_type ${align_type} \
   --base_model_name_or_path ${PROJ_PATH}/models/${PIVOT_NAME} \
   --blending_model_name_or_path ${PROJ_PATH}/models/${SOURCE_NAME} \
   --base_dataset_dir ${PROJ_PATH}/representations/${PIVOT_NAME}_representation_split${i} \
   --blending_dataset_dir ${PROJ_PATH}/representations/${SOURCE_NAME}_representation_split${i} \
   --aligned_dataset_save_dir ${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split${i}\
   --model_max_length 2048 \
   --preprocessing_num_workers 32 \
   --batch_size 16 \
   --temperature 0.5 \
   --token_alignment_matrix_file ${token_alignment_matrix_file} \
   --blending_to_base_file ${blending_to_base_file} \
   --do_distribution_alignment \
   --blending_model_index 0 \
   --metric_level "sequence" \
   --use_token_alignment_matrix \
done

```

### 3 Filter instances with NaN loss in the dataset

```bash
for i in {0..3}; do
python ${PROJ_PATH}/train/filter_nan.py \
  --input_data_dir ${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split${i} \
  --output_data_dir ${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split${i}_fnan \
  --model_index 0
done
```

The processed representations are in `${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split${i}_fnan`.

### 4 Pairwise Knowledge Fusion

We show the scripts for pairwise knowledge fusion with the processed representations.

```bash
# OpenChat-3.5-7B <-> Starling-LM-7B-alpha Nous-Hermes-2-SOLAR-10.7B internlm2-chat-20b Mixtral-8x7B-Instruct-v0.1 Qwen1.5-72B-Chat
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
torchrun --nproc_per_node=8 --master_port=20001 ${PROJ_PATH}/train/train.py \
  --model_name_or_path "openchat/openchat_3.5" \
  --data_path "${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split0_fnan,${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split1_fnan,${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split2_fnan,${PROJ_PATH}/aligned_representations/${PIVOT_NAME}_${SOURCE_NAME}_representation_split3_fnan" \
  --bf16 True \
  --output_dir "${PROJ_PATH}/checkpoints/${PIVOT_NAME}_${SOURCE_NAME}_pairwise_fusion_ckpt" \
  --num_train_epochs 3 \
  --per_device_train_batch_size 4 \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 4 \
  --evaluation_strategy "no" \
  --save_strategy "epoch" \
  --save_steps 10000 \
  --save_total_limit 5 \
  --learning_rate 5e-6 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --fsdp "full_shard auto_wrap" \
  --fsdp_transformer_layer_cls_to_wrap 'MistralDecoderLayer' \
  --tf32 True \
  --model_max_length 2048 \
  --gradient_checkpointing True \
  --conv_temp "openchat" \
  --lazy_preprocess True \
  --flash_attn_transformers True \
  --do_train \
  --do_distill \
  --distill_with_ref_model True \
  --distill_with_aligned_model_0 True \
  --distill_with_aligned_model_1 False \
  --distill_loss_type "ce" \
  --distill_teacher_temperature 1.0 \
  --lm_loss_weight 0.9 \
  --distill_greater_as_gt True \
  --distill_greater_as_gt_type hard \
  --dataloader_num_workers 8 \
  --remove_unused_columns False
```

## Model Merging

We show the scripts on how to get FUSECHAT from target LLMs using different merging methods.


### Setup

Before merging, please install our modified ["mergekit"](https://github.com/arcee-ai/mergekit).
```bash
cd mergekit
pip install -e .
```
### Usage
#### 1.Reproducing our method SCE
The code implementation of our method is in 
```mergekit/mergekit/merge_methods/sce_merging.py```
```bash
model_save_dir=xx # specify your path to save the merged models
mergekit-yaml mergekit/fusechat_configs/fusechat-sce.yml ${model_save_dir}/FUSECHAT-7B-SCE
```

#### 2.Reproducing our experiments on comparison of different merging methods
```bash
model_save_dir=xx # your path to save the merged models
mergekit-yaml mergekit/fusechat_configs/fusechat-linear.yml ${model_save_dir}/FUSECHAT-7B-LINEAR

mergekit-yaml mergekit/fusechat_configs/fusechat-ta.yml ${model_save_dir}/FUSECHAT-7B-TA

mergekit-yaml mergekit/fusechat_configs/fusechat-ties.yml ${model_save_dir}/FUSECHAT-7B-TIES

mergekit-yaml mergekit/fusechat_configs/fusechat-dare.yml ${model_save_dir}/FUSECHAT-7B-DARE
```

## Evaluation
We conduct experiments on two representative benchmarks named AlpacaEval 2.0 and MT-Bench to evaluate the instruction-following and multi-turn conversation capabilities.
### MT-Bench
MT-Bench comprises 80 multi-turn dialogues spanning writing, roleplay, reasoning, math, coding, stem, and humanities domains.The original benchmark uses GPT-4-0613 as the evaluator to provide a scalar score ranging from 1 (lowest) to 10 (highest) for the generated responses.
However, due to inaccuracies in the reference responses generated by the old GPT-4-0613, we follow [the latest works](https://github.com/lm-sys/FastChat/pull/3158) to adopt an updated GPT-4-0125-Preview to correct these errors and evaluate the generated responses.
#### Usage
Please download the [official code](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge) and follow the 
guidelines for evaluation. We provide the scripts for our evaluation. 

```bash
# Step 1. Generate model answers to MT-bench questions
export CUDA_VISIBLE_DEVICES=0,1
python gen_model_answer.py \
  --model-path "FuseChat-7B-SCE" \
  --model-id "openchat_3.5_fusechat_7b_sce" \
  --num-gpus-per-model 1 \
  --num-gpus-total 2

# Step 2. Generate GPT-4-0125-Preview judgments.
# To use GPT-4-0125-Preview as judge model,you should fitst download gpt-4-0125-preview.jsonl(https://github.com/lm-sys/FastChat/pull/3158/files) 
# and place it in llm_judge/data/mt_bench/reference_answer.Then,add "gpt-4-0125-preview" as a valid judge model in common.py.
export OPENAI_API_KEY=XXXXXX  # set the OpenAI API key
python gen_judgment.py \
  --model-list "openchat_3.5_fusechat_7b_sce" \
  --judge-model "gpt-4-0125-preview" \
  --parallel 8

# Step 3. Show MT-bench scores
python show_result.py --model-list "openchat_3.5_fusechat_7b_sce" 
```

### AlpacaEval 2.0
AlpacaEval 2.0, contains 805 instructions from five test subsets. This benchmark compares the Win Rate and Length-Controlled Win Rate (LC Win Rate) against GPT-4. 
We follow the default settings to employ GPT-4-1106-Preview to assess the quality of generated responses. 
#### Usage
Please download the [official code](https://github.com/tatsu-lab/alpaca_eval) and follow the 
guidelines for evaluation.And the prompt for generation is:
```
GPT4 Correct User: {instruction}<|end_of_turn|>GPT4 Correct Assistant: 
```
## Safety
Our model may sometimes generate harmful, hate speech, biased responses, or answer unsafe questions. It's crucial to apply additional AI safety measures in use cases that require safe and moderated responses.
