LR=1e-6
EPOCH=1

CROP=0.3

DATA_DIR=./datasets/uground_21k_transformed
MODEL_DIR=

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
    --main_process_port 29500 \
    --config_file configs/zero3.yaml \
    train_consistent.py \
    --dataset_name $DATA_DIR \
    --model_name_or_path $MODEL_DIR \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --eval_strategy steps \
    --eval_steps 200 \
    --bf16 \
    --output_dir output/epoch${EPOCH}-LR${LR}-CROP${CROP}-batch32-sft \
    --torch_dtype bfloat16 \
    --save_steps 200 \
    --gradient_checkpointing true \
    --save_only_model true \
    --save_total_limit 3 \
    --logging_steps 1 \
    --warmup_steps 100 \
    --report_to wandb \
    --run_name SFT-cropped-continue-sft \
    --dataset_train_split train \
    --dataset_test_split test \
    --num_train_epochs ${EPOCH} \
    --learning_rate ${LR} \
    --dataloader_num_workers 4 \
    --dataloader_prefetch_factor 2 \
    --dataloader_pin_memory true \
    --crop_ratio ${CROP} 