# https://percent4.github.io/NLP%EF%BC%88%E4%B8%80%E7%99%BE%E9%9B%B6%E4%BA%8C%EF%BC%89ReRank%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83%E5%AE%9E%E8%B7%B5/
# pip install -U FlagEmbedding[finetune]

DEVICE_ID=1
CUDA_HOME=/usr/local/cuda-12.4
PATH=$CUDA_HOME/bin:$PATH
LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

export CUDA_DEVICE_ORDER=PCI_BUS_ID
export WANDB_DISABLED=true

uv run -m src.utils wait-for-gpu --gpu-id $DEVICE_ID --check-interval 10 --consecutive-counts 2

# 定义所有训练参数
MODEL_NAME_OR_PATH=./ckpt/bge-reranker-v2-m3
OUTPUT_DIR=./ckpt/bge-reranker-v2-m3-0704-10
TRAIN_DATA=data/rerank_train/all.rerank.train.results.0704.15.jsonl
LEARNING_RATE=4e-5
NUM_TRAIN_EPOCHS=10  # 训练轮数
PER_DEVICE_TRAIN_BATCH_SIZE=1  # 每个设备上的训练批次大小
GRADIENT_ACCUMULATION_STEPS=8  # 梯度累积步数
TRAIN_GROUP_SIZE=16  # 训练组大小，每个正样本对应多少个负样本
QUERY_MAX_LEN=256  # 查询最大长度
PASSAGE_MAX_LEN=256  # 段落最大长度
WEIGHT_DECAY=0.01  # 权重衰减
WARMUP_RATIO=0.1  # 预热比例

# 创建输出目录并保存参数
mkdir -p $OUTPUT_DIR
cat << EOF > $OUTPUT_DIR/train_params.txt
learning_rate: $LEARNING_RATE
num_train_epochs: $NUM_TRAIN_EPOCHS
train_data: $TRAIN_DATA
per_device_train_batch_size: $PER_DEVICE_TRAIN_BATCH_SIZE
gradient_accumulation_steps: $GRADIENT_ACCUMULATION_STEPS
train_group_size: $TRAIN_GROUP_SIZE
query_max_len: $QUERY_MAX_LEN
passage_max_len: $PASSAGE_MAX_LEN
weight_decay: $WEIGHT_DECAY
warmup_ratio: $WARMUP_RATIO
EOF

CUDA_VISIBLE_DEVICES=$DEVICE_ID uv run -m torch.distributed.run  --nproc_per_node 1 \
-m FlagEmbedding.finetune.reranker.encoder_only.base \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--train_data $TRAIN_DATA \
--learning_rate $LEARNING_RATE \
--fp16 \
--num_train_epochs $NUM_TRAIN_EPOCHS \
--per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--dataloader_drop_last True \
--train_group_size $TRAIN_GROUP_SIZE \
--query_max_len $QUERY_MAX_LEN \
--passage_max_len $PASSAGE_MAX_LEN \
--weight_decay $WEIGHT_DECAY \
--save_total_limit 4 \
--warmup_ratio 0.1 \
--logging_steps 10
