#!/usr/bin/env bash
set -euo pipefail
set -x

module load cuda/12.8
module load gcc/13.2.0

export CUDA_HOME=/sw/pkgs/arc/cuda/12.8.1
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
export TORCH_CUDA_ARCH_LIST="8.9"  # L40S

export HYDRA_FULL_ERROR=1
export TOKENIZERS_PARALLELISM=false

PROJECT_DIR="${PROJECT_DIR:-/home/.../verification-k-verl/verl}"
cd "$PROJECT_DIR"
CONFIG_PATH="${CONFIG_PATH:-$PROJECT_DIR/examples/sglang_multiturn/config}"

# Ver@K / retry budget (assistant turns)
K="${K:-2}"

# Your multi-turn dataset (must include interaction_kwargs in each sample)
TRAIN_FILE="${TRAIN_FILE:-./data/math_ver_k_retry_k${K}/train.parquet}"
VAL_FILE="${VAL_FILE:-./data/math_ver_k_retry_k${K}/test.parquet}"

MODEL="${MODEL:-Qwen/Qwen3-4B}"

# GRPO group size per prompt
N="${N:-8}"

# Advantage estimator (grpo + Ver@K variants)
ADV_ESTIMATOR="${ADV_ESTIMATOR:-grpo_verk_step_reward_step_norm}"
case "$ADV_ESTIMATOR" in
  grpo|grpo_vectorized|grpo_verk_step_reward_step_norm|grpo_verk_step_reward_step_norm_reweight_future_only) ;;
  *) echo "Unsupported ADV_ESTIMATOR: $ADV_ESTIMATOR" >&2; exit 1 ;;
esac
ADV_TAG="${ADV_TAG:-$ADV_ESTIMATOR}"

# Sampling (Qwen3 recommended for thinking mode)
TEMP="${TEMP:-0.6}"
TOP_P="${TOP_P:-0.95}"
TOP_K="${TOP_K:-20}"

# Per-attempt response cap; total response budget scales ~K
MAX_PROMPT_LEN="${MAX_PROMPT_LEN:-512}"
RETRY_RESP_LEN="${RETRY_RESP_LEN:-1024}"
RESP_BUFFER="${RESP_BUFFER:-256}"
MAX_RESP_LEN="$((K * RETRY_RESP_LEN + RESP_BUFFER))"

# Slurm captures stdout/stderr; enable USE_NOHUP=1 if you want a separate log file.
USE_NOHUP="${USE_NOHUP:-0}"
if [[ "$USE_NOHUP" == "1" || "$USE_NOHUP" == "true" ]]; then
  NOHUP_DIR="${NOHUP_DIR:-$PROJECT_DIR/outputs}"
  mkdir -p "$NOHUP_DIR"
  NOHUP_FILE="${NOHUP_FILE:-$NOHUP_DIR/nohup_math_k${K}_n${N}_resp${MAX_RESP_LEN}_${ADV_TAG}.out}"
  exec >"$NOHUP_FILE" 2>&1
fi

# Batch sizes: multi-turn is much heavier than single-turn
TRAIN_BS="${TRAIN_BS:-256}"

# Checkpointing (save only at the end, and write to NFS-backed symlink)
PROJECT_NAME="${PROJECT_NAME:-verl_ver_k_retry_math}"
EXP_NAME="${EXP_NAME:-qwen3_4b_ver_k${K}_boxed_math_n${N}_resp${MAX_RESP_LEN}_${ADV_TAG}_sglang}"
CKPTS_DIR="${CKPTS_DIR:-$PROJECT_DIR/checkpoints/nfs/$PROJECT_NAME/$EXP_NAME}"
SAVE_FREQ="${SAVE_FREQ:-1000000000}"

# Interaction config (reuse existing ver_k_retry config)
INTERACTION_CFG="${INTERACTION_CFG:-$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/ver_k_retry_interaction_config.yaml}"

echo "SLURM_JOB_ID=$SLURM_JOB_ID"
echo "SLURM_JOB_GPUS=${SLURM_JOB_GPUS-}"
echo "SLURM_STEP_GPUS=${SLURM_STEP_GPUS-}"
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES-}"

