datasets:
  # class_name - data_name & data_sampling_ratio are not used but need to be passed to avoid errors
  - class_name: MegatronDataset
    data_name: Megatron
    data_sampling_ratio: 1
    class_args:
      eval_steps: 2
      data_cache_path: /proj/checkpoints/mayank/cache
      # Option 1: data loading using --data-path with single file
      data_path:
        - /proj/datasets/training_data_starcoder_cleaned_0324/EDGAR+Secfiling+FIDC+Earningcalltranscript
      split: 100,0,0
      sequence_length: 4096

tokenizer_args:
  tokenizer_name: bigcode/starcoder

# kernel_args:
#   kernels:
#     - cute_swiglu
#     - cute_rmsnorm

model_args:
  model_class: AutoModelForCausalLM
  pretrained_config:
    initializer_range: 0.1
    layer_norm_epsilon: 1e-05
    model_type: ladder_residual
    normalization_function: rmsnorm
    position_embedding_type: rope
    rope_theta: 10000
    hidden_size: 4096
    num_attention_heads: 32
    # attention_multiplier: 0.0078125
    # m_width: 6
    # m_emb: 12
    # m_residual: 0.22
    num_layers: 8
    init_method: normal
    router_aux_loss_coef: 0.01
    bos_token_id: 0
    eos_token_id: 0
    pad_token_id: 0
    sequence_mixer_blocks:
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
      - sequence_mixer_type: softmax_attention
        attention_head_type: mha
    mlp_blocks:
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
      - mlp_type: MLP
        activation_function: swiglu
  use_padding_free_transformer: true
  attention_implementation: flash_attention_2
  efficient_initialization: true

tuning_args:
  tuning_method: pretraining

save_args:
  save_path: /proj/checkpoints/mayank/test
  save_interval: 25000

logging_args:
  log_interval: 10
  torch_profiler_trace_path: tmp

training_parameters:
  num_training_steps: 500000
  eval_interval: 1000000000
  micro_batch_size: 8
  gradient_accumulation_steps: 1
  eval_during_training: false

optimizer_args:
  # params_group_method: mup
  class_name: TorchAdamW
  class_args:
    lr: 0.02
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: power
  num_warmup_steps: 2500
  num_constant_steps: 0
  num_decay_steps: 497500
  extra_lr_scheduler_args:
    # 4 * global_batch_size
    a: 4096
    # constant
    b: -0.51
    # global_batch_size in number of tokens
    c: 4194304

mixed_precision_args:
  dtype: bf16

distributed_args:
  fsdp_algorithm: 2
  torch_compile: true
  tensor_parallel_world_size: 2
  zero_topology:
    data_parallel_sharding_world_size: 1
    data_parallel_replication_world_size: 4
