#!/usr/bin/env bash
export PYTHONUNBUFFERED=1
export HYDRA_FULL_ERROR=1
export VLLM_ATTENTION_BACKEND=XFORMERS # vllm0.8.2不需要使用

adv_estimator=grpo

# 超长惩罚Reward
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

# 采样动态过滤参数
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
gen_ratio=3

# 使用Dr.GRPO
use_drgrpo=False
loss_agg_mode="token-mean"

# 训练参数
kl_coef=0.0
kl_loss_coef=0.0

entropy_coeff=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

train_prompt_bsz=512
n_resp_per_prompt=16
train_prompt_mini_bsz=32
use_dynamic_bsz=True
ppo_epochs=1

sp_size=2
offload=True
gen_tp=2

test_freq=20
save_freq=20
total_epochs=100

# 训练数据构造
prompt_key=question
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 14))
prompt_template_name=dapo_template # dapo_template / dpsk_zero_box
extract_type=dapo # dapo / r1

# remote rm
MATH_VERIFY_SERVER_URL=http://math-verify-server.bcloud.hb1b-h20.ml.baichuan-inc.com
export MATH_VERIFY_SERVER_URL=$MATH_VERIFY_SERVER_URL # 用于remote_verify.py, 训练数据中的data_source需要设置为"remote_verify"

# Ray  
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
WORK_DIR=`dirname $(dirname $SCRIPT_DIR)`
echo "WORK_DIR=$WORK_DIR"
cd $WORK_DIR
NNODES=8


# Paths
MODEL_PATH="/global_data/sft/zhouyijie/models/Qwen2.5-Math-7B"
TRAIN_FILE="/global_data/sft/zhouyijie/rl_datasets/dapo/dapo-math-17k.parquet"
TEST_FILE="['$WORK_DIR/recipe/dapo/test_data/math_test.parquet', '$WORK_DIR/recipe/dapo/test_data/aime2024_test.parquet', '$WORK_DIR/recipe/dapo/test_data/aime2025_test.parquet']"  
SAVE_PATH="$CHECKPOINT_SAVE"
LOG_ROLLOUT_DETAIL=false
LOG_ROLLOUT_N_LIMIT=-1
RESUME_PATH=auto # auto or resume_path if 
# wandb
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7B'
wandb_key=''

# 解析命令行参数
while [[ "$#" -gt 0 ]]; do  
    case $1 in  
        # 算法相关参数
        --kl_coef) kl_coef="$2"; shift ;;
        --kl_loss_coef) kl_loss_coef="$2"; shift ;;
        --entropy_coeff) entropy_coeff="$2"; shift ;;
        --clip_ratio_low) clip_ratio_low="$2"; shift ;;
        --clip_ratio_high) clip_ratio_high="$2"; shift ;;
        --enable_filter_groups) enable_filter_groups="$2"; shift ;;
        --filter_groups_metric) filter_groups_metric="$2"; shift ;;
        --max_num_gen_batches) max_num_gen_batches="$2"; shift ;;
        --gen_ratio) gen_ratio="$2"; shift ;;
        
        # 数据相关参数
        --train_file) TRAIN_FILE="$2"; shift ;;
        --test_file) TEST_FILE="$2"; shift ;;
        --prompt_key) prompt_key="$2"; shift ;;
        --max_prompt_length) max_prompt_length="$2"; shift ;;
        --max_response_length) max_response_length="$2"; shift ;;
        --extract_type) extract_type="$2"; shift ;;
        --prompt_template_name) prompt_template_name="$2"; shift ;;

        # 模型相关参数
        --model_path) MODEL_PATH="$2"; shift ;;
        --train_prompt_bsz) train_prompt_bsz="$2"; shift ;;
        --train_prompt_mini_bsz) train_prompt_mini_bsz="$2"; shift ;;
        --n_resp_per_prompt) n_resp_per_prompt="$2"; shift ;;
        --sp_size) sp_size="$2"; shift ;;
        --use_dynamic_bsz) use_dynamic_bsz="$2"; shift ;;
        --ppo_epochs) ppo_epochs="$2"; shift ;;
        --offload) offload="$2"; shift ;;
        --gen_tp) gen_tp="$2"; shift ;;
        
        # 训练相关参数
        --project_name) project_name="$2"; shift ;;
        --exp_name) exp_name="$2"; shift ;;
        --nnodes) NNODES="$2"; shift ;;
        --wandb_key) wandb_key="$2"; shift ;;
        --test_freq) test_freq="$2"; shift ;;
        --save_freq) save_freq="$2"; shift ;;
        --total_epochs) total_epochs="$2"; shift ;;
        --save_path) SAVE_PATH="$2"; shift ;;
        --resume_path) RESUME_PATH="$2"; shift ;;
        
        # 其他参数
        --log_rollout_detail) LOG_ROLLOUT_DETAIL="true"; shift 1;;
        --log_rollout_n_per_step) LOG_ROLLOUT_N_LIMIT=$2; shift 2;;
        --enable_overlong_buffer) enable_overlong_buffer="$2"; shift ;;
        --overlong_buffer_len) overlong_buffer_len="$2"; shift ;;
        --overlong_penalty_factor) overlong_penalty_factor="$2"; shift ;;
        --use_drgrpo) use_drgrpo="$2"; shift ;;
        *) echo "Unknown parameter: $1"; exit 1 ;;
    esac
    shift
done

export TEMPLATE_NAME=${extract_type} # 用于custom_math.py
echo "extract_type=$extract_type"