nvidia-smi -L
nvidia-smi --query-gpu=index,uuid,name,compute_mode --format=csv
nvidia-smi --query-compute-apps=gpu_uuid,pid,process_name,used_memory --format=csv

export CUDA_MPS_PIPE_DIRECTORY=/tmp/mps_${SLURM_JOB_ID}
export CUDA_MPS_LOG_DIRECTORY=/tmp/mps_${SLURM_JOB_ID}
mkdir -p "$CUDA_MPS_PIPE_DIRECTORY" "$CUDA_MPS_LOG_DIRECTORY"

nvidia-cuda-mps-control -d
trap 'echo quit | nvidia-cuda-mps-control || true' EXIT

export RAY_DEDUP_LOGS=0
export RAY_BACKEND_LOG_LEVEL=debug

srun --ntasks=1 --cpus-per-task="${SLURM_CPUS_PER_TASK:-8}" --gpus=2 \
  python3 -u -m verl.trainer.main_ppo \
  --config-path="$CONFIG_PATH" \
  --config-name='ver_k_retry_multiturn_grpo_w_interaction' \
  algorithm.adv_estimator="$ADV_ESTIMATOR" \
  data.train_files="$TRAIN_FILE" \
  data.val_files="$VAL_FILE" \
  data.train_batch_size="$TRAIN_BS" \
  data.return_raw_chat=True \
  data.max_prompt_length="$MAX_PROMPT_LEN" \
  data.max_response_length="$MAX_RESP_LEN" \
  data.filter_overlong_prompts=True \
  data.truncation='error' \
  actor_rollout_ref.model.path="$MODEL" \
  actor_rollout_ref.actor.optim.lr=1e-6 \
  actor_rollout_ref.model.use_remove_padding=True \
  actor_rollout_ref.actor.use_kl_loss=True \
  actor_rollout_ref.actor.kl_loss_coef=0.001 \
  actor_rollout_ref.actor.kl_loss_type=low_var_kl \
  actor_rollout_ref.actor.entropy_coeff=0 \
  actor_rollout_ref.model.enable_gradient_checkpointing=True \
  actor_rollout_ref.actor.ppo_mini_batch_size=64 \
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
  actor_rollout_ref.actor.fsdp_config.param_offload=False \
  actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
  actor_rollout_ref.actor.use_dynamic_bsz=true \
  actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096 \
  actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=true \
  actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \
  actor_rollout_ref.ref.log_prob_use_dynamic_bsz=true \
  actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \
  actor_rollout_ref.rollout.name=sglang \
  actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
  actor_rollout_ref.rollout.gpu_memory_utilization=0.45 \
  actor_rollout_ref.rollout.dtype=bfloat16 \
  actor_rollout_ref.rollout.n="$N" \
  actor_rollout_ref.rollout.temperature="$TEMP" \
  actor_rollout_ref.rollout.top_p="$TOP_P" \
  actor_rollout_ref.rollout.top_k="$TOP_K" \
  actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent \
  actor_rollout_ref.rollout.multi_turn.enable=true \
  actor_rollout_ref.rollout.multi_turn.interaction_config_path="$INTERACTION_CFG" \
  actor_rollout_ref.rollout.multi_turn.max_assistant_turns="$K" \
  actor_rollout_ref.rollout.multi_turn.max_user_turns="$K" \
  actor_rollout_ref.rollout.multi_turn.max_assistant_response_length="$RETRY_RESP_LEN" \
  actor_rollout_ref.rollout.multi_turn.tokenization_sanity_check_mode=ignore_strippable \
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
  actor_rollout_ref.ref.fsdp_config.param_offload=True \
  algorithm.use_kl_in_reward=False \
  trainer.critic_warmup=0 \
  trainer.logger='["console","wandb"]' \
  trainer.project_name="$PROJECT_NAME" \
  trainer.experiment_name="$EXP_NAME" \
  trainer.default_local_dir="$CKPTS_DIR" \
  trainer.n_gpus_per_node=2 \
  trainer.nnodes=1 \
  trainer.save_freq="$SAVE_FREQ" \
  trainer.test_freq=5 \
  trainer.total_epochs=2 \
  "$@"
