data:
  train_batch_size: 256
  micro_batch_size: 16  # this is also val batch size
  train_files: ~/data/gsm8k/train.parquet
  val_files: ~/data/gsm8k/test.parquet
  prompt_key: question
  response_key: answer
  max_length: 1024
  truncation: right
  balance_dp_token: False
  chat_template: null
model:
  partial_pretrain: ~/models/gemma-1.1-7b-it
  fsdp_config:
    wrap_policy:
      min_num_params: 0
    cpu_offload: False
    offload_params: False
  external_lib: null
  enable_gradient_checkpointing: False
  trust_remote_code: False
optim:
  lr: 1e-5
  betas: [0.9, 0.95]
  weight_decay: 0.01
  warmup_steps_ratio: 0.1
  clip_grad: 1.0

trainer:
  seed: 1
  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
  default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
  resume_path: null
  project_name: gsm8k-sft
  experiment_name: test
  total_epochs: 5
  logger: ['console']
  total_training_steps: 300

