#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
ALIAS="" ### TODO: dear user: change this env variable and only this one only
HOME_DIR="/vc_data_2/$ALIAS" 
CODE_DIR=$HOME_DIR"/code/UnivSearchDev"
PLM_DIR=$HOME_DIR"/model_checkpoints"
COLLECTION_DIR=$HOME_DIR"/data/msmarco"
PROCESSED_DIR=$HOME_DIR"/data/msmarco/processed_data"
LOG_DIR="$HOME_DIR/tensorboard"
CHECKPOINT_DIR=$HOME_DIR"/model_checkpoints"
EMBEDDING_DIR=$HOME_DIR"/embeddings_cache"
RESULT_DIR=$HOME_DIR"/result"
EVAL_DIR=$CODE_DIR"/metrics/trec/trec_eval-9.0.7/"

PROJECT_DIR=$CODE_DIR"/projects/T5-DPR"
TITLE_CHOICE='no_title'
MODEL_NAME="t5-$TITLE_CHOICE-batch_size-32"
NUM_TRAIN_EPOCH_ENCODER=3
BASELINE_PLM="$PLM_DIR/t5-base-scaled"
PER_DEVICE_TRAIN_BATCH_SIZE=32
PER_DEVICE_INFERENCE_BATCH_SIZE=256
NUM_TRAIN_EPOCHS=15
RESULTS_LOG="$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-results.txt"
TRAIN_FILE="$PROCESSED_DIR/t5-$TITLE_CHOICE/train.new.jsonl"
VAL_FILE="$PROCESSED_DIR/t5-$TITLE_CHOICE/val.jsonl"
SAVE_STEPS=500
LOG_STEPS=100
Q_MAX_LEN=32
P_MAX_LEN=128

cd $CODE_DIR
export PYTHONPATH=.

# echo "build train file..."
# rm -rf $PROCESSED_DIR/t5-$TITLE_CHOICE/
# python $PROJECT_DIR/project_scripts/msmarco/build_train.py \
#     --tokenizer_name $BASELINE_PLM  \
#     --negative_file $COLLECTION_DIR/raw_data/train.negatives.tsv  \
#     --qrels $COLLECTION_DIR/raw_data/qrels.train.tsv  \
#     --queries $COLLECTION_DIR/raw_data/train.query.txt  \
#     --collection $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
#     --save_to $PROCESSED_DIR/t5-$TITLE_CHOICE/  \
#     --template "<text>"

# cat $PROCESSED_DIR/t5-$TITLE_CHOICE/*.jsonl > $PROCESSED_DIR/t5-$TITLE_CHOICE/train.jsonl

# tail -n 500 $PROCESSED_DIR/t5-$TITLE_CHOICE/train.jsonl > $PROCESSED_DIR/t5-$TITLE_CHOICE/val.jsonl
# head -n 400282 $PROCESSED_DIR/t5-$TITLE_CHOICE/train.jsonl > $PROCESSED_DIR/t5-$TITLE_CHOICE/train.new.jsonl

echo "start training..."

deepspeed --include="worker-0" lib/openmatch/driver/train_dr.py  \
    --deepspeed $PROJECT_DIR/configs/ds_config_noBm25.json  \
    --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME  \
    --model_name_or_path $BASELINE_PLM  \
    --do_train  \
    --save_steps $SAVE_STEPS  \
    --eval_steps $SAVE_STEPS  \
    --logging_steps $LOG_STEPS \
    --train_path $TRAIN_FILE  \
    --eval_path $VAL_FILE  \
    --fp16  \
    --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
    --train_n_passages 8  \
    --learning_rate 5e-6  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --num_train_epochs $NUM_TRAIN_EPOCHS  \
    --use_t5_decoder  \
    --logging_dir $LOG_DIR/msmarco/$MODEL_NAME  \
    --evaluation_strategy steps \
    --negatives_x_device True \
    --remove_unused_columns False \
    --overwrite_output_dir True \
    --report_to tensorboard \
    --dataloader_num_workers 8

echo "building index..."

python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
    lib/openmatch/driver/build_index.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME  \
    --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
    --use_t5_decoder  \
    --doc_template "<text>"  \
    --doc_column_names id,title,text  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --fp16  \
    --dataloader_num_workers 1

if [ ! -d "$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME" ]; then
    mkdir -p $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME
fi

echo "retrieving..."

python lib/openmatch/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
    --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec  \
    --dataloader_num_workers 1

echo "MRR@100" | tee -a $RESULTS_LOG
$EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec | tee -a $RESULTS_LOG

echo "MRR@10" | tee -a $RESULTS_LOG
python $CODE_DIR/metrics/mrr.py $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec | tee -a $RESULTS_LOG

cat $RESULTS_LOG
