#!/usr/bin/env bash
set -euo pipefail
set -x

if [ "$#" -lt 3 ]; then
  echo "Usage: $0 <nproc_per_node> <save_path> <hf_model_path> [train_parquet] [val_parquet]"
  exit 1
fi

NPROC_PER_NODE="$1"
SAVE_PATH="$2"
MODEL_PATH="$3"
TRAIN_PARQUET="${4:-$HOME/data/gsm8k/train.parquet}"
VAL_PARQUET="${5:-$HOME/data/gsm8k/test.parquet}"

export HYDRA_FULL_ERROR=1
ulimit -n 65535

PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

torchrun --nproc_per_node="${NPROC_PER_NODE}" -m verl.trainer.main_ppo \
  --config-path="${CONFIG_PATH}" \
  --config-name="ver_k_retry_multiturn_grpo_w_interaction" \
  algorithm.adv_estimator=grpo \
  data.train_files="${TRAIN_PARQUET}" \
  data.val_files="${VAL_PARQUET}" \
  data.train_batch_size=64 \
  data.max_prompt_length=512 \
  data.max_response_length=2048 \
  data.return_raw_chat=true \
  actor_rollout_ref.model.path="${MODEL_PATH}" \
  actor_rollout_ref.rollout.name=sglang \
  actor_rollout_ref.rollout.multi_turn.enable=true \
  actor_rollout_ref.rollout.multi_turn.interaction_config_path="${PROJECT_DIR}/examples/sglang_multiturn/config/interaction_config/ver_k_retry_interaction_config.yaml" \
  actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5 \
  actor_rollout_ref.rollout.multi_turn.max_user_turns=5 \
  trainer.default_local_dir="${SAVE_PATH}" \
  trainer.total_epochs=1 \
  trainer.logger='[console]'
