ENV_VARS:
  CUDA_DEVICE_MAX_CONNECTIONS: 1
  TORCH_NCCL_AVOID_RECORD_STREAMS: 1
  NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1
  PYTORCH_CUDA_ALLOC_CONF: expandable_segments:True
  NCCL_NVLS_ENABLE: 0
TEST_TYPE: "release"
MODEL_ARGS:
  # Distributed args
  --distributed-timeout-minutes: 60
  --tensor-model-parallel-size: 2
  --pipeline-model-parallel-size: 8
  --use-distributed-optimizer: true
  --overlap-grad-reduce: true
  --overlap-param-gather: true
  # Training args
  --use-mcore-models: true
  --sequence-parallel: true
  --use-flash-attn: true
  --disable-bias-linear: true
  --micro-batch-size: 1
  --global-batch-size: 256
  --train-samples: 38400
  --exit-duration-in-mins: 230
  # Transformer Engine args
  --transformer-impl: transformer_engine
  # Data args
  --data-cache-path: ${DATA_CACHE_PATH}
  --tokenizer-type: Llama2Tokenizer
  --tokenizer-model: ${DATA_PATH}/tokenizer.model
  --data-path: ${DATA_BLEND}
  --split: 99,1,0
  --no-mmap-bin-files: true
  --num-workers: 6
  # Add network size args
  --untie-embeddings-and-output-weights: true
  --position-embedding-type: rope
  --no-rope-fusion: true  #TODO: We can remove this once upgrading to the DEV container
  --rotary-percent: 1.0
  --normalization: RMSNorm
  --swiglu: true
  --num-layers: 56
  --hidden-size: 6144
  --ffn-hidden-size: 16384
  --num-attention-heads: 48
  --group-query-attention: true
  --num-query-groups: 8
  --seq-length: 4096
  --max-position-embeddings: 4096
  --make-vocab-size-divisible-by: 128
  # Add regularization args
  --attention-dropout: 0.0
  --hidden-dropout: 0.0
  --clip-grad: 1.0
  --weight-decay: 0.1
  # Add learning rate args
  --lr-decay-samples: 255126953
  --lr-warmup-samples: 162761
  --lr: 1.2e-5
  --min-lr: 1.2e-6
  --lr-decay-style: cosine
  --adam-beta1: 0.9
  --adam-beta2: 0.95
  # Add MoE args
  --expert-model-parallel-size: 8
  --num-experts: 8
  --moe-router-load-balancing-type: aux_loss
  --moe-router-topk: 2
  --moe-grouped-gemm: true
  --moe-aux-loss-coeff: 1e-2
  --moe-token-dispatcher-type: alltoall
  # Add validation args
  --eval-iters: 32
  --eval-interval: 500
  # Add checkpointing args
  --finetune: true
  --auto-detect-ckpt-format: true
  --load: ${LOAD_PATH}
  --save: ${OUTPUT_PATH}/checkpoints
  --no-ckpt-fully-parallel-save: true
  --save-interval: 500
  # Add initialization args
  --init-method-std: 0.008
  # Add logging args
  --log-timers-to-tensorboard: true
  --log-memory-to-tensorboard: true
  --log-num-zeros-in-grad: true
  --log-params-norm: true
  --log-validation-ppl-to-tensorboard: true
  --log-throughput: true
  --log-interval: 1
  --tensorboard-dir: ${OUTPUT_PATH}/tensorboard
  --wandb-project: megatron-core-release-runs
  --wandb-exp-name: ${WANDB_EXPERIMENT}
  # Add mixed precision args
  --bf16: true
