#!/bin/bash

# --- 1. 设置核心参数 ---
# 输出目录
OUT_DIR=exp/twnm_grpo_v1

# TWNM 解码器模型路径
DECODER_MODEL_PATH="<PATH_TO_TWNM>/assets/checkpoints/qwen2-audio-llm-extracted"
SFT_LORA_CKPT_PATH="<PATH_TO_TWNM>/exp/SFT2/checkpoint-1251/pytorch_model.bin"

# 空间编码器检查点路径
# SPATIAL_ENCODER_PATH="/path/to/your/spatial_encoder.ckpt"

# 训练数据文件
DATA_FILE="/data2/wl/RL_benchmark/output/benchmark/benchmark_questions.jsonl"

# DeepSpeed 配置文件
DS_CONFIG="./ds_zero2.json"


# --- 2. 设置分布式环境 ---
# GPU_NUM=$(nvidia-smi -L | wc -l)
GPU_NUM=8
NODE_NUM=1
NODE_RANK=0
MASTER_ADDR="127.0.0.1"
MASTER_PORT=32777 # 使用一个固定端口

LOG_FILE="${OUT_DIR}/training_$(date +%Y%m%d-%H%M%S).log"

# export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7

torchrun --nproc_per_node=${GPU_NUM} \
    --nnodes=${NODE_NUM} \
    --node-rank=${NODE_RANK} \
    --master_addr=${MASTER_ADDR} \
    --master_port=${MASTER_PORT} \
    train_grpo.py \
    --output_dir ${OUT_DIR} \
    --data_file ${DATA_FILE} \
    --bf16 true \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 3 \
    --num_train_epochs 5 \
    --learning_rate 1e-5 \
    --num_generations 3 \
    --logging_strategy="steps" \
    --logging_steps=1 \
    --save_strategy="steps" \
    --save_steps=1 \
    --save_total_limit=20 \
    --save_safetensors=False \
    --gradient_checkpointing true 2>&1 | tee ${LOG_FILE} || exit 1
    # --resume_from_checkpoint exp/twnm_grpo_v1/checkpoint-500 \
    