from lightning.pytorch.callbacks.base import Callback


# https://github.com/Lightning-AI/lightning/issues/596#issuecomment-619112563
class ModelCheckpointAtEpochEnd(Callback):
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics['epoch'] = trainer.current_epoch
        if trainer.disable_validation:
            trainer.checkpoint_callback.on_validation_end(trainer, pl_module)