
image_processor: google/vit-base-patch16-224-in21k   # HugginngFace tag
tokenizer: google/vit-base-patch16-224-in21k   # HF calls everything a "tokenizer"
model_base: google/vit-base-patch16-224-in21k  # HuggingFace tag
model_tag: vit-base

# The setting of per-device BS and number of devices should match this
effective_batch_size: 32
# effective_batch_size: 20  # for mercury1 or 2
inference_batch_size: 32  # Batch size to use during inference with model parallel

seed: 13

lora_config:
  lora_r: 16           # rank
  lora_alpha: 16      # scaling factor
  lora_dropout: 0.0

# The following args name should match TrainingArguments
# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
training_args:

  # There seems to be no consensus on the number of epochs to finetune
  # - 3 epochs: https://duarteocarmo.com/blog/fine-tune-llama-2-telegram
  # - 5 epochs: https://aws.amazon.com/blogs/machine-learning/fine-tune-llama-2-for-text-generation-on-amazon-sagemaker-jumpstart/
  # - 10 epochs: https://www.anyscale.com/blog/fine-tuning-llama-2-a-comprehensive-case-study-for-tailoring-models-to-unique-applications
  # However we know more epochs is not necessarily better
  num_train_epochs: 5

  # 1e-4 is a good default LR for finetuning (post below used LoRA)
  # - https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
  # - however, for IMDB smaller LR seems to be more stable
  # For full FT, the model can easily memorize/overfit the entire FT set, and
  # a smaller LR / short training is usually better
  learning_rate: 0.0001
  lr_scheduler_type: cosine

  per_device_train_batch_size: 8   # LoRA on dlab_hatespeech_race
  per_device_eval_batch_size: 64
  gradient_accumulation_steps: 1  # If GPU is OOM (smaller GPUs e.g. w/ 16GB)

  # bf16 can be sometimes slower on skampere1 for some reason
  # fp16: true
  bf16: true  # Since we always run on skampere1, just use bf16 for better precision

  # Some sources suggest setting WD to 0.001 typically
  # - https://medium.com/@ogbanugot/notes-on-fine-tuning-llama-2-using-qlora-a-detailed-breakdown-370be42ccca1
  # - https://www.datacamp.com/tutorial/fine-tuning-llama-2
  weight_decay: 0.001
  # weight_decay: 0.1  # full FT may need higher LR

  # Flags for indicating what to do; not actionable by Trainer and only for the user
  do_train: true
  do_eval: true
  evaluation_strategy: epoch

  ### DEBUG: disable eval and cut down training steps ###
  # max_steps: 50
  # Gradient checkpointing disables `use_cache=True`, slowing down inference
  # - https://discuss.huggingface.co/t/why-is-use-cache-incompatible-with-gradient-checkpointing/18811
  # gradient_checkpointing: true
  # `compile`: For LoRA training on mercury1, seems to give no speed up and messes up with the saved mdoel
  # torch_compile: true

  logging_steps: 5
  warmup_ratio: 0.01
  logging_first_step: true
  remove_unused_columns: true
  dataloader_num_workers: 0

  # save_strategy: "no"  # disable for now and just save after fine-tuning
  # save_strategy: steps  # disable for now and just save after fine-tuning
  # save_steps: 1167  # 2 epochs of dlab-race
  save_strategy: epoch  # disable for now and just save after fine-tuning
  save_total_limit: 10  # at most 2 checkpoints