#!/bin/bash
#SBATCH --job-name=4N@DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B9@L_p0
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --time=48:00:00
#SBATCH --time-min=48:00:00
#SBATCH --partition=pool0_datahall_a
#SBATCH --account=d$mask$_base
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --dependency=singleton
#SBATCH --output=/project/$mask$/$mask$/git/$mask$/ckpts/DAPOv3/DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B/%j@4N@DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B9@L_p0/job.out
#SBATCH --error=/project/$mask$/$mask$/git/$mask$/ckpts/DAPOv3/DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B/%j@4N@DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B9@L_p0/job.err
#SBATCH --exclude=pool0-0023,pool0-0006,pool0-0028,pool0-0002
#SBATCH --requeue

export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export SLURM_LOG_DIR=/project/$mask$/$mask$/git/$mask$/ckpts/DAPOv3/DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B/$SLURM_JOB_ID@4N@DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B9@L_p0

echo "Nodelist="$SLURM_JOB_NODELIST
echo "MASTER_PORT="$MASTER_PORT
echo "WORLD_SIZE="$WORLD_SIZE
echo "MASTER_ADDR="$MASTER_ADDR
echo "SLURM_LOG_DIR="$SLURM_LOG_DIR
mkdir -p $SLURM_LOG_DIR/env
env > $SLURM_LOG_DIR/env/sbatch_env.sh

export DOCKER_PATH=/project/$mask$/$mask$/sqsh/$mask$_mcore_v0.0.7_efa_verl$mask$

# replace these information with your own
verl_workdir=/project/$mask$/$mask$/git/$mask$
WANDB_API_KEY=$mask$
WANDB_ENTITY=$mask$

# Getting the node names
nodes_array=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))

head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

