# Default fit config
# If no other --config is specified with mgen fit, this is used
# If another --config is specified, this is ignored
# To continue using this with another --config, specify this as another config
# e.g. mgen fit --config defaults.yaml --config my_custom_config.yaml
# LightningCLI will override the defaults with the custom config while keeping the remaining defaults
trainer:
  accelerator: auto
  devices: auto
  max_steps: -1
  max_epochs: -1
  gradient_clip_val: 1
  precision: 32
  log_every_n_steps: 50
  default_root_dir: logs
  # DDP
  strategy:
    class_path: lightning.pytorch.strategies.DDPStrategy
  # FSDP
  # strategy:
  #   class_path: lightning.pytorch.strategies.FSDPStrategy
  profiler:
    class_path: lightning.pytorch.profilers.PyTorchProfiler
    dict_kwargs:
      profile_memory: true
  callbacks:
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      dict_kwargs:
        logging_interval: "step"
    # Save a checkpoint for min val loss
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: val_loss
        mode: min
        save_top_k: 1
        filename: "best_val:{step}-{val_loss:.3f}-{train_loss:.3f}"
    # Save a checkpoint for min train loss
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: train_loss
        mode: min
        save_top_k: 1
        filename: "best_train:{step}-{val_loss:.3f}-train:{train_loss:.3f}"
    # Save the latest step
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        filename: "last:{step}-{val_loss:.3f}-{train_loss:.3f}"
    # Save a checkpoint every 1000 steps
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        every_n_train_steps: 1000
        filename: "step:{step}-{val_loss:.3f}-{train_loss:.3f}"
        save_top_k: -1