# 根据训练批次大小计算生成批次大小及其他长度参数
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
# enable_filter_groups设为False时，gen_ratio=1
if [ "$enable_filter_groups" = "False" ]; then
    gen_ratio=1
fi
gen_prompt_bsz=$((train_prompt_bsz * gen_ratio))

# 使用Dr.GRPO时，不使用token level loss
if [ "$use_drgrpo" = "True" ]; then
    loss_agg_mode="drgrpo"
fi

# wandb_key不为空则登录wandb
if [ -n "$wandb_key" ]; then
    wandb login --relogin $wandb_key
    logger="['console','wandb']"
else
    logger="['console']"
fi

ROLLOUT_SAVE_PATH=${SAVE_PATH}/rollout
if [ ! -d "$ROLLOUT_SAVE_PATH" ]; then
    mkdir -p $ROLLOUT_SAVE_PATH
fi

curr_ip=$(python $WORK_DIR/examples/get_host_ip.py)
if [ "$RANK" == "0" ]; then  
    master_ip=$curr_ip
else
    master_ip=$(python3 $WORK_DIR/examples/get_domain_ip.py $MASTER_ADDR)
fi
echo "master_ip=$master_ip"
echo "curr_ip=$curr_ip"

# Algorithm
val_top_k=-1

# Performance Related Parameter

if [ "$master_ip" = "$curr_ip" ]; then
    if [ "$LOG_ROLLOUT_DETAIL" == "true" ]; then
        echo "log_rollout_detail is activated. Now start the server"
        streamlit run $WORK_DIR/examples/sample_moniter/rl_logging_board.py --server.port 6789 > $CHECKPOINT_SAVE/streamlit.log 2>&1 &
    fi
    echo "########### run ray start ###########"  
    ray start --include-dashboard=True --head --max-worker-port 12800 --runtime-env-agent-port 20100 --dashboard-agent-grpc-port 20101 --dashboard-agent-listen-port 20102 --metrics-export-port 20103  
    sleep 50s  
    ray status  
    echo "########### run ray end ###########"  
    ray job submit --address="http://127.0.0.1:8265" \
        --runtime-env-json='{"working_dir": "'$WORK_DIR'"}' \
        -- python3 -m recipe.dapo.src.main_dapo \
        data.train_files="${TRAIN_FILE}" \
        data.val_files="${TEST_FILE}" \
        data.prompt_key="${prompt_key}" \
        data.prompt_template_name=${prompt_template_name} \
        data.truncation='left' \
        data.max_prompt_length=${max_prompt_length} \
        data.max_response_length=${max_response_length} \
        data.gen_batch_size=${gen_prompt_bsz} \
        data.train_batch_size=${train_prompt_bsz} \
        actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
        algorithm.adv_estimator=${adv_estimator} \
        algorithm.kl_ctrl.kl_coef=${kl_coef} \
        actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
        actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
        actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
        actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
        actor_rollout_ref.actor.use_drgrpo=${use_drgrpo} \
        algorithm.filter_groups.enable=${enable_filter_groups} \
        algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
        algorithm.filter_groups.metric=${filter_groups_metric} \
        actor_rollout_ref.model.use_remove_padding=True \
        actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
        actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \
        actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
        actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
        actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
        actor_rollout_ref.model.path="${MODEL_PATH}" \
        +actor_rollout_ref.model.override_config.attention_dropout=0. \
        +actor_rollout_ref.model.override_config.embd_pdrop=0. \
        +actor_rollout_ref.model.override_config.resid_pdrop=0. \
        actor_rollout_ref.model.enable_gradient_checkpointing=True \
        actor_rollout_ref.actor.optim.lr=1e-6 \
        actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
        actor_rollout_ref.actor.optim.weight_decay=0.1 \
        actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
        actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
        actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
        actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \
        actor_rollout_ref.actor.grad_clip=1.0 \
        actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
        actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
        actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
        actor_rollout_ref.rollout.enable_chunked_prefill=True \
        actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
        actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \
        actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \
        actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
        actor_rollout_ref.rollout.val_kwargs.n=1 \
        actor_rollout_ref.rollout.val_kwargs.do_sample=True \
        actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
        actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
        actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
        reward_model.reward_manager=custom \
        reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
        reward_model.overlong_buffer.len=${overlong_buffer_len} \
        reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
        trainer.logger=$logger \
        trainer.project_name="${project_name}" \
        trainer.experiment_name="${exp_name}" \
        trainer.n_gpus_per_node=8 \
        trainer.nnodes="${NNODES}" \
        trainer.val_before_train=False \
        trainer.test_freq=$test_freq \
        trainer.save_freq=$save_freq \
        trainer.remove_previous_ckpt_in_save=True \
        trainer.total_epochs=$total_epochs \
        trainer.log_rollout_detail.activate=${LOG_ROLLOUT_DETAIL} \
        trainer.log_rollout_detail.n_limit_per_step=${LOG_ROLLOUT_N_LIMIT} \
        trainer.log_rollout_detail.rollout_save_path=${ROLLOUT_SAVE_PATH} \
        trainer.default_local_dir="$SAVE_PATH" \
        trainer.resume_mode=${RESUME_PATH}  
    echo 'job done, now shutdown ray cluster'  
    ray stop --force  
else  
    sleep 20s  
    ray start --address $master_ip:6379 --block --runtime-env-agent-port 20100 --dashboard-agent-grpc-port 20101 --dashboard-agent-listen-port 20102 --metrics-export-port 20103  
fi  

sleep 3600