export CUDA_VISIBLE_DEVICES=0,1
HOME_DIR=""
CODE_DIR=$HOME_DIR"/code/UnivSearchDev/"
COLLECTION_DIR=$HOME_DIR"/data/beir/"
CHECKPOINT_DIR=$HOME_DIR"/model_checkpoints"
EMBEDDING_DIR=$HOME_DIR"/embeddings_cache"
RESULT_DIR=$HOME_DIR"/result"
EVAL_DIR=$CODE_DIR"/metrics/trec/trec_eval-9.0.7/"
BEIR_DIR=$CODE_DIR/metrics/beir

TITLE_CHOICE='with_title'
MODEL_NAME="anceprompt-whole-corpus"
GROUND_MODEL_NAME="ancedistill"
FINAL_MODEL_NAME="anceprompt"
DATASET_NAME='nfcorpus'
GROUND_DOC_NUM=10
PROJECT_DIR=$CODE_DIR"/projects/T5_ANCEPROMPT"
##############################
GROUNDING_PLM=""
FINAL_PLM=""
##############################
PER_DEVICE_INFERENCE_BATCH_SIZE=32
LARGER_PER_DEVICE_INFERENCE_BATCH_SIZE=256
Q_MAX_LEN=64
P_MAX_LEN=512
mkdir -p $RESULT_DIR/beir/$DATASET_NAME/$GROUNDING_PLM/
mkdir -p $RESULT_DIR/beir/$DATASET_NAME/$FINAL_PLM/
mkdir -p $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/
mkdir -p $EMBEDDING_DIR/beir/$DATASET_NAME/$MODEL_NAME/

cd $CODE_DIR
export PYTHONPATH=.
cp $CHECKPOINT_DIR/t5-base-scaled/special_tokens_map.json $FINAL_PLM
cp $CHECKPOINT_DIR/t5-base-scaled/spiece.model $FINAL_PLM
cp $CHECKPOINT_DIR/t5-base-scaled/tokenizer.json $FINAL_PLM
cp $CHECKPOINT_DIR/t5-base-scaled/tokenizer_config.json $FINAL_PLM

# echo "downloading data..."
#     python $BEIR_DIR/download_data.py --dataset_name=$DATASET_NAME --save_dir=$COLLECTION_DIR

# echo "preprocessing data..."
#     python $BEIR_DIR/preprocess_data.py --dataset_name=$DATASET_NAME --save_dir=$COLLECTION_DIR

# python $BEIR_DIR/combine_all_grounding1.py  \
#     --msmarco_path $COLLECTION_DIR/$DATASET_NAME/corpus.tsv  \
#     --wiki_path $HOME_DIR/data/wiki/corpus.tsv  \
#     --mesh_path $HOME_DIR/data/mesh/corpus.tsv  \
#     --save_to_path $HOME_DIR/data/combined/corpus_$DATASET_NAME$.tsv

echo "building index..."
python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
    $PROJECT_DIR/project_lib/driver/build_index.py  \
    --output_dir $EMBEDDING_DIR/beir/$DATASET_NAME/$MODEL_NAME \
    --model_name_or_path $GROUNDING_PLM \
    --per_device_eval_batch_size $LARGER_PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --corpus_path $HOME_DIR/data/combined/corpus_$DATASET_NAME$.tsv  \
    --use_t5_decoder  \
    --doc_template "Corpus: <name> <title> <text>"  \
    --doc_column_names id,title,text,name  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --fp16  \
    --dataloader_num_workers 1


echo "retrieving..."
python $PROJECT_DIR/project_lib/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/beir/$DATASET_NAME/$MODEL_NAME \
    --model_name_or_path $GROUNDING_PLM  \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $COLLECTION_DIR/$DATASET_NAME/queries.test.tsv  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test.trec  \
    --dataloader_num_workers 1

if [ "$DATASET_NAME" == "arguana" ] || [ "$DATASET_NAME" == "quora" ]; then
python $BEIR_DIR/remove_same_qd.py \
    --trec $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test.trec  \
    --save_to $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_cleaned.trec
fi

