datasets:
  # class_name, data_name & data_sampling_ratio are not used but need to be passed to avoid errors
  - class_name: MegatronDataset
    data_name: Megatron
    data_sampling_ratio: 1
    class_args:
      eval_steps: 2
      data_cache_path: /proj/checkpoints/mayank/cache
      # Option 1: data loading using --data-path with single file
      data_path:
        - 58996336
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk1
        - 58982360
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk2
        - 59060327
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk3
        - 59040311
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk4
        - 59201586
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk5
        - 59023939
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk6
        - 58932495
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk7
        - 59087730
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk8
        - 59073554
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk9
        - 58995987
        - /proj/datasets/slim_pajama_gptneox_megatron/train/chunk10
      split: 98,1,1
      sequence_length: 2048

tokenizer_args:
  tokenizer_name: EleutherAI/gpt-neox-20b

model_args:
  model_class: AutoModelForCausalLM
  pretrained_config:
    activation_function: swiglu
    add_bias: true
    bos_token_id: 0
    eos_token_id: 0
    initializer_range: 0.02
    layer_norm_epsilon: 1e-05
    model_type: gpt_crosslayer
    hidden_size: 3072
    num_attention_heads: 12
    intermediate_size: 8192
    num_layers: 32
    max_position_embeddings: 2048
    normalization_function: rmsnorm
    num_key_value_heads: 1
    pad_token_id: 0
    position_embedding_type: rope
    rope_theta: 10000
    sharing_pattern:
      - 0
      - 0
      - 2
      - 2
      - 4
      - 4
      - 6
      - 6
      - 8
      - 8
      - 10
      - 10
      - 12
      - 12
      - 14
      - 14
      - 16
      - 16
      - 18
      - 18
      - 20
      - 20
      - 22
      - 22
      - 24
      - 24
      - 26
      - 26
      - 28
      - 28
      - 30
      - 30
    vocab_size: 50304
  efficient_initialization: false
  attention_implementation: flash_attention_2
  use_padding_free_transformer: true

tuning_args:
  tuning_method: pretraining

save_args:
  save_path: /proj/checkpoints/mayank/mla-experiments/3b-base-cla
  save_interval: 5000

logging_args:
  log_interval: 10
  experiments_tracker_name: wandb
  wandb_args:
    project: multilayer-attention
    name: 3b-cla

training_parameters:
  num_training_steps: 25000
  eval_interval: 2500000
  micro_batch_size: 4
  gradient_accumulation_steps: 4

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 3e-4
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: cosine
  num_warmup_steps: 2000
  num_decay_steps: 23000

mixed_precision_args:
  dtype: bf16

distributed_args:
  stage: 3
  zero_topology:
    data_parallel_replication_world_size: 16
    data_parallel_sharding_world_size: 8
