#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
ALIAS="" ### TODO: dear user: change this env variable and only this one only
HOME_DIR="/vc_data_2/$ALIAS" 
CODE_DIR=$HOME_DIR"/code/UnivSearchDev"
PLM_DIR=$HOME_DIR"/model_checkpoints"
COLLECTION_DIR=$HOME_DIR"/data/msmarco"
PROCESSED_DIR=$HOME_DIR"/data/msmarco/processed_data"
LOG_DIR="$HOME_DIR/tensorboard"
CHECKPOINT_DIR=$HOME_DIR"/model_checkpoints"
EMBEDDING_DIR=$HOME_DIR"/embeddings_cache"
RESULT_DIR=$HOME_DIR"/result"
EVAL_DIR=$CODE_DIR"/metrics/trec/trec_eval-9.0.7/"

PROJECT_DIR=$CODE_DIR"/projects/T5-ANCE"
TITLE_CHOICE='with_title'
MODEL_NAME="ance-$TITLE_CHOICE-batch_size-32-t5-epochs15"
NUM_TRAIN_EPOCH_ENCODER=3
BASELINE_PLM="$CHECKPOINT_DIR/msmarco_denseretrievers/t5-$TITLE_CHOICE-batch_size-32-epochs15"
PER_DEVICE_TRAIN_BATCH_SIZE=32
PER_DEVICE_INFERENCE_BATCH_SIZE=256
NUM_TRAIN_ITERS=10
TRAIN_FILE="$PROCESSED_DIR/$MODEL_NAME/train.new.hn.jsonl"
VAL_FILE="$PROCESSED_DIR/$MODEL_NAME/val.hn.jsonl"
SAVE_STEPS=500
LOG_STEPS=100
Q_MAX_LEN=32
P_MAX_LEN=128

ITER_NUM=0
RESULTS_LOG="$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-results.txt"


cd $CODE_DIR
export PYTHONPATH=.
cp $PLM_DIR/t5-base-scaled/special_tokens_map.json $BASELINE_PLM
cp $PLM_DIR/t5-base-scaled/spiece.model $BASELINE_PLM
cp $PLM_DIR/t5-base-scaled/tokenizer.json $BASELINE_PLM
cp $PLM_DIR/t5-base-scaled/tokenizer_config.json $BASELINE_PLM


echo "building index using baseline PLM $BASELINE_PLM..."

python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
    lib/openmatch/driver/build_index.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME  \
    --model_name_or_path $BASELINE_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
    --use_t5_decoder  \
    --doc_template "Title: <title> Text: <text>"  \
    --doc_column_names id,title,text  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --fp16  \
    --dataloader_num_workers 1

if [ ! -d "$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME" ]; then
    mkdir -p $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME
fi

echo "retrieving test using baseline PLM $BASELINE_PLM..."

python lib/openmatch/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
    --model_name_or_path $BASELINE_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec  \
    --dataloader_num_workers 1


echo "evaluation: iter $ITER_NUM:" | tee -a $RESULTS_LOG
$EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec | tee -a $RESULTS_LOG

