Summary: This tutorial provides implementation guidance for fine-tuning BGE (BAAI General Embedding) models using the FlagEmbedding library. It covers essential technical details for model configuration, data processing, and training parameters, specifically focusing on embedding model fine-tuning tasks. Key functionalities include setting up model checkpoints, handling data formats (query-positive-negative triplets), managing sequence lengths, and configuring training parameters like learning rate and batch size. The tutorial is particularly useful for tasks involving custom embedding model training, with specific features including knowledge distillation, gradient checkpointing, FP16 training, and distributed training optimization using deepspeed.

# Fine-tuning

In the previous section, we went through how to construct training and testing data properly. In this tutorial, we will actually fine-tune the model.

## Installation

Note to fine-tune BGE models using FlagEmbedding, we need to install the package with the finetune dependency:


```python
% pip install -U FlagEmbedding[finetune]
```

## Fine-tune

Below are the arguments for fine-tuning:

The following arguments are for model:
- `model_name_or_path`: The model checkpoint for initialization.
- `config_name`: Pretrained config name or path if not the same as model_name.
- `tokenizer_name`: Pretrained tokenizer name or path if not the same as model_name.
- `cache_dir`: Where do you want to store the pre-trained models downloaded from s3.
- `trust_remote_code`: Trust remote code
- `token`: The token to use when accessing the model.

The following arguments are for data:
- `train_data`: One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data. Argument type: multiple.
- `cache_path`: Where do you want to store the cached data.
- `train_group_size`: (No metadata provided)
- `query_max_len`: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.
- `passage_max_len`: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.
- `pad_to_multiple_of`: If set will pad the sequence to be a multiple of the provided value.
- `max_example_num_per_dataset`: The max number of examples for each dataset.
- `query_instruction_for_retrieval`: Instruction for query.
- `query_instruction_format`: Format for query instruction.
- `knowledge_distillation`: Use knowledge distillation when `pos_scores: List[float]` and `neg_scores: List[float]` are in features of training data.
- `passage_instruction_for_retrieval`: Instruction for passage.
- `passage_instruction_format`: Format for passage instruction.
- `shuffle_ratio`: The ratio of shuffling the text.
- `same_dataset_within_batch`: All samples in the same batch comes from the same dataset.
- `small_threshold`: The threshold of small dataset. All small dataset in the same directory will be merged into one dataset.
- `drop_threshold`: The threshold for dropping merged small dataset. If the number of examples in the merged small dataset is less than this threshold, it will be dropped.

And the following extra arguments:
- `negatives_cross_device`: Share negatives across devices.
- `temperature`: Temperature used for similarity score.
- `fix_position_embedding`: Freeze the parameters of position embeddings.
- `sentence_pooling_method`: The pooling method. Available options: cls, mean, last_token. Default: cls.
- `normalize_embeddings`: Whether to normalize the embeddings.
- `sub_batch_size`: Sub batch size for training.
- `kd_loss_type`: The loss type for knowledge distillation. Available options: kl_div, m3_kd_loss. Default: kl_div.


```bash
%%bash
torchrun --nproc_per_node 2 \
	-m FlagEmbedding.finetune.embedder.encoder_only.base \
	--model_name_or_path BAAI/bge-large-en-v1.5 \
    --cache_dir ./cache/model \
    --train_data ./ft_data/training.json \
    --cache_path ./cache/data \
    --train_group_size 8 \
    --query_max_len 512 \
    --passage_max_len 512 \
    --pad_to_multiple_of 8 \
    --query_instruction_for_retrieval 'Represent this sentence for searching relevant passages: ' \
    --query_instruction_format '{}{}' \
    --knowledge_distillation False \
	--output_dir ./test_encoder_only_base_bge-large-en-v1.5 \
    --overwrite_output_dir \
    --learning_rate 1e-5 \
    --fp16 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 2 \
    --dataloader_drop_last True \
    --warmup_ratio 0.1 \
    --gradient_checkpointing \
    --deepspeed config/ds_stage0.json \
    --logging_steps 1 \
    --save_steps 1000 \
    --negatives_cross_device \
    --temperature 0.02 \
    --sentence_pooling_method cls \
    --normalize_embeddings True \
    --kd_loss_type kl_div
```
