Summary: This tutorial provides implementation details for the BGE (BAAI General Embedding) model, specifically focusing on text embedding generation. It demonstrates two implementation approaches: a detailed PyTorch-based implementation using Transformers library and a simplified version using FlagEmbedding. The tutorial covers essential techniques including proper model initialization, CLS token pooling (emphasized as critical for performance), and L2 normalization of embeddings. Key functionalities include handling both single and batch text inputs, proper tokenization with max length of 512, and embedding generation with BERT-base architecture (12 layers, 768 hidden dimensions). This guide is particularly useful for tasks requiring text embedding generation and semantic search implementations.

# BGE Explanation

In this section, we will go through BGE and BGE-v1.5's structure and how they generate embeddings.

## 0. Installation

Install the required packages in your environment.


```python
%%capture
%pip install -U transformers FlagEmbedding
```

## 1. Encode sentences

To know how exactly a sentence is encoded, let's first load the tokenizer and model from HF transformers instead of FlagEmbedding


```python
from transformers import AutoTokenizer, AutoModel
import torch

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")

sentences = ["embedding", "I love machine learning and nlp"]
```

Run the following cell to check the model of bge-base-en-v1.5. It uses BERT-base as base model, with 12 encoder layers and hidden dimension of 768.

Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure.


```python
model.eval()
```

First, let's tokenize the sentences.


```python
inputs = tokenizer(
    sentences, 
    padding=True, 
    truncation=True, 
    return_tensors='pt', 
    max_length=512
)
inputs
```

From the results, we can see that each sentence begins with token 101 and ends with 102, which are the `[CLS]` and `[SEP]` special token used in BERT.


```python
last_hidden_state = model(**inputs, return_dict=True).last_hidden_state
last_hidden_state.shape
```

Here we implement the pooling function, with two choices of using `[CLS]`'s last hidden state, or the mean pooling of the whole last hidden state.


```python
def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):
    if pooling_method == 'cls':
        return last_hidden_state[:, 0]
    elif pooling_method == 'mean':
        s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
        d = attention_mask.sum(dim=1, keepdim=True).float()
        return s / d
```

Different from more commonly used mean pooling, BGE is trained to use the last hidden state of `[CLS]` as the sentence embedding: 

`sentence_embeddings = model_output[0][:, 0]`

If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors.


```python
embeddings = pooling(
    last_hidden_state, 
    pooling_method='cls', 
    attention_mask=inputs['attention_mask']
)
embeddings.shape
```

Assembling them together, we get the whole encoding function:


```python
def _encode(sentences, max_length=512, convert_to_numpy=True):

    # handle the case of single sentence and a list of sentences
    input_was_string = False
    if isinstance(sentences, str):
        sentences = [sentences]
        input_was_string = True

    inputs = tokenizer(
        sentences, 
        padding=True, 
        truncation=True, 
        return_tensors='pt', 
        max_length=max_length
    )

    last_hidden_state = model(**inputs, return_dict=True).last_hidden_state
    
    embeddings = pooling(
        last_hidden_state, 
        pooling_method='cls', 
        attention_mask=inputs['attention_mask']
    )

    # normalize the embedding vectors
    embeddings = torch.nn.functional.normalize(embeddings, dim=-1)

    # convert to numpy if needed
    if convert_to_numpy:
        embeddings = embeddings.detach().numpy()

    return embeddings[0] if input_was_string else embeddings
```

## 2. Comparison

Now let's run the function we wrote to get the embeddings of the two sentences:


```python
embeddings = _encode(sentences)
print(f"Embeddings:\n{embeddings}")

scores = embeddings @ embeddings.T
print(f"Similarity scores:\n{scores}")
```

Then, run the API provided in FlagEmbedding:


```python
from FlagEmbedding import FlagModel

model = FlagModel('BAAI/bge-base-en-v1.5')

embeddings = model.encode(sentences)
print(f"Embeddings:\n{embeddings}")

scores = embeddings @ embeddings.T
print(f"Similarity scores:\n{scores}")
```

As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the [source code](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/inference/embedder/encoder_only/base.py) for more details.
