module load xxx

conda activate mask_model

batch=42
gpus=8

srun python -u -m main \
  loader.batch_size=$batch \
  loader.eval_batch_size=$batch \
  loader.global_batch_size=$((batch*gpus)) \
  model=small \
  data=openwebtext-split \
  algo=mdlm \
  wandb.name=mdlm-owt-small-$(date +%Y%m%d-%H%M%S) \
  model.length=1024 \
  trainer.devices=$gpus \
  strategy=ddp \
  eval.compute_generative_perplexity=False \
  sampling.steps=1024 \
  +wandb.offline=false \
