datasets:
  # a dummy dataset for throughput estimation, replace with your dataset
  # look at some other example config in this repo
  - class_name: DebugDataset
    class_args:
      num_examples: 100
    data_name: debug
    data_sampling_ratio: 1
    max_input_tokens: 65536
    max_output_tokens: 65536

model_args:
  model_name: /proj/checkpoints/mayank/granite-8b-code-instruct-128k
  model_class: AutoModelForCausalLM
  attention_implementation: flash_attention_2
  use_padding_free_transformer: true

tuning_args:
  tuning_method: full_finetuning

save_args:
  save_path: checkpoints
  save_interval: 5000

training_parameters:
  num_training_steps: 100
  eval_interval: 50
  eval_during_training: false
  micro_batch_size: 1
  gradient_accumulation_steps: 1
  gradient_clipping: 1

optimizer_args:
  class_name: TorchAdamW
  class_args:
    lr: 1e-5
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.95
    eps: 1e-10

lr_scheduler_args:
  lr_decay_style: cosine

mixed_precision_args:
  dtype: bf16

distributed_args:
  distributed_backend: torch
  tensor_parallel_size: 8
  fsdp_algorithm: 2
  sequence_parallel: true
  tensor_parallel_word_embeddings: true
  gradient_checkpointing_method: block
  gradient_checkpointing_args:
    checkpoint_every: 3
