set -x

checkSuccess() {
    if [[ $? != 0 ]]; then
        echo "FAILED $1"
        exit 1
    fi
}

# 定义输出路径
OUTPUT_PATH=./datas/ultrafeedback_binarized/llama-3.1-8b-instruct-train-prefs_generations

# 定义 batch inference 命令
read -r -d '' batch_infer_commands <<EOF
openrlhf.cli.batch_inference \
    --eval_task self_multi_generate \
    --pretrain ./checkpoint/llama-3.1-8b-instruct-sft \
    --bf16 \
    --prompt_max_len 4096 \
    --max_new_tokens 1024 \
    --dataset HuggingFaceH4/ultrafeedback_binarized \
    --input_key prompt \
    --apply_chat_template \
    --zero_stage 0 \
    --micro_batch_size 2 \
    --output_path $OUTPUT_PATH \
    --dataset_split train_prefs
EOF

# 运行 batch inference
deepspeed --module $batch_infer_commands
# CUDA_VISIBLE_DEVICES=0 python -m $batch_infer_commands
checkSuccess "Batch Inference"
