set -x

export TENSORBOARD_DIR="path/to/tensorboard_logs"

MODEL_PATH="path/to/Qwen2.5-VL-7B-Instruct"
TRAIN_DATA="path/to/cat_train.parquet"
VAL_DATA="path/to/benchmark.parquet"
EXPERIMENT_NAME="experiment_name"
FORMAT_PROMPT="./examples/format_prompt/decision.jinja"
REWARD_FUNCTION="./examples/reward_function/decision.py:compute_score"
NUM_GPUS=8

LOG_TIME=$(date +"%Y%m%d%H%M")
LOG_FILE="path/to/log/log_name_${LOG_TIME}.txt"

python3 -m verl.trainer.main \
    config=examples/config.yaml \
    data.train_files=${TRAIN_DATA} \
    data.val_files=${VAL_DATA} \
    data.rollout_batch_size=256 \
    worker.actor.global_batch_size=64 \
    data.format_prompt=${FORMAT_PROMPT} \
    worker.actor.model.model_path=${MODEL_PATH} \
    trainer.experiment_name=${EXPERIMENT_NAME} \
    trainer.n_gpus_per_node=${NUM_GPUS} \
    trainer.total_episodes=5 \
    algorithm.adv_estimator=grpo \
    data.prompt_key=prompt \
    data.answer_key=ground_truth \
    data.image_key=images \
    worker.reward.reward_function=${REWARD_FUNCTION} \
    worker.actor.optim.lr=1.0e-6 \
    worker.actor.offload.offload_params=false \
    worker.actor.offload.offload_optimizer=false \
    data.max_prompt_length=5120 \
    data.max_response_length=3072 \
    trainer.save_freq=5 \
    worker.rollout.seed=42 \
    worker.actor.fsdp.torch_dtype=bf16 \
    worker.actor.optim.strategy=adamw_bf16 \
    trainer.logger="['console','tensorboard']" \
    data.val_batch_size=512 \
    worker.rollout.tensor_parallel_size=1 \
    trainer.val_before_train=true \
    worker.actor.model.freeze_vision_tower=true \
    worker.rollout.n=5 \
    algorithm.kl_coef=0.0\
    algorithm.use_kl_loss=false \
    | tee -a "$LOG_FILE"

echo "Finish!"
