# Configuration for SFT

defaults:
  - _self_
  - sft_datasets:
    - vqa         # open-ended VQA
    - gqa
    - vsr
    - ocrvqa
    - aokvqa      # multiple-choice VQA
    - scienceqa
    - refcoco     # referring expression comprehension
    - refcoco+
    - refcocog
    - vg
    - llava150k   # instructions

training_config:
  total_training_steps: 50000    # -1 means default to the dataset size
  max_length: 512
  batch_size: 4
  workers: 4
  template: "default"
  sampling_weights: [
    0.103,
    0.103,
    0.026,
    0.103,
    0.103,
    0.103,
    0.103,
    0.103,
    0.103,
    0.047,
    0.103,
  ]

root: "{SFT_data_root}"
RUN_NAME: "models/AKI"

# training args from the original argparse
model_family: AKI
training_mode: sft_scratch                  # sft_scratch, sft_resume to distinguish from the pre-trained model or already SFT model
lm_path: "microsoft/Phi-3.5-mini-instruct"
tokenizer_path: "microsoft/Phi-3.5-mini-instruct"
vision_encoder_path: "siglip-so400m-patch14-384"
vision_encoder_pretrained: "google"
checkpoint_steps: 5000
logging_steps: 5
loss: "supervised_finetune"
lr_scheduler: "cosine"
run_name: ${RUN_NAME}
num_epochs: 1
warmup_steps: 150
precision: amp_bf16
fsdp: true
learning_rate: 2e-5
min_lr: 1e-6
weight_decay: 1e-4
report_to_tensorboard: true
log_dir: "${RUN_NAME}/logs"

# has default values originally
resume_from_checkpoint: "{checkpoint_path}"
gradient_accumulation_steps: 1
seed: 42
dataset_resampled: false
horovod: false
no_set_device_rank: false
fsdp_use_orig_params: false
fsdp_sharding_strategy: "full"
dist_url: "env://"
dist_backend: "nccl"
offline: false
gradient_checkpointing: false
cpu_offload_gradients: false
delete_previous_checkpoint: false
