# Retrieval-Augmented Thought Process as Sequential Decision Making

This repository contains the source code for the paper: "Retrieval Augmented Thought Process for Private Data Handling in Healthcare".
Submitted to ICLR 2025.


## Installation

Instructions for setting up the environment to run the code.

```bash
cd RATP
pip install -r requirements.txt
```

Instruction for obtaining the emrQA datasets and the associated knowledge base.

1. Requesting access to the n2c2 data sets at the following address : https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/.
2. The questions are included in the emrQA datasets. We only used the query2 file.
3. The knowledge base is composed of the "Training Data: Track 2 Training Data v2" set .

Instruction for obtaining the wikipedia knowledge base.

1. Obtaining a Wikipedia dump by following these instructions : https://en.wikipedia.org/wiki/Wikipedia:Database_download.

Instruction for obtaining the EHRQA datasets and the associated knowledge base.

1. Requesting the access to MIMIC-IV at the following address : https://physionet.org/content/mimic-iv-demo/2.2/
2. The dataset can be downloaded at the following address : https://physionet.org/content/drugehrqa/1.0.0/

Instruction for the knowledge base processing :

1. Run data/splitting_run_parallel.py . If necessary change the INPUT_DIR and OUTPUT_DIR variable.
2. Run data.id.py . If necessary change the INPUT_DIR variable.
3. Run embeding_run_parallel.py . If necessary change the INPUT_DIR and OUTPUT_DIR variable.


The method uses Llama2 chat 70B (https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) or Mixtral8x7B Instruct (https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1).


## Running the experiments

To run the experiments on the Boolq test set, you can use the following script for each method :

LLM :

```bash
python source/evaluate.py \
--method_name LLM \
--query_set_name boolq_test \
--collection_name wikipedia_dump_100 \
--oracle False \
--save_graph False \
--output_dir results/ \
--verbose False 
```
RAG :

```bash
python source/evaluate.py
--method_name LLMRAG
--query_set_name boolq_test
--collection_name wikipedia_dump_100
--document_rate 1
--oracle False
--save_graph False
--output_dir results/
--verbose False
```

MCTS - oracle :

```bash
python source/evaluate.py
--method_name mcts
--query_set_name boolq_test
--collection_name wikipedia_dump_100
--document_rate 2
--size 15
--oracle True
--save_graph False
--output_dir results/
--verbose False
--p_document 0.7
--threshold 0.5
```

MCTS - selfcritic :

```bash
python source/evaluate.py
--method_name mcts
--query_set_name boolq_test
--collection_name wikipedia_dump_100
--document_rate 2
--size 15
--oracle False
--save_graph False
--output_dir results/
--verbose False
--p_document 0.7
--threshold 0.9
```

MCTS - model-based estimation :

```bash
python source/evaluate.py
--method_name qmcts
--query_set_name boolq_test
--collection_name wikipedia_dump_100
--document_rate 2
--size 15
--oracle False
--save_graph False
--output_dir results/
--verbose False
--threshold 0.21
```


To run the experiments on the emrQA test set, you can use the following script for each method :

LLM :

```bash
python source/evaluate.py
--method_name LLM
--query_set_name ade_qa_med2_test
--collection_name med_records_100
--oracle False
--save_graph False
--output_dir results/
--verbose False
```
RAG :

```bash
python source/evaluate.py
--method_name LLMRAG
--query_set_name ade_qa_med2_test
--collection_name med_records_100
--document_rate 1
--oracle False
--save_graph False
--output_dir results/
--verbose False
```

MCTS - oracle :

```bash
python source/evaluate.py
--method_name mcts
--query_set_name ade_qa_med2_test
--collection_name med_records_100
--document_rate 5
--size 25
--oracle True
--save_graph False
--output_dir results/
--verbose False
--p_document 1
--threshold 0.5
```

MCTS - selfcritic :

```bash
python source/evaluate.py
--method_name mcts
--query_set_name ade_qa_med2_test
--collection_name med_records_100
--document_rate 5
--size 25
--oracle False
--save_graph False
--output_dir results/
--verbose False
--p_document 1
--threshold 0.9
```

MCTS - model-based estimation :

```bash
python source/evaluate.py
--method_name qmcts
--query_set_name ade_qa_med2_test
--collection_name med_records_100
--document_rate 5
--size 25
--oracle False
--save_graph False
--output_dir results/
--verbose False
--threshold 0.9
```


## Content

- `data/` Directory containing all the data and the script to process it.
- `model/` Directory containing the weight of the score estimation model.
- `source/controller/` Directory containing all the code related to information retrieval, mcts, thought process.
- `source/models/` Directory containing the models that manage all the different LLM request.
- `source/qlearner/` Directory contraining the training script for the score estimation model.
- `source/language_models/` Directory containing the LLM API.

