
export bsz=32 

accelerate launch \
  --main_process_port 29503 \
  --num_processes=1 \
  --gpu_ids 0 \
  --mixed_precision=fp16 \
  train_tdm_demo_geometric_alignment_logval.py \
  --train_batch_size=$bsz \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --max_train_steps=10001 \
  --learning_rate=2e-05 \
  --max_grad_norm=1 \
  --cfg 4.5 \
  --total_steps 900 \
  --lr_scheduler cosine_with_restarts \
  --lr_warmup_steps 50 \
  --use_huber \
  --use_separate \
  --report_to wandb \
  --checkpoints_total_limit 10 \
  --lambda_nsp 1.0 \
  --nsp_epsilon 1e-2 \
  --checkpointing_steps 250 \
  --validation_epochs 50 \
  --validation_file "coco_50.csv" \
  --output_dir "TDM-pixart-geometric-alignment-logeval-eps-1e-2" 