#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
ALIAS="" 
HOME_DIR="/vc_data_2/$ALIAS" 
if [ ! -d "$HOME_DIR" ]; then
    HOME_DIR="/data/$ALIAS"
fi
CODE_DIR=$HOME_DIR"/code/UnivSearchDev"
PLM_DIR=$HOME_DIR"/model_checkpoints"
COLLECTION_DIR=$HOME_DIR"/data/msmarco"
PROCESSED_DIR=$HOME_DIR"/data/msmarco/processed_data"
LOG_DIR="$HOME_DIR/tensorboard"
CHECKPOINT_DIR=$HOME_DIR"/model_checkpoints"
EMBEDDING_DIR=$HOME_DIR"/embeddings_cache"
RESULT_DIR=$HOME_DIR"/result"
EVAL_DIR=$CODE_DIR"/metrics/trec/trec_eval-9.0.7/"

PROJECT_DIR=$CODE_DIR"/projects/T5_ANCEPROMPT"
TITLE_CHOICE='with_title'
MODEL_NAME="anceprompt-dpr-k-10-$TITLE_CHOICE-batch_size-32-lambda-2"
REF_PLM=""
NUM_TRAIN_EPOCH_ENCODER=3
NUM_DISTILL_EPOCH_ENCODER=3
BASELINE_PLM="$HOME_DIR/model_checkpoints/msmarco_denseretrievers/t5-$TITLE_CHOICE-batch_size-32"

PER_DEVICE_TRAIN_BATCH_SIZE=32
PER_DEVICE_INFERENCE_BATCH_SIZE=64
NUM_TRAIN_ITERS=6
RESULTS_LOG="$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-results.txt"
TRAIN_FILE="$PROCESSED_DIR/$MODEL_NAME/train.new.jsonl"
VAL_FILE="$PROCESSED_DIR/$MODEL_NAME/val.jsonl"
TRAIN_DISTILL_FILE="$PROCESSED_DIR/$MODEL_NAME/train.distill.new.jsonl"
VAL_DISTILL_FILE="$PROCESSED_DIR/$MODEL_NAME/val.distill.jsonl"
SAVE_STEPS=500
LOG_STEPS=100
DISTILL_SAVE_STEPS=200
DISTILL_LOG_STEPS=50
Q_MAX_LEN=32
P_MAX_LEN=128
GROUND_DOC_NUM=10
TEMPERATURE=1000

cd $CODE_DIR
export PYTHONPATH=.

python $PROJECT_DIR/project_scripts/msmarco/combine_all_grounding.py  \
    --msmarco_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
    --mesh_path $HOME_DIR/data/mesh/corpus.tsv  \
    --wiki_path $HOME_DIR/data/wiki/corpus.tsv  \
    --save_to_path $HOME_DIR/data/combined/corpus.tsv

echo "warm up: building index using REF PLM $REF_PLM..."| tee -a $RESULTS_LOG
python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
    $PROJECT_DIR/project_lib/driver/build_index.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-combined  \
    --model_name_or_path $REF_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --corpus_path $HOME_DIR/data/combined/corpus.tsv  \
    --use_t5_decoder  \
    --doc_template "Corpus: <name> <title> <text>"  \
    --doc_column_names id,title,text,name  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --fp16  \
    --dataloader_num_workers 1

echo "warm up: retrieving train using REF PLM $REF_PLM..." | tee -a $RESULTS_LOG
python $PROJECT_DIR/project_lib/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-combined \
    --model_name_or_path $REF_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $COLLECTION_DIR/raw_data/train.query.txt  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_combined.trec  \
    --dataloader_num_workers 1

echo "warm up: augmenting train..." | tee -a $RESULTS_LOG
python $PROJECT_DIR/project_scripts/msmarco/augment_query.py \
    --trec_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_combined.trec  \
    --queries $COLLECTION_DIR/raw_data/train.query.txt  \
    --collection $HOME_DIR/data/combined/corpus.tsv  \
    --save_to $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json  \
    --prf_k $GROUND_DOC_NUM

echo "warm up: retrieving dev using REF PLM $REF_PLM..." | tee -a $RESULTS_LOG
python $PROJECT_DIR/project_lib/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-combined \
    --model_name_or_path $REF_PLM \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_combined.trec  \
    --dataloader_num_workers 1

echo "warm up: augmenting dev..." | tee -a $RESULTS_LOG
python $PROJECT_DIR/project_scripts/msmarco/augment_query.py \
    --trec_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_combined.trec  \
    --queries $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
    --collection $HOME_DIR/data/combined/corpus.tsv  \
    --save_to $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.query.augmented.json  \
    --prf_k $GROUND_DOC_NUM

