defaults:
  - weights_cfg@_global_: base
  - base
  - _self_ 

num_train_epochs: 1
max_steps: 2
max_seq_length: 1024

save_strategy: steps
save_steps: 50
do_eval: true

ddp_timeout: 180000000

train_batch_size: 512
per_device_train_batch_size: 8

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 1.0e-5
lr_scheduler_type: constant
warmup_ratio: 0.1

trainer_log_name: trainer

# free speedup on multi-gpu
dispatch_batches: false
trainer_accelerator_config:
    _target_: builtins.dict
    dispatch_batches: ${dispatch_batches}

# dataloader params - prefetches 8 batches in the dataloader
dataloader_num_workers: 4
dataloader_prefetch_factor: 2

trainer_args:
  _target_: hydra_utils.transformers.TrainingArguments
  # dispatch_batches: ${dispatch_batches}
  dataloader_num_workers: ${dataloader_num_workers}
  dataloader_prefetch_factor: ${dataloader_prefetch_factor}
  accelerator_config: ${trainer_accelerator_config}

trainer:
  _target_: hydra_utils.transformers.Trainer
  model: ${loaded_model}
  