export PYTHONUNBUFFERED=1
export HYDRA_FULL_ERROR=1
export VLLM_ATTENTION_BACKEND=XFORMERS

PROMPT_KEY=question
TRAIN_BATCH_SIZE=256
PPO_MINI_BATCH_SIZE=256
MAX_PROMPT_LENGTH=512
MAX_RESPONSE_LENGTH=8192
ENTROPY_COEF=0.0
LEARNING_RATE=1e-6

# reward modeling
FILTER_OVERLONG=False
FORMAT_REWARD=0.1

# reward modeling
REWARD_MANAGER=searchrl_parallel
REWARD_TYPE=em_score
VALIDATE_REWARD_TYPE=em_score
REWARD_FORMAT_TYPE=validate_format
QUERY_REPEAT_PENALTY_ENABLE=False
QUERY_REPEAT_PENALTY_FACTOR=0.1
enable_overlong_buffer=False
overlong_buffer_len=$((1024*4))
overlong_penalty_factor=0.0

APPLY_CHAT=True
VAL_BEFORE_TRAIN=True
VAL_ONLY=False 
PROMPT_TEMPLATE_NAME=multiple_answer_instruct
ACTOR_MODEL_PATH=Qwen2.5-7B-Instruct
TOOL_CALL_TAG=tool_call
TOOL_RESPONSE_TAG=tool_response
ROLLOUT_N=5
MAX_TURNS=20
MAX_TOKENS_PER_TURN=512

ROLLOUT_NAME=vllm_with_search
SEARCH_URL=http://.....
PROJECT_NAME=SearchRL
EXPERIMENT_NAME=example_case
NNODES=1
N_GPUS_PER_NODE=4
MAX_CKPT_TO_KEEP=1
SAVE_FREQ=10
TEST_FREQ=10
SAVE_INTERVAL=20
TOTAL_EPOCHS=4
WANDB_API_KEY=None
WANDB_RESUME=False
WANDB_RUNID=no_id
SAVE_PATH=YourWorkSpace/**/${EXPERIMENT_NAME}
TRAIN_FILES=YourWorkSpace/train_**.parquet
TEST_FILES=YourWorkSpace/dev_**.parquet 


while [[ $# -gt 0 ]]; do
    case "$1" in
        --prompt_key) PROMPT_KEY="$2"; shift 2;;
        --train_batch_size) TRAIN_BATCH_SIZE="$2"; shift 2;;
        --ppo_mini_batch_size) PPO_MINI_BATCH_SIZE="$2"; shift 2;;
        --max_prompt_length) MAX_PROMPT_LENGTH="$2"; shift 2;;
        --max_response_length) MAX_RESPONSE_LENGTH="$2"; shift 2;;
        --apply_chat) APPLY_CHAT="$2"; shift 2;;
        --prompt_template_name) PROMPT_TEMPLATE_NAME="$2"; shift 2;;
        --actor_model_path) ACTOR_MODEL_PATH="$2"; shift 2;;
        --reward_manager) REWARD_MANAGER="$2"; shift 2;;
        --rollout_n) ROLLOUT_N="$2"; shift 2;;
        --search_url) SEARCH_URL="$2"; shift 2;;
        --project_name) PROJECT_NAME="$2"; shift 2;;
        --experiment_name) EXPERIMENT_NAME="$2"; shift 2;;
        --nnodes) NNODES="$2"; shift 2;;
        --n_gpus_per_node) N_GPUS_PER_NODE="$2"; shift 2;;
        --save_freq) SAVE_FREQ="$2"; shift 2;;
        --test_freq) TEST_FREQ="$2"; shift 2;;
        --total_epochs) TOTAL_EPOCHS="$2"; shift 2;;
        --wandb_api_key) WANDB_API_KEY="$2"; shift 2;;
        --save_path) SAVE_PATH="$2"; shift 2;;
        --train_files) TRAIN_FILES="$2"; shift 2;;
        --test_files) TEST_FILES="$2"; shift 2;;
        *)
            echo "unknown argument '$1'" >&2
            exit 1;;
    esac
done

if [ "$WANDB_API_KEY" != "None" ]; then
    wandb login --relogin $WANDB_API_KEY
    export WANDB_DIR=${SAVE_PATH}
fi

if [ ! -d "$SAVE_PATH" ]; then
    mkdir -p $SAVE_PATH
