## Hexa: Self-Improving Framework for Knowledge-Grounded Dialogue System

### Requirements
```
pip install retriv contractions rouge html2text googlesearch-python matplotlib seaborn jsonlines
pip install --upgrade transformers==4.27.2
pip install sentence-transformers
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
```

### Run Search Server
1. Download Wikipedia Corpus: We use Wiki corpus, named as **data.wikipedia_split.psgs_w100**, provided by [this repo](https://github.com/facebookresearch/DPR/blob/main/dpr/data/download_data.py).
2. Building BM25 based static search server
```
python search_engines/misc/build_wiki.py --corpus_path=[downloaded tsv file path] --target_path=[target indexing file path]
```
3. Run search server
```
sh run_search_server.sh 8080 8 
```

### Hexa Training
1. Download initial BlenderBot3 checkpoint of 3B following [this link](https://parl.ai/projects/bb3/).
2. Run train script
```
python run_self_learn.py --master_port=1111 --visible_gpus=0,1,2,3 --server_port=8080 --name_space=hexa-highest --initial_model_path=[BB3-3B model path] --bt_inc_rate=0.1 --inc_rate=0.0 --num_bootstrap=4000 --start_num_loop=0 --lr=2e-6 --bs=1 --acc=4 --max_iter=10
```
* You can also run 'STaR' using above script with --name_space=star
### Evaluation
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=1111 self_learn_eval.py --config_path=configs/selflearn_3b_eval.yaml --model.model_init_from_hf=true --knowledge_conditioning=combined --memory_decision=compute --server_port=8080 --debug=false --trainer.do_train=False --trainer.per_device_eval_batch_size=1 --name_space=hexa-highest-SR --scheme=bt --num_loop=9 --num_bootstrap=4000 --finetune_num_epoch=1 --max_num_entries=50000 --save_eval_samples=True
```
* After evaluation, you can find json files that include the evaluation scores in the following path: ./experiment/bb3_3b/[your model path]/selflearn_eval_[task name]_log.json
* You can also find the generated samples in the following path: ./experiment/data/eval/[your model path]/[task name].json