ITER_NUM=0
echo "warm up: building index using baseline PLM $BASELINE_PLM..."| tee -a $RESULTS_LOG
python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
    $PROJECT_DIR/project_lib/driver/build_index.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME  \
    --model_name_or_path $BASELINE_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
    --use_t5_decoder  \
    --doc_template "Title: <title> Text: <text>"  \
    --doc_column_names id,title,text  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --fp16  \
    --dataloader_num_workers 1

echo "warm up: retrieving train without grounding using baseline PLM $BASELINE_PLM..."| tee -a $RESULTS_LOG
python $PROJECT_DIR/project_lib/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
    --model_name_or_path $BASELINE_PLM  \
    --per_device_eval_batch_size 256  \
    --query_path $COLLECTION_DIR/raw_data/train.query.txt  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.trec  \
    --dataloader_num_workers 1

echo "warm up: retrieving train with grounding using baseline PLM $BASELINE_PLM..."| tee -a $RESULTS_LOG
python $PROJECT_DIR/project_lib/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
    --model_name_or_path $BASELINE_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text,grounds,ground_ids  \
    --use_ground True  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_augmented.trec  \
    --dataloader_num_workers 1 \
    --ground_passage_num $GROUND_DOC_NUM

rm -rf $PROCESSED_DIR/$MODEL_NAME
mkdir -p $PROCESSED_DIR/$MODEL_NAME
echo "warm up: building train with grounds..."| tee -a $RESULTS_LOG
python $PROJECT_DIR/project_scripts/msmarco/build_hn_with_ground.py  \
    --tokenizer_name $PLM_DIR/t5-base-scaled  \
    --hn_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_augmented.trec  \
    --qrels $COLLECTION_DIR/raw_data/qrels.train.tsv  \
    --queries $COLLECTION_DIR/raw_data/train.query.txt  \
    --collection $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
    --save_to $PROCESSED_DIR/$MODEL_NAME  \
    --template "Title: <title> Text: <text>" \
    --use_title True \
    --augmented_q_file_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json

