datasets:
  # class_name, data_name & data_sampling_ratio are not used but need to be passed to avoid errors
  - class_name: MegatronDataset
    data_name: Megatron
    data_sampling_ratio: 1
    class_args:
      eval_steps: 2
      data_cache_path: /proj/checkpoints/mayank/cache
      # Option 1: data loading using --data-path with single file
      data_path:
        - /proj/datasets/training_data_starcoder_cleaned_0324/fineweb-edu
      split: 98,1,1
      sequence_length: 2048

tokenizer_args:
  tokenizer_name: bigcode/starcoder

model_args:
  model_class: AutoModelForCausalLM
  model_name: ibm/PowerLM-3b
  efficient_initialization: false
  attention_implementation: sdpa
  use_padding_free_transformer: false

teacher_args:
  model_class: AutoModelForCausalLM
  model_name: ibm/PowerLM-3b
  dtype: bf16
  kl_divergence_method: forward
  kl_divergence_weight: 1

tuning_args:
  tuning_method: distillation

save_args:
  save_path: /proj/checkpoints/mayank/tmp
  save_interval: 5000

logging_args:
  log_interval: 10

training_parameters:
  num_training_steps: 25000
  eval_interval: 2500000
  micro_batch_size: 2
  gradient_accumulation_steps: 4

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 3e-4
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: cosine
  num_warmup_steps: 2000
  num_decay_steps: 23000

mixed_precision_args:
  dtype: bf16

distributed_args:
  communication_dtype: fp32
  stage: 3
  fsdp_algorithm: 2
  zero_topology:
    data_parallel_replication_world_size: 1
    data_parallel_sharding_world_size: 8
