#!/bin/bash
pip install --upgrade pip
pip install transformers sentencepiece datasets faiss-gpu tensorboard deepspeed==0.6.5 jsonlines beir accelerate
# pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
ALIAS=""
HOME_DIR="/vc_data_2/$ALIAS" 
if [ ! -d "$HOME_DIR" ]; then
    HOME_DIR="/data/$ALIAS"
fi
CODE_DIR=$HOME_DIR"/code/UnivSearchDev"
PLM_DIR=$HOME_DIR"/model_checkpoints"
COLLECTION_DIR=$HOME_DIR"/data/msmarco"
PROCESSED_DIR=$HOME_DIR"/data/msmarco/processed_data"
LOG_DIR="$HOME_DIR/tensorboard"
CHECKPOINT_DIR=$HOME_DIR"/model_checkpoints"
EMBEDDING_DIR=$HOME_DIR"/embeddings_cache"
RESULT_DIR=$HOME_DIR"/result"
EVAL_DIR=$CODE_DIR"/metrics/trec/trec_eval-9.0.7/"

PROJECT_DIR=$CODE_DIR"/projects/T5-DPR"
TITLE_CHOICE='with_title'
MODEL_NAME="t5-$TITLE_CHOICE-batch_size-32"
NUM_TRAIN_EPOCH_ENCODER=3
BASELINE_PLM="$PLM_DIR/t5-base-scaled"
PER_DEVICE_TRAIN_BATCH_SIZE=32
PER_DEVICE_INFERENCE_BATCH_SIZE=256
NUM_TRAIN_EPOCHS=15
RESULTS_LOG="$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME-results.txt"
TRAIN_FILE="$PROCESSED_DIR/t5-$TITLE_CHOICE-batch_size-32/train.new.jsonl"
VAL_FILE="$PROCESSED_DIR/t5-$TITLE_CHOICE-batch_size-32/val.jsonl"
SAVE_STEPS=500
LOG_STEPS=100
Q_MAX_LEN=32
P_MAX_LEN=128


cd $CODE_DIR
export PYTHONPATH=.
echo "start training..."

# python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
#     lib/openmatch/driver/train_dr.py  \
#     --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME  \
#     --model_name_or_path $BASELINE_PLM  \
#     --do_train  \
#     --save_steps $SAVE_STEPS  \
#     --eval_steps $SAVE_STEPS  \
#     --logging_steps $LOG_STEPS \
#     --train_path $TRAIN_FILE  \
#     --eval_path $VAL_FILE  \
#     --fp16  \
#     --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
#     --train_n_passages 8  \
#     --learning_rate 5e-6  \
#     --q_max_len $Q_MAX_LEN  \
#     --p_max_len $P_MAX_LEN  \
#     --num_train_epochs $NUM_TRAIN_EPOCHS  \
#     --use_t5_decoder  \
#     --logging_dir $LOG_DIR/msmarco/$MODEL_NAME  \
#     --evaluation_strategy steps \
#     --negatives_x_device True \
#     --remove_unused_columns False \
#     --overwrite_output_dir True \
#     --report_to tensorboard \
#     --dataloader_num_workers 8

# echo "building index..."

# python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
#     lib/openmatch/driver/build_index.py  \
#     --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME  \
#     --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME  \
#     --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
#     --corpus_path $COLLECTION_DIR/raw_data/collection_with_title.tsv  \
#     --use_t5_decoder  \
#     --doc_template "Title: <title> Text: <text>"  \
#     --doc_column_names id,title,text  \
#     --q_max_len $Q_MAX_LEN  \
#     --p_max_len $P_MAX_LEN  \
#     --fp16  \
#     --dataloader_num_workers 1

# if [ ! -d "$RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME" ]; then
#     mkdir -p $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME
# fi

# echo "retrieving..."

# python lib/openmatch/driver/retrieve.py  \
#     --output_dir $EMBEDDING_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME \
#     --model_name_or_path $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME  \
#     --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
#     --query_path $COLLECTION_DIR/raw_data/queries.dev.small.tsv  \
#     --use_t5_decoder  \
#     --query_template "<text>"  \
#     --query_column_names id,text  \
#     --q_max_len $Q_MAX_LEN  \
#     --fp16  \
#     --trec_save_path $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec  \
#     --dataloader_num_workers 1

# $EVAL_DIR/trec_eval -c -mrecip_rank.10 -mrecall.100 $COLLECTION_DIR/raw_data/qrels.dev.small.tsv $RESULT_DIR/msmarco_$TITLE_CHOICE/$MODEL_NAME/dev.trec > $RESULTS_LOG

# cat $RESULTS_LOG

python -m torch.distributed.launch --nproc_per_node=8 --master_port 19286 \
    lib/openmatch/driver/train_dr.py  \
    --output_dir $CHECKPOINT_DIR/msmarco_denseretrievers/$MODEL_NAME-unused  \
    --model_name_or_path $BASELINE_PLM  \
    --do_train  \
    --save_steps 400000  \
    --eval_steps 400000  \
    --logging_steps 400000 \
    --train_path $TRAIN_FILE  \
    --eval_path $VAL_FILE  \
    --fp16  \
    --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE  \
    --train_n_passages 8  \
    --learning_rate 5e-6  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --num_train_epochs 200  \
    --use_t5_decoder  \
    --logging_dir $LOG_DIR/msmarco/$MODEL_NAME-unused  \
    --evaluation_strategy steps \
    --negatives_x_device True \
    --remove_unused_columns False \
    --overwrite_output_dir True \
    --report_to tensorboard \
    --dataloader_num_workers 8
    
