module load xxx

clear

conda activate mask_model

finetune_path=your_fine_tune_path_here.ckpt
clear

gpus=1
model_size=small  # small, medium, large

srun --ntasks-per-node=$gpus --ntasks-per-node=$gpus \
python -u -m main \
  mode=train \
  data=openwebtext-split \
  data.train_ratio=1.0 \
  data.valid_ratio=1.0 \
  training.finetune_path=$finetune_path \
  trainer.max_steps=100000 \
  trainer.precision=bf16 \
  strategy.find_unused_parameters=true \
  trainer.limit_val_batches=8 \
  sampling.predictor=ancestral_cache \
  checkpointing.resume_from_ckpt=false \
  algo=diff_instruct \
  wandb.name=reinforce-owt-$model_size-$(date +%Y%m%d-%H%M%S) \
  +wandb.offline=false \
  trainer.devices=$gpus \
  strategy=ddp \