echo "warm up: spliting file..."| tee -a $RESULTS_LOG
cat $PROCESSED_DIR/$MODEL_NAME/*.hn.jsonl > $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl
tail -n 500 $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl > $VAL_FILE
head -n 502439 $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl > $TRAIN_FILE

mkdir -p $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME

for ((ITER_NUM=0; ITER_NUM<1; ITER_NUM++))
do
    echo "Iter $ITER_NUM" | tee -a $RESULTS_LOG
    let NEW_ITER_NUM=$ITER_NUM+1
    let LAST_ITER_NUM=$ITER_NUM-1

    if [ "$ITER_NUM" -le "0" ]; then
        echo "iter $ITER_NUM: training model starting from baseline LM $BASELINE_PLM" | tee -a $RESULTS_LOG
        python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
            $PROJECT_DIR/project_lib/driver/train_dr.py  \
            --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --model_name_or_path $BASELINE_PLM  \
            --do_train  \
            --save_steps $SAVE_STEPS  \
            --eval_steps $SAVE_STEPS  \
            --logging_steps $LOG_STEPS \
            --train_path $TRAIN_FILE  \
            --eval_path $VAL_FILE  \
            --fp16  \
            --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
            --train_n_passages 8  \
            --learning_rate 5e-6  \
            --q_max_len $Q_MAX_LEN  \
            --p_max_len $P_MAX_LEN  \
            --num_train_epochs $NUM_TRAIN_EPOCH_ENCODER  \
            --use_t5_decoder  \
            --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --evaluation_strategy steps \
            --negatives_x_device True \
            --remove_unused_columns False \
            --overwrite_output_dir True \
            --report_to tensorboard \
            --dataloader_num_workers 8 \
            --ground_passage_num $GROUND_DOC_NUM
    else
        echo "iter $ITER_NUM: training model starting from checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$ITER_NUM" | tee -a $RESULTS_LOG
        python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
            $PROJECT_DIR/project_lib/driver/train_dr.py  \
            --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$ITER_NUM  \
            --do_train  \
            --save_steps $SAVE_STEPS  \
            --eval_steps $SAVE_STEPS  \
            --logging_steps $LOG_STEPS \
            --train_path $TRAIN_FILE  \
            --eval_path $VAL_FILE  \
            --fp16  \
            --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
            --train_n_passages 8  \
            --learning_rate 5e-6  \
            --q_max_len $Q_MAX_LEN  \
            --p_max_len $P_MAX_LEN  \
            --num_train_epochs $NUM_TRAIN_EPOCH_ENCODER  \
            --use_t5_decoder  \
            --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --evaluation_strategy steps \
            --negatives_x_device True \
            --remove_unused_columns False \
            --overwrite_output_dir True \
            --report_to tensorboard \
            --dataloader_num_workers 8 \
            --ground_passage_num $GROUND_DOC_NUM
    fi

    echo "iter $ITER_NUM: building index using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
        $PROJECT_DIR/project_lib/driver/build_index.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME  \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
        --use_t5_decoder  \
        --doc_template "Title: <title> Text: <text>"  \
        --doc_column_names id,title,text  \
        --q_max_len $Q_MAX_LEN  \
        --p_max_len $P_MAX_LEN  \
        --fp16  \
        --dataloader_num_workers 1
    
    ############################evaluate train######################################
    echo "iter $ITER_NUM: retrieving test using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.query.augmented.json  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text,grounds,ground_ids  \
        --use_ground True  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_augmented.trec  \
        --dataloader_num_workers 1 \
        --ground_passage_num $GROUND_DOC_NUM

    echo "iter $ITER_NUM: evaluation with grounding using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    $EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_augmented.trec | tee -a $RESULTS_LOG
    ##########################end evaluate train######################################
    ############################training for distill###################################

    echo "prepare for query_ground_att.pkl"
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text,grounds,ground_ids  \
        --use_ground True  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_augmented.trec  \
        --dataloader_num_workers 1 \
        --ground_passage_num $GROUND_DOC_NUM

    rm -rf $PROCESSED_DIR/$MODEL_NAME
    mkdir -p $PROCESSED_DIR/$MODEL_NAME
    echo "iter $ITER_NUM: building distill data..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_scripts/msmarco/build_distill_combined.py  \
        --q_att_file_path $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/query_ground_att.pkl  \
        --tokenizer_name $PLM_DIR/t5-base-scaled  \
        --augmented_q_file_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json  \
        --distill_train_file_dir $PROCESSED_DIR/$MODEL_NAME \
        --hn_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_combined.trec  \
        --qrels $COLLECTION_DIR/raw_data/qrels.train.tsv  \
        --queries $COLLECTION_DIR/raw_data/train.query.txt  \
        --collection $HOME_DIR/data/combined/corpus.tsv  \
        --save_to $PROCESSED_DIR/$MODEL_NAME  \
        --template "Corpus: <name> <title> <text>" \
        --use_title True 
 
    echo "iter $ITER_NUM: spliting distilling file..." | tee -a $RESULTS_LOG
    cat $PROCESSED_DIR/$MODEL_NAME/*.distill.jsonl > $PROCESSED_DIR/$MODEL_NAME/train.distill.jsonl
    tail -n 500 $PROCESSED_DIR/$MODEL_NAME/train.distill.jsonl > $VAL_DISTILL_FILE
    head -n 502439 $PROCESSED_DIR/$MODEL_NAME/train.distill.jsonl > $TRAIN_DISTILL_FILE

    echo "iter $ITER_NUM: distilling to grounder" | tee -a $RESULTS_LOG
    if [ $ITER_NUM == 0 ]; then
        python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
            $PROJECT_DIR/project_lib/driver/train_distill.py  \
            --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2-pseudo-pos  \
            --model_name_or_path $BASELINE_PLM  \
            --do_train  \
            --save_steps $DISTILL_SAVE_STEPS  \
            --eval_steps $DISTILL_LOG_STEPS  \
            --logging_steps $DISTILL_LOG_STEPS \
            --train_path $TRAIN_DISTILL_FILE  \
            --eval_path $VAL_DISTILL_FILE  \
            --fp16  \
            --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
            --train_n_passages $GROUND_DOC_NUM  \
            --learning_rate 5e-6  \
            --q_max_len $Q_MAX_LEN  \
            --p_max_len $P_MAX_LEN  \
            --num_train_epochs $NUM_DISTILL_EPOCH_ENCODER  \
            --use_t5_decoder  \
            --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2-pseudo-pos  \
            --evaluation_strategy steps \
            --remove_unused_columns False \
            --overwrite_output_dir True \
            --report_to tensorboard \
            --softmax_temperature 2000 \
            --dataloader_num_workers 8 \
            --loss_lambda 2 \
            --negatives_x_device True \
            --use_relevant True
    else
       python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
            $PROJECT_DIR/project_lib/driver/train_distill.py  \
            --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2-pseudo-pos  \
            --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$ITER_NUM-lambda_2-pseudo-pos \
            --do_train  \
            --save_steps $DISTILL_SAVE_STEPS  \
            --eval_steps $DISTILL_LOG_STEPS  \
            --logging_steps $DISTILL_LOG_STEPS \
            --train_path $TRAIN_DISTILL_FILE  \
            --eval_path $VAL_DISTILL_FILE  \
            --fp16  \
            --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
            --train_n_passages $GROUND_DOC_NUM  \
            --learning_rate 5e-6  \
            --q_max_len $Q_MAX_LEN  \
            --p_max_len $P_MAX_LEN  \
            --num_train_epochs $NUM_DISTILL_EPOCH_ENCODER  \
            --use_t5_decoder  \
            --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2-pseudo-pos  \
            --evaluation_strategy steps \
            --remove_unused_columns False \
            --overwrite_output_dir True \
            --report_to tensorboard \
            --softmax_temperature 2000 \
            --dataloader_num_workers 8 \
            --loss_lambda 2 \
            --negatives_x_device True \
            --use_relevant True
    fi     
    ######################begin distill json file generation##########################
    echo "iter $ITER_NUM: building combined index using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
        $PROJECT_DIR/project_lib/driver/build_index.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/end-pseudo-pos  \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2-pseudo-pos  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --corpus_path $HOME_DIR/data/combined/corpus.tsv  \
        --use_t5_decoder  \
        --doc_template "Corpus: <name> <title> <text>"  \
        --doc_column_names id,title,text,name  \
        --q_max_len $Q_MAX_LEN  \
        --p_max_len $P_MAX_LEN  \
        --fp16  \
        --dataloader_num_workers 1

    echo "iter $ITER_NUM: retrieving train using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-combined \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $COLLECTION_DIR/raw_data/train.query.txt  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_combined.trec  \
        --dataloader_num_workers 1
    
    echo "iter $ITER_NUM: augmenting train queries using GROUNDING_PLM $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_scripts/msmarco/augment_query.py \
        --trec_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_combined.trec  \
        --queries $COLLECTION_DIR/raw_data/train.query.txt  \
        --collection $HOME_DIR/data/combined/corpus.tsv  \
        --save_to $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json  \
        --prf_k $GROUND_DOC_NUM

    echo "iter $ITER_NUM: retrieving test using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-combined \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_combined.trec  \
        --dataloader_num_workers 1

    echo "iter $ITER_NUM: augmenting test queries using GROUNDING_PLM $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_scripts/msmarco/augment_query.py \
        --trec_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_combined.trec  \
        --queries $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
        --collection $HOME_DIR/data/combined/corpus.tsv  \
        --save_to $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.query.augmented.json  \
        --prf_k $GROUND_DOC_NUM

    echo "iter $ITER_NUM: building index on msmarco using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
        $PROJECT_DIR/project_lib/driver/build_index.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-distill  \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
        --use_t5_decoder  \
        --doc_template "Title: <title> Text: <text>"  \
        --doc_column_names id,title,text  \
        --q_max_len $Q_MAX_LEN  \
        --p_max_len $P_MAX_LEN  \
        --fp16  \
        --dataloader_num_workers 1

    echo "iter $ITER_NUM: retrieving train on msmarco using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-distill \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $COLLECTION_DIR/raw_data/train.query.txt  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.trec  \
        --dataloader_num_workers 1
    
    echo "iter $ITER_NUM: retrieving test on msmarco using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-distill \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM-lambda_2  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec  \
        --dataloader_num_workers 1

    echo "iter $ITER_NUM: evaluating distilled grounder $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-grounder$NEW_ITER_NUM..." | tee -a $RESULTS_LOG

    $EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec | tee -a $RESULTS_LOG

    #####################end distill json file generation##########################
    #############################build hn for train################################
    echo "iter $ITER_NUM: retrieving train using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_lib/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text,grounds,ground_ids  \
        --use_ground True  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_augmented.trec  \
        --dataloader_num_workers 1 \
        --ground_passage_num $GROUND_DOC_NUM

    rm -rf $PROCESSED_DIR/$MODEL_NAME
    mkdir -p $PROCESSED_DIR/$MODEL_NAME
    echo "iter $ITER_NUM: building train with grounds..." | tee -a $RESULTS_LOG
    python $PROJECT_DIR/project_scripts/msmarco/build_hn_with_ground.py  \
        --tokenizer_name $PLM_DIR/t5-base-scaled  \
        --hn_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train_augmented.trec  \
        --qrels $COLLECTION_DIR/raw_data/qrels.train.tsv  \
        --queries $COLLECTION_DIR/raw_data/train.query.txt  \
        --collection $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
        --save_to $PROCESSED_DIR/$MODEL_NAME  \
        --template "Title: <title> Text: <text>" \
        --use_title True \
        --augmented_q_file_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.query.augmented.json
    
    echo "iter $ITER_NUM: spliting training file..." | tee -a $RESULTS_LOG
    cat $PROCESSED_DIR/$MODEL_NAME/*.hn.jsonl > $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl
    tail -n 500 $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl > $VAL_FILE
    head -n 502439 $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl > $TRAIN_FILE
    ###########################build hn for train#####################################
done


