set -x

MODEL_PATH="path/to/model" 
TRAIN_DATA="path/to/benchmark_chartbench.parquet"
VAL_DATA="path/to/benchmark_chartbench.parquet"
FORMAT_PROMPT="./examples/format_prompt/decision.jinja" # cot, code, decision
REWARD_FUNCTION="./examples/reward_function/decision.py:compute_score" # cot, code, decision
VAL_OUTPUT_FILE='path/to/log/name.log' 
NUM_GPUS=8

python3 -m verl.trainer.main \
    config=examples/config.yaml \
    worker.actor.model.model_path=${MODEL_PATH} \
    worker.rollout.gpu_memory_utilization=0.7 \
    worker.rollout.seed=42 \
    data.train_files=${TRAIN_DATA} \
    data.val_files=${VAL_DATA} \
    data.prompt_key=prompt \
    data.answer_key=ground_truth \
    data.image_key=images \
    data.format_prompt=${FORMAT_PROMPT} \
    worker.reward.reward_function=${REWARD_FUNCTION} \
    worker.rollout.tensor_parallel_size=1 \
    trainer.val_only=true \
    trainer.n_gpus_per_node=${NUM_GPUS} \
    data.val_batch_size=512\
    data.rollout_batch_size=256 \
    data.max_prompt_length=5120 \
    data.max_response_length=3072 \
    trainer.val_generations_to_log=1000 \
    trainer.logger=['console'] > ${VAL_OUTPUT_FILE}