datasets:
  # class_name, data_name & data_sampling_ratio are not used but need to be passed to avoid errors
  - class_name: FSDPDataset
    data_name: FSDP
    data_sampling_ratio: 1
    class_args:
      enable_checkpoint: True
      checkpoint_model_weights_only: False
      checkpoint_folder: checkpoints
      dataset_path: /proj/datasets/fms-dataloader-test/
      datasets: null #"dump=CC-MAIN-2013-20,dump=CC-MAIN-2015-35,dump=CC-MAIN-2017-26,dump=CC-MAIN-2018-47,dump=CC-MAIN-2020-24,dump=CC-MAIN-2022-27"
      weights: null #"1,1,1,1,1,1"
      # dataset_path: "/proj/data-eng/fsdp/data/R45_th7_200"
      # datasets: "Java,Python"
      # weights: "0.5,0.5"
      col_name: "tokens"
      file_type: "arrow"
      vocab_size: 200000
      bos_token: null
      eos_token: 0
      drop_tokens: ""
      logical_shards: 8 # 8GPUs * 2 Nodes * 4 workers * 2 (Some multiple)
      num_workers: 1
      seed: 2023
      sequence_length: 2048 # This sets the seq_len
 
tokenizer_args:
  tokenizer_name: bigcode/starcoder # /proj/data-eng/tokenizers/bigcode_starcoder # bigcode/starcoder

model_args:
  model_class: AutoModelForCausalLM
  pretrained_config:
    model_type: gpt_dolomite
    vocab_size: 200000
    max_position_embeddings: 2048
    hidden_size: 768
    num_layers: 12
    num_attention_heads: 12
    num_key_value_heads: null
    intermediate_size: null
    activation_function: gelu_pytorch_tanh
    attention_head_type: mqa
    normalization_function: layernorm
    layer_norm_epsilon: 1e-5
    initializer_range: 0.02
    use_cache: true
    bos_token_id: 0
    eos_token_id: 0
    pad_token_id: 0
    add_bias: true
    position_embedding_type: learned_absolute
  attention_implementation: flash_attention_2
  use_padding_free_transformer: true

tuning_args:
  tuning_method: pretraining

load_args:
  load_path: /proj/checkpoints/mayank/fms

save_args:
  save_path: checkpoints
  save_interval: 10000

training_parameters:
  num_training_steps: 100000
  eval_during_training: false
  eval_interval: 500
  micro_batch_size: 6
  gradient_accumulation_steps: 4

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 1e-5
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: cosine

mixed_precision_args:
  dtype: bf16

logging_args:
  log_interval: 10