for ((ITER_NUM=0; ITER_NUM<$NUM_TRAIN_ITERS; ITER_NUM++))
do
    echo "Iter $ITER_NUM"
    let NEW_ITER_NUM=$ITER_NUM+1

    if [ $ITER_NUM == 0 ]; then
        echo "iter $ITER_NUM: retrieving train using baseline PLM $BASELINE_PLM..."
        python lib/openmatch/driver/retrieve.py  \
            --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
            --model_name_or_path $BASELINE_PLM  \
            --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
            --query_path $COLLECTION_DIR/raw_data/train.query.txt  \
            --use_t5_decoder  \
            --query_template "<text>"  \
            --query_column_names id,text  \
            --q_max_len $Q_MAX_LEN  \
            --fp16  \
            --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.trec  \
            --dataloader_num_workers 1
    else
        echo "iter $ITER_NUM: retrieving train using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$ITER_NUM..."
        python lib/openmatch/driver/retrieve.py  \
            --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
            --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$ITER_NUM  \
            --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
            --query_path $COLLECTION_DIR/raw_data/train.query.txt  \
            --use_t5_decoder  \
            --query_template "<text>"  \
            --query_column_names id,text  \
            --q_max_len $Q_MAX_LEN  \
            --fp16  \
            --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.trec  \
            --dataloader_num_workers 1
    fi
    
    rm -rf $PROCESSED_DIR/$MODEL_NAME
    mkdir -p $PROCESSED_DIR/$MODEL_NAME
    echo "iter $ITER_NUM: building hard negatives..."

    python $PROJECT_DIR/project_scripts/msmarco/build_hn.py  \
        --tokenizer_name $PLM_DIR/t5-base-scaled  \
        --hn_file $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/train.trec  \
        --qrels $COLLECTION_DIR/raw_data/qrels.train.tsv  \
        --queries $COLLECTION_DIR/raw_data/train.query.txt  \
        --collection $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
        --save_to $PROCESSED_DIR/$MODEL_NAME  \
        --template "Title: <title> Text: <text>" \
        --use_title True 

    echo "iter $ITER_NUM: spliting file..."

    cat $PROCESSED_DIR/$MODEL_NAME/*.hn.jsonl > $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl
    tail -n 500 $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl > $VAL_FILE
    head -n 502439 $PROCESSED_DIR/$MODEL_NAME/train.hn.jsonl > $TRAIN_FILE
    
    if [ $ITER_NUM == 0 ]; then
        echo "iter $ITER_NUM: start training using baseline PLM $BASELINE_PLM..."
        python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
            lib/openmatch/driver/train_dr.py  \
            --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --model_name_or_path $BASELINE_PLM  \
            --do_train  \
            --save_steps $SAVE_STEPS  \
            --eval_steps $SAVE_STEPS  \
            --logging_steps $LOG_STEPS \
            --train_path $TRAIN_FILE  \
            --eval_path $VAL_FILE  \
            --fp16  \
            --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
            --train_n_passages 8  \
            --learning_rate 5e-6  \
            --q_max_len $Q_MAX_LEN  \
            --p_max_len $P_MAX_LEN  \
            --num_train_epochs $NUM_TRAIN_EPOCH_ENCODER  \
            --use_t5_decoder  \
            --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --evaluation_strategy steps \
            --negatives_x_device True \
            --remove_unused_columns False \
            --overwrite_output_dir True \
            --dataloader_num_workers 8
    else
        echo "iter $ITER_NUM: start training using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$ITER_NUM..."
        python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
            lib/openmatch/driver/train_dr.py  \
            --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$ITER_NUM  \
            --do_train  \
            --save_steps $SAVE_STEPS  \
            --eval_steps $SAVE_STEPS  \
            --logging_steps $LOG_STEPS \
            --train_path $TRAIN_FILE  \
            --eval_path $VAL_FILE  \
            --fp16  \
            --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
            --train_n_passages 8  \
            --learning_rate 5e-6  \
            --q_max_len $Q_MAX_LEN  \
            --p_max_len $P_MAX_LEN  \
            --num_train_epochs $NUM_TRAIN_EPOCH_ENCODER  \
            --use_t5_decoder  \
            --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-refresh$NEW_ITER_NUM  \
            --evaluation_strategy steps \
            --negatives_x_device True \
            --remove_unused_columns False \
            --overwrite_output_dir True \
            --dataloader_num_workers 8 
    fi       

    echo "iter $ITER_NUM: building index using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM..."

    python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
        lib/openmatch/driver/build_index.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME  \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
        --use_t5_decoder  \
        --doc_template "Title: <title> Text: <text>"  \
        --doc_column_names id,title,text  \
        --q_max_len $Q_MAX_LEN  \
        --p_max_len $P_MAX_LEN  \
        --fp16  \
        --dataloader_num_workers 1

    echo "iter $ITER_NUM: retrieving using checkpoint $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM..."

    python lib/openmatch/driver/retrieve.py  \
        --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
        --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-refresh$NEW_ITER_NUM  \
        --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
        --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
        --use_t5_decoder  \
        --query_template "<text>"  \
        --query_column_names id,text  \
        --q_max_len $Q_MAX_LEN  \
        --fp16  \
        --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec  \
        --dataloader_num_workers 1

    echo "evaluation: iter $NEW_ITER_NUM:" | tee -a $RESULTS_LOG
    $EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec

    python $CODE_DIR/projects/T5_ANCE_Aug/project_scripts/msmarco/mrr_truncate.py \
        --trec_file_last $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec  \
        --save_to $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_mrr_10.trec

    $EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_mrr_10.trec

    python $CODE_DIR/metrics/mrr.py $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_mrr_10.trec

    python $PROJECT_DIR/project_scripts/msmarco/merge_prediction.py \
        --trec_file_last $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec \
        --trec_file_new $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_mrr_10.trec \
        --save_to $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev_debug.trec

done