DATA_DIR=xx
OUTPUT_DIR=xx
## *************************************
## Custion Input
export residual_num_layer=3
export model_name_or_path=gpt-j-6b
export train_job_name=xx
infer_job_name=inference-beir.${train_job_name}
## *************************************
## *************************************
## Train Setup
export TOT_CUDA="0,1,2,3"
CUDAs=(${TOT_CUDA//,/ })
CUDA_NUM=${#CUDAs[@]}
PORT="1234"
export FAISS_CUDA="0,1,2,3"
eval_batch_size=16
SplitNum=5
## *************************************

dataset_name_list=(trec-covid nfcorpus fiqa arguana scifact scidocs)

# ## **************************************************************************
# ## Starting Inference
# ## **************************************************************************
for dataset_name in ${dataset_name_list[@]}
do
    mkdir -p ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/query
    mkdir -p ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/corpus
    
    export q_max_len=64
    export p_max_len=128
    export remove_identical_qid_docid=0
    
    if [ ${dataset_name} == arguana ]
    then
        export q_max_len=128
        export remove_identical_qid_docid=1
        
    elif [ ${dataset_name} == quora ]
    then
        export remove_identical_qid_docid=1
        
    elif [ ${dataset_name} == scifact ] || [ ${dataset_name} == trec-news ] || [ ${dataset_name} == robust04 ]
    then
        export p_max_len=256
        export eval_batch_size=8
    fi

    echo infer ${dataset_name} ...
    echo q_max_len: ${q_max_len}
    echo p_max_len: ${p_max_len}
    echo remove_identical_qid_docid: ${remove_identical_qid_docid}
    
    deepspeed --include localhost:${TOT_CUDA} --master_port ${PORT} ../ancetele/infer_beir.py \
    --deepspeed ./deepspeed_configs/ds_config_zero3.json \
    --output_dir ${OUTPUT_DIR}/${infer_job_name}/${dataset_name} \
    --model_name_or_path ${OUTPUT_DIR}/${model_name_or_path} \
    --residual_encoder_name_or_path ${OUTPUT_DIR}/${train_job_name}/residual_encoder.ckpt \
    --fp16 \
    --q_max_len ${q_max_len} \
    --p_max_len ${p_max_len} \
    --per_device_eval_batch_size ${eval_batch_size} \
    --dataloader_num_workers 1 \
    --eval_dir ${DATA_DIR}/${dataset_name} \
    --sub_split_num ${SplitNum} \
    --residual_num_layer ${residual_num_layer} \
    --cosine_scale 1 \

    ## *************************************
    ## Search Dev (GPU/CPU)
    ## *************************************
    CUDA_VISIBLE_DEVICES=${FAISS_CUDA} python ../ancetele/faiss_retriever/do_retrieval.py \
    --query_reps ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/query/qry.pt \
    --passage_reps ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/corpus/'*.pt' \
    --index_num ${SplitNum} \
    --batch_size ${eval_batch_size} \
    --save_text \
    --depth 100 \
    --save_ranking_to ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/test.rank.tsv \
    --use_gpu \

    python ../scripts/convert_result_to_trec.py \
    --input ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/test.rank.tsv \
    --remove_identical_qid_docid ${remove_identical_qid_docid} \
    
    python ../scripts/eval_beir_ndcg.py \
    --dataset_name ${dataset_name} \
    --qrels_path ${DATA_DIR}/${dataset_name}/qrels/test.tsv \
    --trec_path ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/test.rank.tsv.teIn \
    --log_path ${OUTPUT_DIR}/${infer_job_name}/res.tsv

    # ## *************************************
    # ## (5) delete embedding
    # ## *************************************
    if [ -s ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/test.rank.tsv ]
    then
        echo "Successfully saved trec file! Delete embedding files :) "
        rm -rf ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/query
        rm -rf ${OUTPUT_DIR}/${infer_job_name}/${dataset_name}/corpus
    else
        echo "There are some troubles in saving trec file ..."
    fi
    
done