import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_info
import numpy as np
print(f"pl: {pl.__version__}")


class AlteredModelCheckpoint(ModelCheckpoint):
    def __init__(self, ssl_train: bool = True, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if ssl_train:
            self.epoch_list = [x for x in range(0, 1001, 20)]
        else:
            # self.epoch_list = list(np.arange(10)) + [10, 30, 100, 200, 500, 1000]
            self.epoch_list = np.arange(1, 1002, 25)
        print(f"self.epoch_list:{self.epoch_list}")
        self.super = super(AlteredModelCheckpoint, self)

    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Save a checkpoint at the end of the validation stage."""

        print(f"in validation end! epoch: {trainer.current_epoch}.")

        print(f"self._should_skip_saving_checkpoint(trainer):{self._should_skip_saving_checkpoint(trainer)}")
        print(f"self._save_on_train_epoch_end:{self._save_on_train_epoch_end}")
        print(f"self._every_n_epochs < 1:{self._every_n_epochs < 1}")
        print(f"trainer.current_epoch not in self.epoch_list: "
              f"{trainer.current_epoch not in self.epoch_list}")

        if (
                (self._should_skip_saving_checkpoint(trainer) and trainer.current_epoch != 0)
                or self._save_on_train_epoch_end
                or self._every_n_epochs < 1
                or trainer.current_epoch not in self.epoch_list
        ):
            print(f"did not pass, epoch: {trainer.current_epoch}.")
            return

        self.save_checkpoint(trainer)
        print(f"saved checkpoint! epoch:{trainer.current_epoch}, dirpath: {self.dirpath}")
