# Continual LLaVA: Continual Instruction Tuning in Large Vision-Language Models


## Declaration

This repository contains the source codes of Continual LLaVA. Due to the policy of ICLR 2025, we avoid providing any explicit external links throughout the whole repository to prevent identity leakage. 

Note that we have not cleaned and reconstructed the code. Therefore, the current version needs "manual modifications" of some "hard coded parts" within the code, as specified in the following sections. Nevertheless, we ensure reproducibility.


## Contents
- [Install and Preparation](#install)
- [Dataset](Data.md)
- [Train](#train)
- [Inference](#inference)
- [Evaluation](#evaluation)

## Install

1. Navigate to LLaVA folder
```bash
cd Continual_LLaVA
```

2. Install Package
```Shell
conda create -n llava_continual python=3.10 -y
conda activate llava_continual
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
```

3. Install additional packages for training cases
```
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
```

### Upgrade to latest code base

```Shell
git pull
pip install -e .

# if you see some import errors when you upgrade,
# please try running the command below (without #)
# pip install flash-attn --no-build-isolation --no-cache-dir
```
### Download clip model and initial checkpoint
Download clip-vit-large-patch14-336 and llava-v1.5-7b models from huggingface and put them under `checkpoints` folder.

## Dataset
Please refer to Data.md

## Train
### Stage1: Surrogate-Proxy Embeddings Alignment

- Modify `self.task_pool_index_range` in `llava/model/language_model/llava_llama.py` to desired task.

- Run the followings.

```Shell
bash scripts/v1.5/pretrain_alignment.sh

python llava/train/merge_pretrained_key.py
```

### Stage2: Dual Increment Embedding Tuning

- Modify `self.task_pool_index_range` in `llava/model/language_model/llava_llama.py` to desired task.

- Set `self.retriever_state_dict` in `llava/model/language_model/llava_llama.py` to the output path of the surrogate-proxy embeddings alignment in stage 1.

- Set `self.disable_task_id = False` in `llava/model/language_model/llava_llama.py`.

Make sure you have at least 4 A100 GPUS and run:

- For COAST-domain
```Shell
bash scripts/v1.5/finetune_continual_llava_domain.sh
```
- For COAST-capability
```Shell
bash scripts/v1.5/finetune_continual_llava_capability.sh
```
- For COAST-dataset
```Shell
bash scripts/v1.5/finetune_continual_llava_dataset.sh
```

## Inference

- Set `self.disable_task_id = True` in `llava/model/language_model/llava_llama.py`

- For COAST-domain: modify the output model checkpoint path of each .sh file (e.g. scripts/v1.5/eval/chartqa.sh) and then:
```Shell
bash scripts/v1.5/eval/continual_llava_eval_domain.sh
```

- For COAST-capability: modify the output model checkpoint path of each .sh file(e.g. scripts/v1.5/eval/conversation.sh) and then:
```Shell
bash scripts/v1.5/eval/continual_llava_eval_capability.sh
```

- For COAST-dataset: modify the output model checkpoint path of each .sh file(e.g. scripts/v1.5/eval/VQAv2.sh) and then:
```Shell
bash scripts/v1.5/eval/continual_coin_eval.sh

bash scripts/v1.5/eval/continual_llava_eval_dataset.sh

```

## Evaluation
- Use `scripts/convert_answer_for_gpt_eval.py` to transform data format.
- Then use `python3 llava/eval/eval_qa_gpt.py` to get gpt eval score and accuracy.

## 📜 License

This project is released under the Apache 2.0 license.<br>
