ENV_VARS:
  SKIP_PYTEST: 1
  CUDA_DEVICE_MAX_CONNECTIONS: 1
  NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
  NCCL_ALGO: Ring
  CUBLAS_WORKSPACE_CONFIG: :4096:8
  ARTIFACTS_ROOT: /workspace/checkpoints
  DISTILL_CONFIG: '{intermediate_layer_pairs: [["decoder.final_layernorm", "decoder.final_layernorm"]], logit_layers: ["output_layer", "output_layer"], skip_lm_loss: true, kd_loss_scale: 10.0}'
BEFORE_SCRIPT: |
  mkdir -p ${DATA_CACHE_PATH}/distill && echo $DISTILL_CONFIG | yq -P > ${DATA_CACHE_PATH}/distill/distill_config.yaml
MODEL_ARGS:
  --export-te-mcore-model: true
  --export-kd-teacher-load: ${ARTIFACTS_ROOT}/gpt_teacher
  --export-kd-cfg: ${DATA_CACHE_PATH}/distill/distill_config.yaml
  --auto-detect-ckpt-format: true
  --num-layers: 12
  --hidden-size: 512
  --num-attention-heads: 8
  --normalization: RMSNorm
  --log-params-norm: true
  --log-num-zeros-in-grad: true
  --log-validation-ppl-to-tensorboard: true
  --log-timers-to-tensorboard: true
  --tensorboard-dir: ${TENSORBOARD_PATH}
  --micro-batch-size: 2
  --global-batch-size: 16
  --seq-length: 1024
  --max-position-embeddings: 1024
  --position-embedding-type: rope
  --no-rope-fusion: true #TODO: We can remove this once upgrading to the DEV container
  --rotary-percent: 0.5
  --swiglu: true
  --untie-embeddings-and-output-weights: true
  --disable-bias-linear: true
  --train-iters: 100
  --timing-log-level: 2
  --lr-decay-iters: 320000
  --save: ${CHECKPOINT_SAVE_PATH}
  --load: ${CHECKPOINT_LOAD_PATH}
  --data-path: ${DATA_PATH}/my-gpt3_00_text_document
  --vocab-file: ${DATA_PATH}/bpe/vocab.json
  --merge-file: ${DATA_PATH}/bpe/merges.txt
  --split: 949,50,1
  --distributed-backend: nccl
  --lr: 0.00015
  --lr-decay-style: cosine
  --min-lr: 1.0e-5
  --weight-decay: 1e-2
  --clip-grad: 1.0
  --lr-warmup-fraction: .01
  --use-distributed-optimizer: true
  --log-interval: 1
  --save-interval: 50
  --eval-interval: 1000
  --eval-iters: 10
  --transformer-impl: transformer_engine
  --tensor-model-parallel-size: 2
  --pipeline-model-parallel-size: 1
  --sequence-parallel: true
  --deterministic-mode: true
  --no-gradient-accumulation-fusion: true
  --use-checkpoint-opt_param-scheduler: true
  --ckpt-format: torch_dist
  --dist-ckpt-strictness: log_all  # backward compatibility for TE changes
  --data-cache-path: ${DATA_CACHE_PATH}
  --bf16: true
  --attention-backend: unfused
  --log-memory-to-tensorboard: true
TEST_TYPE: ckpt-resume
