# =========================================================
# ===============  SFT Training Config  ===================
# =========================================================

# ==== Model ====
model_name_or_path: "${MODEL_NAME}"
trust_remote_code: true
flash_attn: sdpa
gradient_checkpointing: true

# ==== Method ====
stage: sft
do_train: true
finetuning_type: full
deepspeed: "${DS_CONFIG}"

# ==== Dataset ====
dataset_dir: "graphAGI"
dataset: "${Train_dataset}"
eval_dataset: "${Eval_dataset}"


template: "${TEMPLATE}" 

cutoff_len: 2048
overwrite_cache: true
preprocessing_num_workers: 8
dataloader_num_workers: 4
remove_unused_columns: false
group_by_length: true

# ==== Multi-task ====
mix_strategy: "concat"
eval_on_each_dataset: true

freeze_vision_tower: ${VT}
freeze_multi_modal_projector: false
freeze_language_model: false

# ==== Output ====
output_dir: "saves/${Output_dir}"
overwrite_output_dir: false
save_steps: 100
save_total_limit: 1
logging_steps: 5
plot_loss: true
seed: 42

# ==== Train ====
per_device_train_batch_size: 1
gradient_accumulation_steps: ${ACCUM_STEPS}
learning_rate: ${LR}
num_train_epochs: ${EPOCH}
lr_scheduler_type: ${LR_TYPE}
warmup_ratio: ${WR}
bf16: true

# ==== Eval ====
eval_strategy: steps
eval_steps: 5000
per_device_eval_batch_size: 1

# ==== Optimizer ====
optim: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.98
weight_decay: 0.01
max_grad_norm: 1.0

# ==== Logging ====
report_to: "wandb"
run_name: "${Run_name}"