fi

ROLLOUT_SAVE_PATH=${SAVE_PATH}/rollout
if [ ! -d "$ROLLOUT_SAVE_PATH" ]; then
    mkdir -p $ROLLOUT_SAVE_PATH
fi

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    algorithm.kl_ctrl.kl_coef=0.001 \
    algorithm.use_kl_in_reward=False \
    data.train_files="$TRAIN_FILES" \
    data.val_files="$TEST_FILES" \
    data.prompt_key=${PROMPT_KEY} \
    data.train_batch_size=${TRAIN_BATCH_SIZE} \
    data.max_prompt_length=${MAX_PROMPT_LENGTH} \
    data.max_response_length=${MAX_RESPONSE_LENGTH} \
    data.apply_chat=${APPLY_CHAT} \
    data.prompt_template_name=${PROMPT_TEMPLATE_NAME} \
    data.dataloader_workers=${DATALOADER_WORKERS} \
    actor_rollout_ref.model.path=${ACTOR_MODEL_PATH} \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${PPO_MINI_BATCH_SIZE} \
    actor_rollout_ref.actor.use_dynamic_bsz=True \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=4 \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$(((MAX_PROMPT_LENGTH+MAX_RESPONSE_LENGTH))) \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.actor.kl_loss_coef=0.0 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.actor.clip_ratio_high=0.28 \
    actor_rollout_ref.actor.entropy_coeff=${ENTROPY_COEF} \
    actor_rollout_ref.actor.checkpoint.save_interval=${SAVE_INTERVAL} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$(((MAX_PROMPT_LENGTH+MAX_RESPONSE_LENGTH))) \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=${ROLLOUT_NAME} \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.max_turns=${MAX_TURNS} \
    actor_rollout_ref.rollout.max_tokens_per_turn=${MAX_TOKENS_PER_TURN} \
    actor_rollout_ref.rollout.n=${ROLLOUT_N} \
    actor_rollout_ref.rollout.search_url=${SEARCH_URL} \
    actor_rollout_ref.rollout.max_num_batched_tokens=65536 \
    actor_rollout_ref.rollout.filter_overlong=${FILTER_OVERLONG} \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$(((MAX_PROMPT_LENGTH+MAX_RESPONSE_LENGTH))) \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    reward_model.reward_manager=${REWARD_MANAGER} \
    reward_model.reward_type=${REWARD_TYPE} \
    reward_model.validate_reward_type=${VALIDATE_REWARD_TYPE} \
    reward_model.format.type=${REWARD_FORMAT_TYPE} \
    reward_model.query_repeat_penalty.enable=${QUERY_REPEAT_PENALTY_ENABLE} \
    reward_model.query_repeat_penalty.penalty_factor=${QUERY_REPEAT_PENALTY_FACTOR} \
    reward_model.format_reward=${FORMAT_REWARD} \
    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
    reward_model.overlong_buffer.log=${enable_overlong_buffer} \
    reward_model.overlong_buffer.len=${overlong_buffer_len} \
    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
    trainer.critic_warmup=0 \
    trainer.logger="[console, wandb]" \
    trainer.project_name=${PROJECT_NAME} \
    trainer.experiment_name=${EXPERIMENT_NAME} \
    trainer.n_gpus_per_node=${N_GPUS_PER_NODE} \
    trainer.nnodes=${NNODES} \
    trainer.save_freq=${SAVE_FREQ} \
    trainer.max_actor_ckpt_to_keep=${MAX_CKPT_TO_KEEP} \
    trainer.max_critic_ckpt_to_keep=${MAX_CKPT_TO_KEEP} \
    trainer.test_freq=${TEST_FREQ} \
    trainer.total_epochs=${TOTAL_EPOCHS} \
    trainer.default_hdfs_dir=null \
    trainer.default_local_dir=${SAVE_PATH} \
    trainer.val_before_train=${VAL_BEFORE_TRAIN} \
    trainer.val_only=${VAL_ONLY} \
    trainer.resume_mode=auto \
    trainer.rollout_save_path=${ROLLOUT_SAVE_PATH} \
    wandb.resume=${WANDB_RESUME} \
    wandb.run_id=${WANDB_RUNID} \
    hydra.run.dir=${SAVE_PATH}/outputs | tee ${SAVE_PATH}/run.log