# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_root: ../data/countdown/corpora/
positive_dataset_sources: 
- bfs/correct_train.csv
- dfs/correct_train.csv
negative_dataset_sources: 
- bfs/failed_train.csv
- dfs/failed_train.csv
dataset_text_field: text
dataset_num_proc: 16
chat_template: null
train_on_completion_only: true

# SFT trainer config
## UL related
ul_loss_type: unlikelihood
ul_alpha: 0.0

# logging
output_dir: logs
overwrite_output_dir: true
log_level: info
logging_steps: 75
logging_strategy: steps
do_eval: false # we eval after training finishes
eval_on_start: false
eval_strategy: 'no'

# Note: we save all ckpts for future evaluation, because eval_loss is not a reliable signal in our task
num_train_epochs: 2
save_only_model: true # to save disk and time
load_best_model_at_end: false
save_strategy: steps
save_steps: 0.05

packing: false # make sure it is false, since default packing causes contamination
max_seq_length: 256 # the actual max_seq_length depends on the minibatch
per_device_eval_batch_size: 16
eval_accumulation_steps: 1
per_device_train_batch_size: 8
gradient_accumulation_steps: 1

learning_rate: 2.0e-05
lr_scheduler_type: cosine_with_min_lr
warmup_ratio: 0.1
lr_scheduler_kwargs: 
  min_lr: 7.0e-08

seed: 42
bf16: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: False
remove_unused_columns: true
report_to:
- wandb
push_to_hub: false
hub_model_id: null
hub_strategy: every_save
