# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html

ema:
  _target_: src.callbacks.ema_callback.EMACallback
  decay: 0.9995  # should be around 0.999 to 0.9999
  ema_warmup: 10000
  
  # for checkpointing
  dirpath: ${paths.output_dir}/checkpoints
  filename: step_{step:03d}
  monitor: val/loss
  mode: min
  every_n_train_steps: 1000
  save_last: False
  save_top_k: 1 # save k best models (determined by above metric)
  save_weights_only: False # if True, then only the model’s weights will be saved
  