datasets:
  - class_name: HuggingFaceDataset
    class_args:
      # edit accordingly
      data_path: mayank-mishra/glaive-code-assisstant-v3-20k
      input_key: question
      output_key: answer
    # just some metadata for internal use
    data_name: glaive-20k
    # data sampling ratio is meaningless when we have 1 dataset
    data_sampling_ratio: 1
    # to format input and output for training accordingly
    input_format: "input: __input__\noutput: "
    output_format: "__output__"
    max_input_tokens: 4096
    max_output_tokens: 4096

model_args:
  model_name: ibm-granite/granite-3b-code-base
  model_class: AutoModelForCausalLM
  attention_implementation: flash_attention_2
  # padding free transformer needs a gpt_dolomite model.
  # To convert granite models to this class and convert back after training,
  # take a look at the readme of this repo
  use_padding_free_transformer: false

random_args:
  # for replication of experiment (however, flash attention is non-deterministic so replication generally won't work)
  seed: 42

tuning_args:
  tuning_method: full_finetuning

save_args:
  save_path: checkpoints
  save_interval: 5000

training_parameters:
  # we will use 2 GPUs so our total samples seen during training is:
  # num_training_steps * micro_batch_size * gradient_accumulation_steps * data_parallel_size
  # = 20000 (1 epoch since the dataset we are using here has 20k samples exactly)
  # note: data_parallel_size = num_GPUs
  num_training_steps: 5000
  eval_during_training: false
  micro_batch_size: 2
  gradient_accumulation_steps: 1
  gradient_clipping: 1

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 1e-5
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  # linear, cosine or exponential decay
  lr_decay_style: cosine
  # linear warmup
  num_warmup_steps: 200
  # final lr will be 0.1 * max lr (max lr is set in optimizer args)
  lr_decay_factor: 0.1

mixed_precision_args:
  dtype: bf16

distributed_args:
  # use ZeRO-3 for model sharding, saves most memory but needs more communication. this is fine since we are doing training on 2 GPUs and they are connected via NVLink
  stage: 3
  # enable at your own risk as this is not tested completely
  torch_compile: false
  # this will load dataset only on the first GPU and send part of the data to the other GPUs, not recommended unless the datasets are immensely large
  dispatching_dataloader: false