# if we detect a space character in the head node IP, we'll
# convert it to an ipv4 address. This step is optional.
if [[ "$head_node_ip" == *" "* ]]; then
IFS=' ' read -ra ADDR <<<"$head_node_ip"
if [[ ${#ADDR[0]} -gt 16 ]]; then
  head_node_ip=${ADDR[1]}
else
  head_node_ip=${ADDR[0]}
fi
echo "IPV6 address detected. We split the IPV4 address as $head_node_ip"
fi

port=30006
ip_head=$head_node_ip:$port
export ip_head
echo "IP Head: $ip_head"

# make sure we set environment variables before Ray initialization
export VLLM_ATTENTION_BACKEND=XFORMERS

# 设置Ray的超时时间和端口配置
export RAY_GCS_SERVER_CONNECT_TIMEOUT_SECONDS=60  # 增加超时时间到60秒
export RAY_worker_ports="20000-20050"
export RAY_WORKER_PORT_START=20000
export RAY_WORKER_PORT_END=20050
export RAY_DASHBOARD_PORT=40000
export RAY_DASHBOARD_AGENT_LISTEN_PORT=40011
export RAY_DASHBOARD_AGENT_GRPC_PORT=40022
export RAY_metrics_export_port=40033
export RAY_runtime_env_agent_port=40044
export RAY_PORT=40055
export RAY_GCS_SERVER_MAX_RETRIES=5  # 增加重试次数
export RAY_GCS_SERVER_RETRY_INTERVAL_SECONDS=5  # 设置重试间隔
# export RAY_memory_usage_threshold=0.98
# export RAY_memory_monitor_refresh_ms=0


BASE_COMMAND="LOCAL_RANK=\$SLURM_LOCALID; 
              env >> \$SLURM_LOG_DIR/env/\$SLURM_PROCID.sh 2>&1; 
              ulimit -c 0; 
              cd ${verl_workdir}; pip install -e .; 
              export WANDB_API_KEY=${WANDB_API_KEY}; 
              export WANDB_ENTITY=${WANDB_ENTITY}; 
              export TORCH_NCCL_ENABLE_MONITORING=0; 
              export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=1800; 
              export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True; 
              export IMAGINAIRE_OUTPUT_ROOT=/project/$mask$/$mask$/$mask$-output; 
              export LD_LIBRARY_PATH=/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/lib:\$LD_LIBRARY_PATH;"

echo "Starting HEAD at $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" \
    --container-image=${DOCKER_PATH} --container-mounts /project/$mask$/$mask$:/project/$mask$/$mask$:rw,/project/$mask$:/project/$mask$:rw,/home/$mask$:/home/$mask$:rw \
    bash -c "${BASE_COMMAND} ray start --head --node-ip-address=$head_node_ip --port=$port --num-gpus 8 --block >> \${SLURM_LOG_DIR}/log_\${SLURM_PROCID}.txt 2>&1" &
# optional, though may be useful in certain versions of Ray < 1.0.
sleep 20

# number of nodes other than the head node
worker_num=$((SLURM_JOB_NUM_NODES - 1))

for ((i = 1; i <= worker_num; i++)); do
    node_i=${nodes_array[$i]}
    min_worker_port=$((23000 + (i - 1)*100))
    max_worker_port=$((min_worker_port + 99))
    echo "Starting WORKER $i at $node_i"
    srun --nodes=1 --ntasks=1 -w "$node_i" \
        --container-image=${DOCKER_PATH} --container-mounts /project/$mask$/$mask$:/project/$mask$/$mask$:rw,/project/$mask$:/project/$mask$:rw,/home/$mask$:/home/$mask$:rw \
        bash -c "${BASE_COMMAND} ray start --address $ip_head --min-worker-port=$min_worker_port --max-worker-port=$max_worker_port --num-gpus 8 --block  >> \${SLURM_LOG_DIR}/log_\${SLURM_PROCID}.txt 2>&1" &
    sleep 5
done

project_name=DAPOv3
exp_name=DAPO_v359_redall2_token_r_s1.0c-1.0p0.5norm1_ep1_bugs1fo0fg1cgs_0.0lr1e-06maxl4096negw1.0p512r16mb512ext1.0N4-Qwen2.5-Math-7B

adv_estimator=grpo

kl_coef=0.0
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

enable_filter_groups=1
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=512
if [ "$enable_filter_groups" = "0" ]; then
    gen_prompt_bsz=$((train_prompt_bsz))
else
    gen_prompt_bsz=$((train_prompt_bsz * 3))
fi

n_resp_per_prompt=16
train_prompt_mini_bsz=32

use_token_level_loss=True

# Paths
MODEL_PATH=${MODEL_PATH:-"${verl_workdir}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${verl_workdir}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${verl_workdir}/data/dapo-math-17k_10boxed.parquet"}  # update on 0405 add answer back
# TEST_FILE=${TEST_FILE:-"${verl_workdir}/data/aime-2024.parquet"}

# Algorithm
## Train
max_prompt_length=$((512))
max_response_length=$((4096-512))
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout


neg_weight=1.0
if [ "$neg_weight" = "-1.0" ]; then
    enable_overlong_buffer=False
    overlong_buffer_len=0
    overlong_penalty_factor=0.0
else
    enable_overlong_buffer=False
    overlong_buffer_len=0
    overlong_penalty_factor=0.0
fi

# Performance Related Parameter
sp_size=1
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
gen_tp=4
actor_offload=False
ref_offload=True
sleep 30

# 改为动态等待
# echo "等待节点就绪..."
# while [[ $(ray status --address $ip_head | grep -c "Healthy") -lt $SLURM_JOB_NUM_NODES ]]; do
#     echo "已就绪节点: $(ray status --address $ip_head | grep -c "Healthy")/$SLURM_JOB_NUM_NODES"
#     sleep 30
# done
# echo "所有节点已就绪！"

# data.val_files=\"['/project/$mask$/$mask$/git/$mask$/data/aime-2024.parquet', '/project/$mask$/$mask$/git/$mask$/data/math500.parquet']\" \
PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "$head_node" \
    --container-image=${DOCKER_PATH} --container-mounts /project/$mask$/$mask$:/project/$mask$/$mask$:rw,/project/$mask$:/project/$mask$:rw,/home/$mask$:/home/$mask$:rw \
    bash -c "${BASE_COMMAND}
    python3 -m verl.trainer.main_ppo \
    data.train_files=${TRAIN_FILE} \
    data.val_files=\"['/project/$mask$/$mask$/git/$mask$/data/aime-2024-boxed_w_answer.parquet', '/project/$mask$/$mask$/git/$mask$/data/math500_boxed.parquet', '/project/$mask$/$mask$/git/$mask$/data/prime_math500_boxed.parquet', '/project/$mask$/$mask$/git/$mask$/data/prime_minerva_math.parquet', '/project/$mask$/$mask$/git/$mask$/data/minerva_math.parquet', '/project/$mask$/$mask$/git/$mask$/data/prime_olympiadbench.parquet', '/project/$mask$/$mask$/git/$mask$/data/olympiadbench.parquet', '/project/$mask$/$mask$/git/$mask$/data/aime2025_32_dapo_boxed_w_answer.parquet', '/project/$mask$/$mask$/git/$mask$/data/amc2023_32_dapo_boxed_w_answer.parquet']\" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.filter_overlong_prompts=True \
    data.gen_batch_size=${gen_prompt_bsz} \
    data.train_batch_size=${train_prompt_bsz} \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    algorithm.filter_groups.enable=${enable_filter_groups} \
    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
    algorithm.filter_groups.metric=${filter_groups_metric} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.negw=1.0 \
    actor_rollout_ref.reduction=all2 \
    actor_rollout_ref.actor.ratio_type=token \
    actor_rollout_ref.actor.ppo_epochs=1 \
    actor_rollout_ref.c_guidance=0.0 \
    actor_rollout_ref.overlong_filter=0 \
    actor_rollout_ref.bugged_dynamic_scale=1 \
    actor_rollout_ref.soft_label=0.0 \
    actor_rollout_ref.clamp_uni=-1.0 \
    actor_rollout_ref.clamp_positive=0.5 \
    actor_rollout_ref.normalize=1 \
    actor_rollout_ref.ratio_scale=1.0 \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.model.path=${MODEL_PATH} \
    +actor_rollout_ref.model.override_config.attention_dropout=0. \
    +actor_rollout_ref.model.override_config.embd_pdrop=0. \
    +actor_rollout_ref.model.override_config.resid_pdrop=0. \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-06 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.70 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
    actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \
    actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
    custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \
    custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \
    custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
    trainer.logger=['console','wandb'] \
    trainer.project_name=${project_name} \
    trainer.experiment_name=${exp_name} \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=${SLURM_NNODES} \
    +trainer.val_before_train=False \
    trainer.test_freq=10.0 \
    trainer.save_freq=20.0 \
    trainer.total_epochs=100 \
    trainer.default_local_dir=${CKPTS_DIR} \
    trainer.resume_mode=auto \
      >> \${SLURM_LOG_DIR}/log_\${SLURM_PROCID}.txt 2>&1"