# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_root: ../data/countdown/corpora/
positive_dataset_sources: 
- <model_name>/cot_1/correct_train.csv
- <model_name>/cot_2/correct_train.csv
- <model_name>/cot_3/correct_train.csv
- <model_name>/cot_4/correct_train.csv
- <model_name>/cot_5/correct_train.csv
- <model_name>/cot_6/correct_train.csv
- <model_name>/cot_7/correct_train.csv
- <model_name>/cot_8/correct_train.csv
- <model_name>/cot_9/correct_train.csv
negative_dataset_sources: 
- <model_name>/cot_1/failed_train.csv
- <model_name>/cot_2/failed_train.csv
- <model_name>/cot_3/failed_train.csv
- <model_name>/cot_4/failed_train.csv
- <model_name>/cot_5/failed_train.csv
- <model_name>/cot_6/failed_train.csv
- <model_name>/cot_7/failed_train.csv
- <model_name>/cot_8/failed_train.csv
- <model_name>/cot_9/failed_train.csv
dataset_num_proc: 16
chat_template: null

# logging
output_dir: logs
overwrite_output_dir: true
log_level: info
logging_steps: 20
logging_strategy: steps
do_eval: false # we eval after training finishes
eval_on_start: false
eval_strategy: 'no'

# Note: we save all ckpts for future evaluation, because eval_loss is not a reliable signal in our task
num_train_epochs: 1
num_copy_positive: 2
save_only_model: true # to save disk and time
load_best_model_at_end: false
save_strategy: steps
save_steps: 0.05

# pref config:
# packing = false by default
max_length: 256 # the actual max_seq_length depends on the minibatch
per_device_eval_batch_size: 16
eval_accumulation_steps: 1
per_device_train_batch_size: 8
gradient_accumulation_steps: 1

learning_rate: 2.0e-05
lr_scheduler_type: cosine_with_min_lr
warmup_ratio: 0.1
lr_scheduler_kwargs: 
  min_lr: 7.0e-08

seed: 42
bf16: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: False
remove_unused_columns: true
report_to:
- wandb
push_to_hub: false
hub_model_id: null
hub_strategy: every_save
