<h1 align="center">ElicitR: Unlocking Latent Reasoning in Dense Retrievers via Generative Regularization</h1>

<h4 align="center">
    <p>
        <a href="#installation">🔧 Installation</a> |
        <a href="#resources">📚 Resources</a> |
        <a href="#training">🚀 Training</a> |
        <a href="#eval"> 📊 Evaluation</a> |
    </p>
</h4>

<h2 id="installation">Installation</h2>

To begin, set up the conda environment using the following command:

```
conda env create -f environment.yml
```

In <code>EclicitR</code>, we modify the transformers architecture to incorporate **in-batch** attention. To enable this, install a customized version of the `transformers` library. Specically, please put `modelling_llama.py` in the corresponding path. We implement in-batch attention mechanism here. The file might consist of some comments, which are from the original release of `transformers(4.45.2)`: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py.

Finally, we train the model in a modular setup. To install the local package in editable mode, run:

```
cd src/tevatron
pip install -e .
```

<h2 id="resources">Resources</h2>

### Data

The datasets used for the initialization of generative regularzation is listed as below:

| Dataset                    | Source                                                                                                                                                              | Number of Batches | Batch Size |
|----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------|------------|
| [Training Corpus with sampled examples)](data/tevatron_wiki_chunk_sent_1000.jsonl)     | [Wikipedia](https://huggingface.co/datasets/Tevatron/wikipedia-nq-corpus)                                                                                           | 320,000           | 16         |
| [Code Training Corpus with sampled examples](data/merged_stackoverflow_chunk_sent_sampled_lib_tuto_1000.jsonl) | [Stackoverflow Posts](https://huggingface.co/datasets/code-rag-bench/stackoverflow-posts), [Online Tutorials](https://huggingface.co/datasets/code-rag-bench/online-tutorials), [Library Documentation](https://huggingface.co/datasets/code-rag-bench/library-documentation) | 358,763           | 16         |


<h2 id="training">Training</h2>
There are two training stages in ElicitR, the intialization of generative regularizer and contrastive learning together with generative regularization.

The initialization of generative regularization can be found in `train_init.sh`.

```
export CUDA_VISIBLE_DEVICES=0,1,2,3
export TRITON_PRINT_AUTOTUNING=1

export ROOT_DIR=./
export INIT_OUTPUT_DIR=...
export INIT_RUN_NAME=...

deepspeed --include localhost:0,1,2,3 --master_port 6022 --module tevatron.llm_retriever.driver.train \
  --deepspeed $ROOT_DIR/deepspeed/ds_zero3_config.json \
  --output_dir $INIT_OUTPUT_DIR \
  --model_name_or_path meta-llama/Llama-3.2-1B \
  --reference_model_name_or_path HuggingFaceTB/SmolLM2-135M \
  --lora \
  --lora_r 256 \
  --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
  --save_steps 500 \
  --bm25_retrieval_file $DATA_PATH \
  --add_passage_prefix True \
  --add_query_prefix True \
  --first_half True \
  --bf16 \
  --pooling eos \
  --append_eos_token \
  --normalize \
  --temperature 0.01 \
  --attn_temperature 0.0001 \
  --per_device_train_batch_size 1 \
  --train_group_size 16 \
  --learning_rate 1e-4 \
  --passage_max_len 157 \
  --num_train_epochs 1 \
  --gradient_accumulation_steps 8 \
  --logging_steps 1 \
  --overwrite_output_dir \
  --warmup_steps 100 \
  --resume latest \
  --top_k 16 \
  --run_name $INIT_RUN_NAME
```

The running script of contrastive leanring together with generative regularization is present in `train_elicitr.sh`.

```
#!/bin/bash
source activate $env

export TRITON_CACHE_DIR=...
export TRITON_HOME=...
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export TRITON_PRINT_AUTOTUNING=1

export ROOT_DIR=./
export OUTPUT_DIR=...
export INIT_OUTPUT_DIR=...
export RUN_NAME=...

deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 60010 \
--module tevatron.co_retriever.driver.train \
  --deepspeed $ROOT_DIR/deepspeed/ds_zero3_config.json \
  --output_dir $ROOT_DIR/$MODEL_NAME \
  --model_name_or_path meta-llama/Llama-3.2-1B \
  --reference_model_name_or_path HuggingFaceTB/SmolLM2-135M \
  --retriever_lora_name_or_path $ROOT_DIR/$INIT_OUTPUT_DIR/encoder \
  --reference_lora_name_or_path $ROOT_DIR/$INIT_OUTPUT_DIR/reference \
  --disable_v_norm True \
  --lora \
  --lora_r 256 \
  --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
  --save_steps 1000 \
  --dataset_name Tevatron/msmarco-passage-aug \
  --top_k 8 \
  --query_prefix "query: " \
  --passage_prefix "passage: " \
  --bf16 \
  --pooling eos \
  --append_eos_token \
  --normalize \
  --temperature 0.01 \
  --attn_temperature 0.1 \
  --contrastive_loss_weight 0.5 \
  --per_device_train_batch_size 4 \
  --gradient_accumulation_steps 4 \
  --gradient_checkpointing \
  --train_group_size 16 \
  --learning_rate 5e-5 \
  --query_max_len 32 \
  --passage_max_len 196 \
  --num_train_epochs 1 \
  --warmup_steps 100 \
  --logging_steps 1 \
  --overwrite_output_dir \
  --run_name $RUN_NAME
```

<h2 id="eval">Evaluation</h2>

In our work, we evaluate our models mainly on two benchmarks, BRIGHT and BEIR.

For BRIGHT, we modify the public implementation from the original paper (with commit hash: `d99e8391d967d4c2b3a74732530d2309e2fc92b6`). Please replace the original `retriever.py` with `retriever.py` in this repo. The type of models evaluted is `repllamanoinstruct`.

For BEIR, we can evaluate the trained models with customized `mteb`.


```
from mteb.model_meta import ModelMeta
from mteb.models.repllama_models import RepLLaMAWrapper, _loader

PEFT_MODEL=...

elicitr_llama_1b = ModelMeta(
    loader=_loader(
        RepLLaMAWrapper,
        base_model_name_or_path="meta-llama/Llama-3.2-1B",
        peft_model_name_or_path=PEFT_MODEL,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    ),
    name="ElicitR-1b",
    languages=["eng_Latn"],
    open_source=True,
    revision="",  # base-peft revision
    release_date="2024-09-15",
)
elicitr_llama_1b_model = elicitr_llama_1b.loader()

evaluation = mteb.MTEB(tasks=["SciFact", "NFCorpus"])
evaluation.run(model=elicitr_llama_1b_model, output_folder="results/ElicitR-1b")
```


<h2 id="citing">Citing</h2>
