# Retriever Training

This repository contains the code for training custom Colbert retriever models.
Notably, we train colbert with LLMs (decoders) as well as Image Language models !

## Installation

```bash
pip install -r requirements.txt
```

## Training

```bash
accelerate launch scripts/train/train_colbert.py scripts/configs/train_colidefics_model.yaml 
```

### Configurations
All training arguments can be set through a configuration file.
The configuration file is a yaml file that contains all the arguments for training.

The construction is as follows:

```python
@dataclass
class ColModelTrainingConfig:
    model: PreTrainedModel
    tr_args: TrainingArguments = None
    output_dir: str = None
    max_length: int = 256
    run_eval: bool = True
    run_train: bool = True
    peft_config: Optional[LoraConfig] = None
    add_suffix: bool = False
    processor: Idefics2Processor = None
    tokenizer: PreTrainedTokenizer = None
    loss_func: Optional[Callable] = ColbertLoss()
    dataset_loading_func: Optional[Callable] = None
```
### Example

An example configuration file is:

```yaml
config:
  (): custom_colbert.utils.train_custom_colbert_models.ColModelTrainingConfig
  processor:
    (): custom_colbert.utils.wrapper.AutoProcessorWrapper
    pretrained_model_name_or_path: "HuggingFaceM4/idefics2-8b"
    do_image_splitting: false
  model:
    (): custom_colbert.utils.wrapper.AutoColModelWrapper
    pretrained_model_name_or_path: "HuggingFaceM4/idefics2-8b-chatty"
    training_objective: "colbertv1"
    # attn_implementation: "flash_attention_2"
    torch_dtype:  !ext torch.bfloat16
    quantization_config:
      (): transformers.BitsAndBytesConfig
      load_in_4bit: true
      bnb_4bit_quant_type: "nf4"
      bnb_4bit_compute_dtype:  "float16"
      bnb_4bit_use_double_quant: true

  dataset_loading_func: !ext custom_colbert.utils.dataset_transformation.load_docvqa_dataset
  max_length: 256
  run_eval: true
  loss_func:
    (): custom_colbert.loss.colbert_loss.ColbertLoss
  tr_args:
    (): transformers.training_args.TrainingArguments
    output_dir: null
    overwrite_output_dir: true
    num_train_epochs: 3
    per_device_train_batch_size: 4
    gradient_accumulation_steps: 8
    per_device_eval_batch_size: 4
    eval_strategy: "steps"
    dataloader_num_workers: 8
    # bf16: true
    save_steps: 500
    logging_steps: 10
    eval_steps: 50
    warmup_steps: 100
    learning_rate: 5e-5
    save_total_limit: 1

  peft_config:
    (): peft.LoraConfig
    r: 8
    lora_alpha: 8
    lora_dropout: 0.1
    init_lora_weights: "gaussian"
    bias: "none"
    task_type: "FEATURE_EXTRACTION"
    target_modules: '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
```

#### SLURM

```bash
sbatch --nodes=1 --cpus-per-task=16 --mem-per-cpu=32GB --time=20:00:00 --gres=gpu:1  -p gpua100 --job-name=colidefics --output=colidefics.out --error=colidefics.err --wrap="accelerate launch scripts/train/train_colbert.py  scripts/configs/train_colidefics_model.yaml"

sbatch --nodes=1  --time=5:00:00 -A cad15443 --gres=gpu:8  --constraint=MI250 --job-name=colpali --wrap="python scripts/train/train_colbert.py scripts/configs/train_colpali_model.yaml"
```

## Appendix

### Get the raw datasets
Assuming you have the credentials for the private GCS bucket `gs://retriever-training`, run the following commands to download the raw datasets:

```bash
pip install dvc[gs]
gcloud auth application-default login
dvc pull
```

After downloading, the datasets should appear in the `data` directory.
