model_id: "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path: "."
max_seq_len: 3072 # 2048

output_dir: "./checkpoints/llama-3-8b-sep-qlora"
report_to: "wandb"
learning_rate: 0.0002
lr_scheduler_type: "constant"
num_train_epochs: 3
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
gradient_accumulation_steps: 2
optim: adamw_torch
logging_steps: 10
save_strategy: epoch
evaluation_strategy: epoch
max_grad_norm: 0.3
warmup_ratio: 0.03
bf16: true
tf32: true
gradient_checkpointing: true

#fsdp: "full_shard auto_wrap offload"
fsdp: "full_shard auto_wrap"
fsdp_config:
  backward_prefetch: "backward_pre"
  forward_prefetch: "false"
  use_orig_params: "false"