#!/usr/bin/env bash
set -xeuo pipefail

###################################################
# Please fill/modify in the following paths before running
RAY_DATA_HOME=""
MODEL_PATH=""
TRAIN_FILE=""
TEST_FILE=""
VERL_PATH=""
WANDB_KEY=""
PROJECT_NAME=""
EXP_NAME=""
###################################################

adv_estimator=grpo
loss_mode=gspo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.0003
clip_ratio_high=0.0004

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=0.0

loss_agg_mode="seq-mean-token-mean"

train_prompt_bsz=128
n_resp_per_prompt=16
train_prompt_mini_bsz=32

# Ray
NNODES=${NNODES:-4}
MASTER_PORT=${MASTER_PORT:-6379}
DASHBOARD_PORT=${DASHBOARD_PORT:-8265}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
offload=True
gen_tp=4
train_tp=4
train_pp=4
EP=4
ETP=1

# Checkpoint dir
project_name=$PROJECT_NAME
exp_name=$EXP_NAME
CKPTS_DIR="${RAY_DATA_HOME}/save/${project_name}/${exp_name}"
CHECK_DIR="${RAY_DATA_HOME}/save/${project_name}/${exp_name}/ray_temp"
mkdir -p ${CHECK_DIR}

wandb login $WANDB_KEY

echo "CURRENT RANK: $RANK"

# path for megatron save
experience_store_step=5

export VERL_DEFAULT_LOCAL_DIR=${CKPTS_DIR}
export VERL_EXPERIENCE_STORE_STEP=${experience_store_step}

if [ "$RANK" -eq "0" ]; then
    echo "[HEAD $RANK] Starting Ray head node..."

    NODE_IP=$(hostname -I | awk '{print $1}')
    echo "$NODE_IP" > ${CHECK_DIR}/head_ip.txt
    echo "[HEAD] IP is $NODE_IP, written to ${CHECK_DIR}/head_ip.txt"

    ray start --head \
        --node-ip-address=${NODE_IP} \
        --dashboard-host=0.0.0.0 \
        --dashboard-port=${DASHBOARD_PORT} \
        --port=${MASTER_PORT} \
        --include-dashboard=True \
        --num-gpus=8

    touch ${CHECK_DIR}/RANK_${RANK}.ready
    echo "[HEAD] Waiting for all ${NNODES} workers to become ready..."

    while [ $(ls ${CHECK_DIR}/RANK_*.ready 2>/dev/null | wc -l) -lt $NNODES ]; do
        sleep 3
    done

    echo "[HEAD] All workers are ready. Waiting for cluster resources..."
    sleep 10

    echo "[HEAD] All workers are ready. Showing cluster status:"
    ray status

    echo "[HEAD] Launching training script..."

    rm -f ${CHECK_DIR}/RANK_*.ready
    rm -f ${CHECK_DIR}/head_ip.txt

    cd $VERL_PATH

    python3 -m verl.trainer.main_ppo \
        --config-path=./config \
        --config-name='ppo_megatron_trainer.yaml' \
        data.train_files="${TRAIN_FILE}" \
        data.val_files="${TEST_FILE}" \
        data.prompt_key=prompt \
        data.truncation='left' \
        data.max_prompt_length=${max_prompt_length} \
        data.max_response_length=${max_response_length} \
        data.train_batch_size=${train_prompt_bsz} \
        actor_rollout_ref.nccl_timeout=3600 \
        actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
        actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
        algorithm.adv_estimator=${adv_estimator} \
        algorithm.use_kl_in_reward=${use_kl_in_reward} \
        algorithm.kl_ctrl.kl_coef=${kl_coef} \
        actor_rollout_ref.model.use_fused_kernels=True \
        actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
        actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
        actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
        actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
        actor_rollout_ref.actor.clip_ratio_c=10.0 \
        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
        actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
        actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
        actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
        actor_rollout_ref.model.path="${MODEL_PATH}" \
        +actor_rollout_ref.model.enable_gradient_checkpointing=True \
        actor_rollout_ref.actor.optim.lr=1e-6 \
        actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
        actor_rollout_ref.actor.optim.weight_decay=0.1 \
        actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
        actor_rollout_ref.actor.megatron.param_offload=${offload} \
        actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
        actor_rollout_ref.actor.megatron.grad_offload=${offload} \
        actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
        actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
        actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
        actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
        actor_rollout_ref.actor.entropy_coeff=0 \
        actor_rollout_ref.actor.optim.clip_grad=1.0 \
        actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
        actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
        actor_rollout_ref.rollout.enable_chunked_prefill=True \
        actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
        actor_rollout_ref.rollout.temperature=${temperature} \
        actor_rollout_ref.rollout.top_p=${top_p} \
        actor_rollout_ref.rollout.top_k=${top_k} \
        actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
        actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
        actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
        actor_rollout_ref.rollout.val_kwargs.do_sample=True \
        actor_rollout_ref.rollout.val_kwargs.n=2 \
        actor_rollout_ref.rollout.name=vllm \
        actor_rollout_ref.rollout.enforce_eager=True \
        actor_rollout_ref.rollout.free_cache_engine=True \
        actor_rollout_ref.rollout.calculate_log_probs=True \
        actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
        actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
        actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
        actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
        actor_rollout_ref.ref.megatron.param_offload=${offload} \
        actor_rollout_ref.actor.megatron.use_mbridge=True \
        +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
        +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
        +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=11 \
        +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=11 \
        reward_model.reward_manager=dapo \
        +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
        +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
        +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
        +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
        +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
        trainer.logger=['wandb'] \
        trainer.project_name="${project_name}" \
        trainer.experiment_name="${exp_name}" \
        trainer.n_gpus_per_node=8 \
        trainer.nnodes="${NNODES}" \
        trainer.val_before_train=False \
        trainer.test_freq=5 \
        trainer.save_freq=25 \
        trainer.max_actor_ckpt_to_keep=1 \
        trainer.total_epochs=10 \
        trainer.default_local_dir="${CKPTS_DIR}" \
        trainer.resume_mode=auto \
        trainer.log_val_generations=10 \
        +trainer.experience_store_step=${experience_store_step} \
        +trainer.valid_store=1 \
        +actor_rollout_ref.actor.gspo_drift_disturb_lower_bound=0.1 \
        +actor_rollout_ref.actor.gspo_drift_disturb_lower_weight=0 \

    echo "[HEAD] Training completed. Cleaning up..."
    ray stop -f

else
    echo "[WORKER $RANK] Waiting for head IP to appear at ${CHECK_DIR}/head_ip.txt..."
    sleep 10
    while [ ! -f ${CHECK_DIR}/head_ip.txt ]; do
        sleep 1
    done
    sleep 10

    HEAD_IP=$(cat ${CHECK_DIR}/head_ip.txt)
    echo "[WORKER $RANK] Head IP is $HEAD_IP. Starting Ray worker and connecting to head..."

    ray start --address="${HEAD_IP}:${MASTER_PORT}" --num-gpus=8

    touch ${CHECK_DIR}/RANK_${RANK}.ready
    echo "[WORKER $RANK] Worker is ready and connected to head at ${HEAD_IP}:${MASTER_PORT}."

    sleep infinity
fi