## 📖 Documents

### 🌟 **Get Started**
#### Install Environment 
```
conda create -n BabelRS python=3.9
conda activate BabelRS
pip install -r requirements.txt
```

Install ```flash-attn==2.3.6``` (optional, for training chat models):

```
pip install flash-attn==2.3.6 --no-build-isolation
```

Alternatively you can compile from source:
```
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
git checkout v2.3.6
python setup.py install
```
#### Prepare Pretraining Datasets and ckpt
Download the pretraining datasets from 
1. ViTP dataset <a href="https://www.modelscope.cn/datasets/GreatBird/ViTP/files"><img src="https://img.shields.io/badge/ModelScope-Data-624aff"></a> or <a href="https://huggingface.co/GreatBird/ViTP"><img src="https://img.shields.io/badge/HuggingFace-Data-ffd21e?logo=huggingface"></a>. To be noticed, the txt files under ```pretrain_data/images``` contains the dataset download URLs.
2. MMRS-1M dataset at this [link](https://github.com/wivizhang/EarthGPT).
3. SARLang dataset at this [link](https://github.com/Jimmyxichen/SARLANG-1M).

Download the pretraining checkpoint from (will release after the paper acceptance).

#### Start Pretraining:
Place the pretraining datasets and annotations according to the paths in ```babelrs_configs/BabelRS_ft_data_instruct.json```, then run:
```
GPUS=8 PER_DEVICE_BATCH_SIZE=1 sh babelrs_configs/BabelRS_internvl_1b_20ksteps.sh
```

**For more details, please refer to the [official documentation](https://github.com/OpenGVLab/InternVL).**

## Quick Start

We provide an example code to run `InternVL2_5-8B` using `transformers`.

> Please use transformers>=4.37.2 to ensure the model works normally.

### Model Loading

#### 16-bit (bf16 / fp16)

```python
import torch
from transformers import AutoTokenizer, AutoModel
path = "OpenGVLab/InternVL2_5-8B"
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval().cuda()
```

#### BNB 8-bit Quantization

```python
import torch
from transformers import AutoTokenizer, AutoModel
path = "OpenGVLab/InternVL2_5-8B"
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    load_in_8bit=True,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval()
```

#### Multiple GPUs

The reason for writing the code this way is to avoid errors that occur during multi-GPU inference due to tensors not being on the same device. By ensuring that the first and last layers of the large language model (LLM) are on the same device, we prevent such errors.

```python
import math
import torch
from transformers import AutoTokenizer, AutoModel

def split_model(model_name):
    device_map = {}
    world_size = torch.cuda.device_count()
    num_layers = {
        'InternVL2_5-1B': 24, 'InternVL2_5-2B': 24, 'InternVL2_5-4B': 36, 'InternVL2_5-8B': 32,
        'InternVL2_5-26B': 48, 'InternVL2_5-38B': 64, 'InternVL2_5-78B': 80}[model_name]
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

path = "OpenGVLab/InternVL2_5-8B"
device_map = split_model('InternVL2_5-8B')
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map).eval()
```
