from trident.core.module import TridentModule


class TridentModuleSaveOnVal(TridentModule):
    def __init__(self, checkpoint_dir: str, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.checkpoint_dir = checkpoint_dir
        self._validation_epoch = 0

    def on_validation_end(self) -> None:
        """This makes it a lot easier to store checkpoints after Lightning's `trainer.val_check_interval`.

        Warning: this will be triggered by `trainer.num_sanity_val_steps`.
        """
        super().on_validation_end()
        if not self.trainer.sanity_checking and self.checkpoint_dir:
            from pathlib import Path

            path = Path(self.checkpoint_dir).joinpath(f"{self._validation_epoch}.ckpt")
            self.trainer.save_checkpoint(path, weights_only=True)
            self._validation_epoch += 1
