train.py --data_path /path/to/reasoning_maze/interleave_sft_data.json  \
      --decoder_path /path/to/sketch_decoder.ckpt \
      --image_dir /path/to/reasoning_maze \
      --model_path /path/to/stage1/ckpt \
      --learning_rate 1e-4 --max_grad_norm 1.0 --num_train_epochs 5 \
      --per_device_train_batch_size 8 --gradient_accumulation_steps 2 --per_device_eval_batch_size 1 --weight_decay 0.01 \
      --ds_config ./ds_cfg.json \
      --save_steps 200 --eval_steps 100 --logging_steps 100 --resume_from_checkpoint True \
      --output_dir /path/to/output_dir/ --wandb_project project_name \
      --validation_split 0.005 --augment --freeze-backbone --text_loss_weight 0.0 \
      --sum-loss --loss_type l1