export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=0,1,2,3

export TRAIN_DATA_DIR='RL_dataset'
export DEV_DATA_DIR='RL_dataset'
export BASE_MODEL='/models/huggingface.co/Qwen/Qwen2.5-3B-Instruct'

export EXPERIMENT_NAME=PilotRAG-grpo-qwen2.5-3b-instruct-2stages

# set -x
export VLLM_ATTENTION_BACKEND=XFORMERS

# Two-stage training configuration
# Stage 1: EM reward only (steps 1-20)
# Stage 2: Add efficiency reward (steps 21-40)
SWITCH_STEP=20
SEED=42

PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo_two_stages \
    seed=$SEED \
    data.train_files=$TRAIN_DATA_DIR/train.parquet \
    data.val_files=$DEV_DATA_DIR/dev.parquet \
    data.train_data_num=null \
    data.val_data_num=null \
    data.train_batch_size=256 \
    data.val_batch_size=32 \
    data.max_prompt_length=4096 \
    data.max_response_length=500 \
    data.max_start_length=2048 \
    data.max_obs_length=500 \
    data.shuffle_train_dataloader=True \
    algorithm.adv_estimator=grpo \
    actor_rollout_ref.model.path=$BASE_MODEL \
    actor_rollout_ref.model.enable_gradient_checkpointing=true \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=1 \
    actor_rollout_ref.actor.use_kl_loss=true \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size=32 \
    actor_rollout_ref.actor.fsdp_config.param_offload=true \
    actor_rollout_ref.actor.fsdp_config.grad_offload=true \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \
    actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    algorithm.no_think_rl=false \
    actor_rollout_ref.rollout.n_agent=5 \
    actor_rollout_ref.rollout.temperature=1 \
    actor_rollout_ref.actor.state_masking=true \
    trainer.logger=['console'] \
    +trainer.val_only=false \
    +trainer.val_before_train=false \
    trainer.default_hdfs_dir=null \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=10 \
    trainer.test_freq=10 \
    trainer.experiment_name=$EXPERIMENT_NAME \
    trainer.total_epochs=2 \
    trainer.total_training_steps=40 \
    trainer.resume_from_steps=0 \
    trainer.default_hdfs_dir=null \
    trainer.default_local_dir=/verl_checkpoints/$EXPERIMENT_NAME \
    max_turns=4 \
    retriever.url="http://localhost:8001/search" \
    retriever.topk=3 \
    +switch_step=$SWITCH_STEP \
    2>&1 | tee -a $EXPERIMENT_NAME.log 