#!/bin/bash
set -xe

source uv_verl/bin/activate

ulimit -n 65535

# Working directory
WORK_DIR=$(dirname $(dirname $(realpath $0)))
cd ${WORK_DIR}
export PYTHONPATH="${PYTHONPATH}:${WORK_DIR}"

# ==============================================================================
# Model Configuration
# ==============================================================================
model_name="Qwen2.5-Math-7B"
MODEL_PATH="Qwen/Qwen2.5-Math-7B"

apo_method="apo_ratio"
additional_name=${1:-""}
project_name="APO"
experiment_name="APO_${model_name}_${apo_method}${additional_name}"
ckpts_dir="/data1/wty/outputs/${project_name}/${experiment_name}"
mkdir -p "${ckpts_dir}/logs"

NGPUS=8
sp_size=1
tp_size=1

train_prompt_bsz=512
train_prompt_mini_bsz=32
use_dynamic_bsz=True

max_prompt_length=512
max_response_length=2048
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))

temperature=1.0
top_p=1.0
top_k=-1
n_samples=8
adv_estimator=grpo


# APO Hyperparameters
apo_topk=8                                                             
apo_exclude_sampled=true                                                           
apo_push_coeff=1.05     
apo_pull_coeff=0.1                

clip_ratio=0.2
clip_ratio_low=0.2
clip_ratio_high=0.2
clip_ratio_c=10.0
use_kl_loss=False
kl_coef=0.0

# Loss aggregation mode
loss_agg_mode="token-mean"

train_path="data/DAPO/dapo-math-17k_dedup.parquet"
aime_test_path="data/DAPO/offline_eval/math__aime_repeated_8x_240.parquet"
math_test_path="data/DAPO/offline_eval/math__math_500.parquet"
aime25_test_path="data/DAPO/offline_eval/math__aime2025_2025.parquet"
minerva_test_path="data/DAPO/offline_eval/math__minerva_math_2025_processed.parquet"

python3 -m verl.trainer.main_ppo \
    data.train_files="['$train_path']" \
    data.val_files="['$aime_test_path', '$math_test_path', '$aime25_test_path', '$minerva_test_path']" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.norm_adv_by_std_in_grpo=False \
    algorithm.use_kl_in_reward=False \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.rollout.n=${n_samples} \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k=${top_k} \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${tp_size} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_seqs=8 \
    actor_rollout_ref.rollout.max_model_len=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.actor.clip_ratio=${clip_ratio} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=${clip_ratio_c} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_coef} \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.shuffle=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0.0 \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=${NGPUS} \
    actor_rollout_ref.actor.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.actor.use_torch_compile=False \
    actor_rollout_ref.actor.policy_loss.loss_mode="${apo_method}" \
    actor_rollout_ref.actor.policy_loss.apo_push_coeff=${apo_push_coeff} \
    actor_rollout_ref.actor.policy_loss.apo_pull_coeff=${apo_pull_coeff} \
    actor_rollout_ref.actor.policy_loss.apo_topk=${apo_topk} \
    actor_rollout_ref.actor.policy_loss.apo_exclude_sampled=${apo_exclude_sampled} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    reward_model.reward_manager=naive \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${experiment_name}" \
    trainer.n_gpus_per_node=${NGPUS} \
    trainer.nnodes=1 \
    trainer.total_epochs=10 \
    trainer.save_freq=10 \
    trainer.max_actor_ckpt_to_keep=2 \
    trainer.test_freq=5 \
    trainer.val_before_train=True \
    trainer.default_local_dir=${ckpts_dir} \
    trainer.resume_mode="auto" \
    trainer.logger='["console", "swanlab"]' \
    trainer.log_val_generations=10