if [ "$DATASET_NAME" == "arguana" ] || [ "$DATASET_NAME" == "quora" ]; then
    echo "augmenting test queries using GROUNDING_PLM $GROUNDING_PLM..."
    python $BEIR_DIR/mix_corpus_combined.py \
    --trec_file $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_cleaned.trec  \
    --queries $COLLECTION_DIR/$DATASET_NAME/queries.test.tsv  \
    --collection $HOME_DIR/data/combined/corpus_$DATASET_NAME$.tsv \
    --save_to $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/queries.test.augmented.json  \
    --prf_k $GROUND_DOC_NUM \
    --threshold 5
else
    echo "augmenting test queries using GROUNDING_PLM $GROUNDING_PLM..."
    python $BEIR_DIR/mix_corpus_combined.py \
    --trec_file $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test.trec  \
    --queries $COLLECTION_DIR/$DATASET_NAME/queries.test.tsv  \
    --collection $HOME_DIR/data/combined/corpus_$DATASET_NAME$.tsv \
    --save_to $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/queries.test.augmented.json  \
    --prf_k $GROUND_DOC_NUM \
    --threshold 5
fi

# if [ ! -d "$EMBEDDING_DIR/beir/$DATASET_NAME/$FINAL_MODEL_NAME" ]; then
echo "building index..."
python -m torch.distributed.launch --nproc_per_node=2 --master_port 19286 \
    $PROJECT_DIR/project_lib/driver/build_index.py  \
    --output_dir $EMBEDDING_DIR/beir/$DATASET_NAME/$FINAL_MODEL_NAME \
    --model_name_or_path $FINAL_PLM \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --corpus_path $COLLECTION_DIR/$DATASET_NAME/corpus.tsv  \
    --use_t5_decoder  \
    --doc_template "Title: <title> Text: <text>"  \
    --doc_column_names id,title,text  \
    --q_max_len $Q_MAX_LEN  \
    --p_max_len $P_MAX_LEN  \
    --fp16  \
    --dataloader_num_workers 1
# fi

echo "retrieving again..."
python $PROJECT_DIR/project_lib/driver/retrieve.py  \
    --output_dir $EMBEDDING_DIR/beir/$DATASET_NAME/$FINAL_MODEL_NAME \
    --model_name_or_path $FINAL_PLM \
    --per_device_eval_batch_size $PER_DEVICE_INFERENCE_BATCH_SIZE  \
    --query_path $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/queries.test.augmented.json  \
    --use_t5_decoder  \
    --query_template "<text>"  \
    --query_column_names id,text,grounds,ground_ids  \
    --use_ground True  \
    --q_max_len $Q_MAX_LEN  \
    --fp16  \
    --trec_save_path $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_augmented.trec  \
    --dataloader_num_workers 1  \
    --ground_passage_num $GROUND_DOC_NUM

if [ "$DATASET_NAME" == "arguana" ] || [ "$DATASET_NAME" == "quora" ]; then
python $BEIR_DIR/remove_same_qd.py \
    --trec $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_augmented.trec  \
    --save_to $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_augmented_cleaned.trec
fi

echo $DATASET_NAME
echo $MODEL_NAME

if [ "$DATASET_NAME" == "arguana" ] || [ "$DATASET_NAME" == "quora" ]; then
echo "scoring grounding model..."
$EVAL_DIR/trec_eval -c -mrecip_rank.10 -mndcg_cut.10 $COLLECTION_DIR/$DATASET_NAME/qrel.test.trec $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_cleaned.trec

echo "scoring final model..."
$EVAL_DIR/trec_eval -c -mrecip_rank.10 -mndcg_cut.10 $COLLECTION_DIR/$DATASET_NAME/qrel.test.trec $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_augmented_cleaned.trec
else
echo "scoring grounding model..."
$EVAL_DIR/trec_eval -c -mrecip_rank.10 -mndcg_cut.10 $COLLECTION_DIR/$DATASET_NAME/qrel.test.trec $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test.trec

echo "scoring final model..."
$EVAL_DIR/trec_eval -c -mrecip_rank.10 -mndcg_cut.10 $COLLECTION_DIR/$DATASET_NAME/qrel.test.trec $RESULT_DIR/beir/$DATASET_NAME/$MODEL_NAME/test_augmented.trec
fi

