type: basic
format_version: 1
maintainers: [maanug]
loggers: [stdout]
spec:
  name: "{model}_{scope}_\
         {'mcore_' if use_mcore else ''}{'te_' if use_te else ''}\
         tp{tp_size}_pp{pp_size}{'_vp'+str(vp_size) if vp_size else ''}\
         {'_resume_'+str(ckpt_format) if ckpt_resume else ''}\
         {'_'+args_meta if args_meta else ''}\
         _{platforms}_{nodes}N{gpus}G"
  model: bert
  variant: 345m
  build: mcore-pyt
  scope: mr
  nodes: 1
  gpus: 8
  platforms: dgx_a100
  use_te: False
  use_mcore: True
  vp_size: null
  extra_args: null
  args_meta: null
  micro_batch_size: 4 # MBS
  batch_size: 128 # GBS, JET schema requires 'batch_size'
  precision: bf16
  time_limit: 1200
  artifacts: {/workspace/data/bert_data: text/the_pile/bert_shard00}
  ckpt_format: torch_dist
  ckpt_resume: 0
  script: |-
    ls
    cd /workspace/megatron-lm

    ./tests/functional_tests/test_scripts/bert/pretrain_bert_distributed_test.sh \
        DATA_PATH=/workspace/data/bert_data/my-bert_00_text_sentence \
        CHECKPOINT_PATH=/workspace/checkpoints \
        TENSORBOARD_DIR={assets_dir} \
        DATA_CACHE=/workspace/data/index-cache \
        USE_TE={"1" if use_te else "0"} \
        TP_SIZE={tp_size} \
        PP_SIZE={pp_size} \
        NUM_NODES={nodes} \
        MAX_STEPS={100 if ckpt_resume else 50} \
        USE_CORE={"1" if use_mcore else "0"} \
        VP_SIZE={vp_size if vp_size is not None else '""'} \
        MBS={micro_batch_size} \
        GBS={batch_size} \
        CHECKPOINT_RESUME_TEST={ckpt_resume} \
        JOB_NAME={name} \
        ADDITIONAL_PARAMS={extra_args if extra_args is not None else '""'}
products:
  # MCore
  - {tp_size: [2], pp_size: [2], ckpt_resume: [0, 1]}
  - {tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], extra_args: ['"--spec local"'], args_meta: ["local_spec"]}
  # Non-MCore
  - {use_mcore: [False], tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--transformer-impl local"']}
  - {use_mcore: [False], tp_size: [1], pp_size: [4], vp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--transformer-impl local"']}
