datasets:
  # class_name, data_name & data_sampling_ratio are not used but need to be passed to avoid errors
  - class_name: MegatronDataset
    data_name: Megatron
    data_sampling_ratio: 1
    class_args:
      eval_steps: 2
      data_cache_path: cache
      # Option 4: data loading using --(train|val|test)-weighted-split-paths
      # train is only allowed to have 1 group
      train_weighted_split_paths:
        - TRAIN:
          - weight: 0.2
            # why does yaml not make this a string if first digit is 0?
            split: "0:0.98"
            path: data/train/lang=Matlab
          - weight: 0.3
            split: "0:0.98"
            path: data/train/lang=Verilog
          - weight: 0.5
            split: "0:0.98"
            path: data/train/lang=Zig
      # val & test can have multiple groups
      val_weighted_split_paths:
        - VAL_lang=Matlab:
          - weight: 1
            split: 0.98:0.99
            path: data/val/lang=Matlab
        - VAL_lang=Verilog:
          - weight: 1
            split: 0.98:0.99
            path: data/val/lang=Verilog
        - VAL_lang=Zig:
          - weight: 1
            split: 0.98:0.99
            path: data/val/lang=Zig
        - VAL_all_sources_weighted:
          - weight: 0.2
            split: 0.98:0.99
            path: data/val/lang=Matlab
          - weight: 0.3
            split: 0.98:0.99
            path: data/val/lang=Verilog
          - weight: 0.5
            split: 0.98:0.99
            path: data/val/lang=Zig
      test_weighted_split_paths:
        - TEST_lang=Matlab:
          - weight: 1
            split: 0.99:1
            path: data/test/lang=Matlab
        - TEST_lang=Verilog:
          - weight: 1
            split: 0.99:1
            path: data/test/lang=Verilog
        - TEST_lang=Zig:
          - weight: 1
            split: 0.99:1
            path: data/test/lang=Zig
        - TEST_all_sources_weighted:
          - weight: 0.2
            split: 0.99:1
            path: data/test/lang=Matlab
          - weight: 0.3
            split: 0.99:1
            path: data/test/lang=Verilog
          - weight: 0.5
            split: 0.99:1
            path: data/test/lang=Zig
      sequence_length: 2048

tokenizer_args:
  tokenizer_name: bigcode/starcoder

model_args:
  model_class: AutoModelForCausalLM
  pretrained_config:
    model_type: gpt_dolomite
    vocab_size: 50257
    max_position_embeddings: 2048
    hidden_size: 768
    num_layers: 12
    num_attention_heads: 12
    num_key_value_heads: null
    intermediate_size: null
    activation_function: gelu_pytorch_tanh
    attention_head_type: mqa
    normalization_function: layernorm
    layer_norm_epsilon: 1e-5
    initializer_range: 0.02
    bos_token_id: 50256
    eos_token_id: 50256
    pad_token_id: 50256
    add_bias: true
    position_embedding_type: learned_absolute
  attention_implementation: flash_attention_2
  use_padding_free_transformer: true

tuning_args:
  tuning_method: pretraining

save_args:
  save_path: checkpoints
  save_interval: 50

training_parameters:
  num_training_steps: 100
  eval_interval: 50
  micro_batch_size: 6

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 1e-5
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: cosine

mixed_precision_args:
  dtype: bf16
