# export HF_ENDPOINT=https://hf-mirror.com
# trl dpo --model_name_or_path /mnt/data/***/hf_models/gpt2 --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir tmp 

wandb disabled

if command -v nvidia-smi &> /dev/null; then
    NUM_GPUS=$(nvidia-smi -L | wc -l)
    echo "发现 ${NUM_GPUS} 块 GPU。"
else
    NUM_GPUS=-1
    echo "nvidia-smi 命令不可用，请确保 NVIDIA 驱动已安装。"
fi

# export HF_ENDPOINT=https://hf-mirror.com
# accelerate launch --config_file=accelerate_configs/multi_gpu.yaml --num_processes ${NUM_GPUS} dpo.py \
#     --dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
#     --model_name_or_path=/mnt/data/***/hf_models/gpt2 \
#     --per_device_train_batch_size 4 \
#     --learning_rate 1e-3 \
#     --gradient_accumulation_steps 1 \
#     --logging_steps 3 \
#     --eval_steps 10 \
#     --output_dir=tmp \
#     --sanity_check \
#     --warmup_steps 5 \
#     --report_to wandb \
#     --bf16 \
#     --logging_first_step \
#     --no_remove_unused_columns

# model_name=gpt2
# model_path=/mnt/data/***/hf_models/gpt2
model_name=qwen
model_path=/mnt/data/***/hf_models/Qwen2-7B-Instruct

# dataset_name=toy
dataset_name=weak
# dataset_name=goldn
dataset_path=/mnt/data/***/multiagent_doc2graph/train_router/data/${dataset_name}
# dataset_path=/mnt/data/***/multiagent_doc2graph/train_router/goldn

# dataset_name=trlstyle
# dataset_path=trl-internal-testing/hh-rlhf-helpful-base-trl-style
# dataset_path=/mnt/data/***/multiagent_doc2graph/train_router/hh-rlhf-helpful-base-trl-style


# config_file=multi_gpu
config_file=deepspeed_zero2

tag=${dataset_name}_${model_name}_${config_file}

echo "dataset_name ${dataset_name}, tag ${tag}"

accelerate launch --config_file accelerate_configs/${config_file}.yaml --num_processes ${NUM_GPUS} dpo.py \
    --dataset_name ${dataset_path} \
    --model_name_or_path ${model_path} \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --learning_rate 1e-5 \
    --gradient_accumulation_steps 8 \
    --logging_steps 3 \
    --eval_steps 5 \
    --output_dir /mnt/data/***/multiagent_doc2graph/train_router/output_model/${tag} \
    --warmup_steps 5 \
    --report_to none \
    --bf16 \
    --logging_first_step \
    --max_prompt_length 512 \
    --max_length 512 \
    --no_remove_unused_columns > log/train_log/${tag}.log 2>&1

echo "Done."

# dpo_router, 2, bash -c "cd /mnt/data/***/multiagent_doc2graph/train_router && source /mnt/data/***/anaconda3_not_root/bin/activate doc2graph && bash train.sh"