datasets:
  # class_name, data_name & data_sampling_ratio are not used but need to be passed to avoid errors
  - class_name: MegatronDataset
    data_name: Megatron
    data_sampling_ratio: 1
    class_args:
      eval_steps: 2
      data_cache_path: /proj/checkpoints/mayank/cache
      # Option 1: data loading using --data-path with single file
      data_path:
        - data/lang=Matlab
      split: 100,0,0
      sequence_length: 4096

tokenizer_args:
  tokenizer_name: bigcode/starcoder

model_args:
  model_class: AutoModelForCausalLM
  pretrained_config:
    activation_function: swiglu
    add_bias: false
    attn_pdrop: 0
    embd_pdrop: 0
    resid_pdrop: 0
    initializer_range: 0.1
    layer_norm_epsilon: 1e-05
    model_type: moe_dolomite
    n_embd: 1024
    n_head: 16
    n_inner: 512
    n_layer: 24
    n_positions: 4096
    num_experts: 32
    num_experts_per_tok: 8
    num_key_value_heads: 8
    normalization_function: rmsnorm
    position_embedding_type: rope
    rope_theta: 10000
    attention_head_type: gqa
    scale_attn_weights: true
    vocab_size: 49152
    tie_word_embeddings: true
    upcast_logits_for_loss: true
    bos_token_id: 0
    eos_token_id: 0
    pad_token_id: 0
    router_aux_loss_coef: 0.01
  moe_implementation: scattermoe
  attention_implementation: sdpa

tuning_args:
  tuning_method: pretraining

save_args:
  save_path: /proj/checkpoints/mayank/test/sdpa-stage-0-1b-moe-compile
  save_interval: 5000

logging_args:
  log_interval: 10

training_parameters:
  num_training_steps: 25000
  eval_interval: 10000000
  micro_batch_size: 2
  gradient_accumulation_steps: 8
  eval_during_training: false

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 3e-4
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: cosine
  num_warmup_steps: 2500
  num_constant_steps: 0
  num_decay_steps: 22500

mixed_precision_args:
  dtype: bf16

distributed_args:
  distributed_backend: torch
  communication_dtype: fp32
  torch_compile: true
  stage: 0
