# * Set parameters
NUM_FID_IMAGES=5000
# rollout
G_BATCH_SIZE=8
NUM_SAMPLES=50
NUM_GROUPS=3
NUM_INFERENCE_STEPS=50
NUM_BEST_SAMPLES=16
BUFFER_SIZE=$((NUM_BEST_SAMPLES * NUM_INFERENCE_STEPS * NUM_SAMPLES / NUM_PROCESSES))
# policy
P_BATCH_SIZE=8
P_STEP=1

# run
accelerate launch \
  --multi_gpu \
  --num_processes ${NUM_PROCESSES} \
  --num_machines ${NNODES} \
  --machine_rank ${NODE_RANK} \
  --main_process_ip ${MASTER_ADDR} \
  --main_process_port ${MASTER_PORT} \
  train.py \
  --multi_gpu 1 \
  --prompt_path ./dataset/imagenet1k/data_meta.json \
  --image_size 256 \
  --model_type sit \
  --pretrained_model_name_or_path SiT-XL-2-256.pt \
  --output_dir ${LOG_DIR} \
  --g_batch_size ${G_BATCH_SIZE} \
  --report_to wandb \
  --max_train_steps 50000 \
  --seed 3407 \
  --gt_fid_stats ${GT_FID_STATS} \
  --num_fid_images ${NUM_FID_IMAGES} \
  --g_batch_size ${G_BATCH_SIZE} \
  --num_samples ${NUM_SAMPLES} \
  --num_groups ${NUM_GROUPS} \
  --num_inference_steps ${NUM_INFERENCE_STEPS} \
  --buffer_size ${BUFFER_SIZE} \
  --p_step ${P_STEP} \
  --p_batch_size ${P_BATCH_SIZE} \
  --flat_rollout \
  --grpo_flag 2 \
  --global_flag 1 \
  --num_best_samples ${NUM_BEST_SAMPLES} \
  --save_interval 20 \
  --refill_interval 10 \
  --checkpointing_steps 40 \
  --model_output_dir ${MODEL_OUTPUT_DIR} \
  --kl_weight 0
