#!/usr/bin/env bash
# ------------- 环境变量 -------------
set -e
export VLLM_USE_V1=1
export NCCL_DEBUG=WARN

NUM_INFER_WORKERS=0
NUM_GPU_FOR_TRAIN=$((ARNOLD_WORKER_GPU - NUM_INFER_WORKERS))
BATCH_SIZE=8
NUM_GENERATION=4

# Arnold 自动注入
export NPROC_PER_NODE=$NUM_GPU_FOR_TRAIN
export MASTER_PORT=$ARNOLD_WORKER_0_PORT
export NNODES=$ARNOLD_WORKER_NUM
export NODE_RANK=$ARNOLD_ID
export MASTER_ADDR=$ARNOLD_WORKER_0_HOST
export LOCAL_WORLD_SIZE=$NUM_GPU_FOR_TRAIN
export WORLD_SIZE=$NUM_GPU_FOR_TRAIN


# ------------- 路径（原脚本） -------------
MODEL_PATH="ckpts/sft_ckpts/critic-sft-merged-new/v0-20250903-102221/checkpoint-1730"
TRAIN_DATA_PATH="data/grpo_train.jsonl"
VAL_DATA_PATH="data/grpo_val.jsonl"
OUTPUT_DIR="ckpts/dapo_ckpts/critic-sft-merged-new-1epoch"
LOG_DIR="logs/dapo_logs"
mkdir -p "$LOG_DIR"

# ------------- 启动训练 -------------
swift rlhf \
    --rlhf_type grpo \
    --model "$MODEL_PATH" \
    --model_type qwen2_audio \
    --external_plugins scripts/grpo/plugin.py \
    --reward_funcs external_mos_acc soft_overlong \
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset "$TRAIN_DATA_PATH" \
    --val_dataset "$VAL_DATA_PATH" \
    --num_train_epochs 1 \
    --per_device_train_batch_size "$BATCH_SIZE" \
    --per_device_eval_batch_size "$BATCH_SIZE" \
    --num_generations "$NUM_GENERATION" \
    --learning_rate 1e-6 \
    --gradient_accumulation_steps 4 \
    --save_steps 500 \
    --eval_steps 1000 \
    --save_total_limit 2 \
    --logging_steps 1 \
    --max_completion_length 512 \
    --use_vllm true \
    --vllm_mode colocate \
    --vllm_gpu_memory_utilization 0.6 \
    --vllm_max_model_len 8192 \
    --vllm_max_num_seqs 64 \
    --output_dir "$OUTPUT_DIR" \
    --warmup_ratio 0.01 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --temperature 1.0 \
    --log_completions true \
    --deepspeed zero2 \
    --save_only_model True \
    --report_to tensorboard \
    --loss_type bnpo \
    --epsilon_high 0.28 \
    --dynamic_sample true \
    --max_resample_times 3 \
    --overlong_filter true \
    --soft_cache_length 62 \
    > "$LOG_DIR/train-critic-sft-merged-new-1epoch.log" 2>&1
