ENV_VARS:
  CUDA_DEVICE_MAX_CONNECTIONS: 1
  NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
  PYTORCH_CUDA_ALLOC_CONF: expandable_segments:True
  NCCL_NVLS_ENABLE: 0
  PYTHONWARNINGS: ignore
  NCCL_DEBUG: VERSION
MODEL_ARGS:
  # Distributed args
  --distributed-timeout-minutes: 60
  --tensor-model-parallel-size: 2
  --pipeline-model-parallel-size: 2
  --expert-model-parallel-size: 4
  --context-parallel-size: 1
  --expert-tensor-parallel-size: 1
  --use-distributed-optimizer: true
  --overlap-grad-reduce: true
  --overlap-param-gather: true
  # Use unfused attention since MLA with fused attention and deterministic mode leads to NaN
  --attention-backend: unfused # TODO: switch back to fused attention after fix
  # Training args
  --use-mcore-models: true
  --sequence-parallel: true
  --disable-bias-linear: true
  --micro-batch-size: 4
  --global-batch-size: 32
  --train-iters: 50
  --exit-duration-in-mins: 230
  --no-check-for-nan-in-loss-and-grad: true
  --no-rope-fusion: true
  --cross-entropy-loss-fusion: true
  --cross-entropy-fusion-impl: native
  --manual-gc: true
  --manual-gc-interval: 100
  --recompute-granularity: selective
  --recompute-modules: "[layernorm mla_up_proj mlp moe_act]"
  # Transformer Engine args
  --transformer-impl: transformer_engine
  # Data args
  --seq-length: 4096
  --data-cache-path: ${DATA_CACHE_PATH}
  --data-path: ${DATA_PATH}/my-gpt3_00_text_document
  --vocab-file: ${DATA_PATH}/bpe/vocab.json
  --merge-file: ${DATA_PATH}/bpe/merges.txt
  --split: 949,50,1
  # Add network size args
  --num-layers: 16
  --moe-layer-freq: ([0]*3+[1]*13)
  --pipeline-model-parallel-layout: Et*3\\|\\(tt\\|\\)*6tL # Et*3|(tt|)*6tL
  --hidden-size: 1024
  --ffn-hidden-size: 4096
  --num-attention-heads: 32
  --kv-channels: 128
  --max-position-embeddings: 4096
  --position-embedding-type: rope
  --rotary-base: 10000
  --make-vocab-size-divisible-by: 3232
  --normalization: RMSNorm
  --norm-epsilon: 1e-6
  --swiglu: true
  --untie-embeddings-and-output-weights: true
  --multi-latent-attention: true
  # Comment out the following MTP args to disable MTP
  --mtp-num-layers: 1
  --mtp-loss-scaling-factor: 0.1
  # Add regularization args
  --attention-dropout: 0.0
  --hidden-dropout: 0.0
  --clip-grad: 1.0
  --weight-decay: 0.1
  --qk-layernorm: true
  # Add learning rate args
  --lr-warmup-fraction: .01
  --lr: 0.00015
  --min-lr: 1.0e-5
  --lr-decay-style: cosine
  --adam-beta1: 0.9
  --adam-beta2: 0.95
  # Add MoE args
  --num-experts: 32
  --moe-ffn-hidden-size: 1024
  --moe-shared-expert-intermediate-size: 1024
  --moe-router-load-balancing-type: seq_aux_loss
  --moe-router-topk: 4
  --moe-token-dispatcher-type: alltoall
  --moe-router-pre-softmax: true
  --moe-grouped-gemm: true
  --moe-aux-loss-coeff: 1e-4
  --moe-router-group-topk: 2
  --moe-router-num-groups: 4
  --moe-router-topk-scaling-factor: 2.0
  --moe-router-score-function: sigmoid
  --moe-router-enable-expert-bias: true
  --moe-router-bias-update-rate: 1e-3
  --moe-router-dtype: fp32
  --moe-permute-fusion: true
  # Add MLA args
  --q-lora-rank: 1536
  --kv-lora-rank: 512
  --qk-head-dim: 128
  --qk-pos-emb-head-dim: 64
  --v-head-dim: 128
  --rotary-scaling-factor: 40
  --mscale: 1.0
  --mscale-all-dim: 1.0
  # Add validation args
  --eval-iters: 32
  --eval-interval: 200
  # Add checkpointing args
  --save: ${CHECKPOINT_SAVE_PATH}
  --load: ${CHECKPOINT_LOAD_PATH}
  --save-interval: 25
  # Add initialization args
  --init-method-std: 0.02
  # Add logging args
  --log-timers-to-tensorboard: true
  --log-memory-to-tensorboard: true
  --log-num-zeros-in-grad: true
  --log-params-norm: true
  --log-validation-ppl-to-tensorboard: true
  --log-throughput: true
  --log-interval: 1
  --logging-level: 40
  --tensorboard-dir: ${TENSORBOARD_PATH}
  # Add mixed precision args
  --bf16: true
  --exit-interval: 50
TEST_TYPE: ckpt-resume
METRICS:
  - "iteration-time"
  - "lm loss"
  - "num-zeros"
  - "mem-allocated-bytes"
  - "mem-max-allocated-bytes"
  - "mtp_1 loss"
