defaults:
  - override /trainer_cfg@_global_: sequence_mixing_trainer
  - override /model_cfg@_global_: smollm/drope.yaml
  - override /data_cfg@_global_: fineweb_edu_shuffled
  - _self_

bf16: true
tf32: true
num_train_epochs: 1
logging_steps: 1
logging_strategy: steps

# 120B tokens
max_steps: 120000

save_strategy: steps
save_steps: 5000
do_eval: false

train_batch_size: 512 
per_device_train_batch_size: 32

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false

learning_rate: 3e-4
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-8
max_grad_norm: 1.0

lr_scheduler_type: cosine

warmup_ratio: 0.005
ddp_timeout: 18000

packing: false

max_seq_length: 2048
per_device_eval_batch_size: 4

copy_attention_weights: true
use_additional_features: false
forward_mode: sequence_mixing
loss_type: "mse"
per_layer_loss_coef: 0.0
distillation_loss_coef: 0.0
next_token_loss_coef: 1.0
mask_per_layer_losses: true
reinitialize_base_model: false
freeze_base_model: false

skip_train_batches: null

trainer_args:
  _target_: transformers.TrainingArguments
  output_dir: ${output_dir}
  max_steps: ${max_steps}
  num_train_epochs: ${num_train_epochs}
  per_device_train_batch_size: ${per_device_train_batch_size}
  per_device_eval_batch_size: ${per_device_train_batch_size}
  gradient_accumulation_steps: ${gradient_accumulation_steps}
  gradient_checkpointing: ${gradient_checkpointing}
  gradient_checkpointing_kwargs: ${gradient_checkpointing_kwargs}
  learning_rate: ${learning_rate}
  weight_decay: ${weight_decay}
  adam_beta1: ${adam_beta1}
  adam_beta2: ${adam_beta2}
  adam_epsilon: ${adam_epsilon}
  max_grad_norm: ${max_grad_norm}
  lr_scheduler_type: ${lr_scheduler_type}
  warmup_ratio: ${warmup_ratio}
  logging_strategy: ${logging_strategy}
  logging_steps: ${logging_steps}
  save_strategy: ${save_strategy}
  save_steps: ${save_steps}
  report_to: ${report_to}
  run_name: ${wandb_run_name}
  bf16: ${bf16}
  tf32: ${tf32}
  seed: ${seed}
  ddp_timeout: ${ddp_timeout}
  do_train: ${do_train}
  do_eval: ${do_eval}
  remove_unused_columns: false


wandb_project: sequence_mixing_from2k
wandb_run_name: ${trainer_log_name}_${model_log_name}_${max_steps}steps_lr${learning_rate}
