DATA=$1 # LF-AmazonTitles-131K, LF-WikiSeeAlso-320K
MODEL=distilbert-base-uncased

# trn_bsz are tuned on P4d.24xl (40GB GPU mem)
if [ ${DATA} == "LF-AmazonTitles-131K" ]; then
	Q_MAX_LEN=32;  P_MAX_LEN=32; TRN_BATCH_SIZE=896;
	MAX_STEPS=3000; HNM_STEPS=1000; SAVE_STEPS=1000;
	HNM_TOPK=50; LR=3e-4;
elif [ ${DATA} == "LF-WikiSeeAlso-320K" ]; then
	Q_MAX_LEN=128; P_MAX_LEN=32; TRN_BATCH_SIZE=576;
	MAX_STEPS=3000; HNM_STEPS=1000; SAVE_STEPS=1000;
	HNM_TOPK=50; LR=2e-4;
elif [ ${DATA} == "LF-Wikipedia-500K" ]; then
	Q_MAX_LEN=192; P_MAX_LEN=32; TRN_BATCH_SIZE=400;
	MAX_STEPS=28000; HNM_STEPS=7000; SAVE_STEPS=2000;
	HNM_TOPK=25; LR=2e-4;
elif [ ${DATA} == "LF-AmazonTitles-1.3M" ]; then
	Q_MAX_LEN=32;  P_MAX_LEN=32; TRN_BATCH_SIZE=896;
	MAX_STEPS=40000; HNM_STEPS=8000; SAVE_STEPS=2000;
	HNM_TOPK=50; LR=3e-4;
else
	echo "DATA=${DATA} is not support yet!"
	exit
fi
XMC_DATA_DIR="./datasets/${DATA}"
PRC_DATA_DIR="./proc_data/${DATA}"

# Fixed Hyper-parameters
HNM_TYPE=q2z;
MAX_LABEL_PER_QUERY=40;
TRN_GROUP_SIZE=3;
TRN_TEMPERATURE=0.04;
INF_TEMPERATURE=0.04;
TST_BATCH_SIZE=2048;
INFERENCE_METHOD=q2xz
INFERENCE_TOPK=200
EXP_VERSION=RAEXMC_b${TRN_BATCH_SIZE}_step${MAX_STEPS}_hnm-${HNM_STEPS}-${HNM_TOPK}

# Training
MODEL_PATH="./output/${MODEL}/${DATA}/${EXP_VERSION}"
mkdir -p ${MODEL_PATH}/model

if [ ! -f ${MODEL_PATH}/model/model.safetensors ]; then
	torchrun --nnodes 1 --nproc-per-node 8 \
		-m sup_con_xmc.build_de_hnm --fp16 \
		--model_name_or_path ${MODEL} \
		--output_dir ${MODEL_PATH}/model \
		--lbl_folder "${PRC_DATA_DIR}/lbl" \
		--trn_folder "${PRC_DATA_DIR}/trn" \
		--y_npz_path "${XMC_DATA_DIR}/Y.trn.npz" \
		--dataloader_num_workers 4 \
		--q_max_len ${Q_MAX_LEN} \
		--p_max_len ${P_MAX_LEN} \
		--per_device_train_batch_size ${TRN_BATCH_SIZE} \
		--per_device_eval_batch_size ${TST_BATCH_SIZE} \
		--train_group_size ${TRN_GROUP_SIZE} \
		--temperature ${TRN_TEMPERATURE} \
		--learning_rate ${LR} \
		--max_steps ${MAX_STEPS} \
		--hnm_steps ${HNM_STEPS} \
		--hnm_topk ${HNM_TOPK} \
		--hnm_type ${HNM_TYPE} \
		--save_steps ${SAVE_STEPS} \
		--logging_steps 20 \
		--negatives_x_device \
		--train_dual True \
		--use_q_neg True \
		--max_label_per_query ${MAX_LABEL_PER_QUERY} \
		--overwrite_output_dir \
		|& tee ${MODEL_PATH}/model/train.log
fi

# inference
INDEX_PATH="${MODEL_PATH}/indexer_${INFERENCE_METHOD}"
mkdir -p ${INDEX_PATH}
torchrun --nnodes 1 --nproc-per-node 8 \
	-m sup_con_xmc.searcher --fp16 \
	--model_name_or_path ${MODEL_PATH}/model \
	--per_device_eval_batch_size ${TST_BATCH_SIZE} \
	--trn_folder ${PRC_DATA_DIR}/trn \
	--tst_folder ${PRC_DATA_DIR}/tst \
	--lbl_folder ${PRC_DATA_DIR}/lbl \
	--inp_key_col "qid" \
	--lbl_key_col "lid" \
	--text_max_len 256 \
	--dataloader_num_workers 4 \
	--output_dir ${INDEX_PATH} \
	--overwrite_output_dir \
	--temperature ${INF_TEMPERATURE} \
	--y_npz_path ${XMC_DATA_DIR}/Y.trn.npz \
	--inference_topk ${INFERENCE_TOPK} \
	--inference_method ${INFERENCE_METHOD}

python -m sup_con_xmc.evaluate \
	-d ${DATA} \
	-t ${XMC_DATA_DIR}/Y.trn.npz \
	-y ${XMC_DATA_DIR}/Y.tst.npz \
	-p ${INDEX_PATH}/P.k-${INFERENCE_TOPK}.npz \
	-f ${XMC_DATA_DIR}/filter_labels_test.txt \
	|& tee ${INDEX_PATH}/eval_xmc.log
