module load xxx

clear

conda activate mask_model

batch=24
gpus=8
max_step=90948 
lr=1e-4
finetune_path=your_fine_tune_path_here.ckpt

srun python -u -m main \
  loader.batch_size=$batch \
  loader.eval_batch_size=$batch \
  loader.global_batch_size=$((batch*gpus)) \
  model=medium \
  data=openwebtext-split \
  algo=mdlm \
  training.finetune_path=$finetune_path \
  wandb.name=mdlm-owt-mid-$(date +%Y%m%d-%H%M%S) \
  model.length=1024 \
  trainer.devices=$gpus \
  strategy=ddp \
  eval.compute_generative_perplexity=False \
  sampling.steps=1024 \
  trainer.max_steps=$max_step \
  optim.lr=$lr \
  +wandb.offline=false \
