from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
class MyEarlyStopping(EarlyStopping):

    def \
            __init__(self, *args, start_epoch=0, **kwargs):
        self.start_epoch = start_epoch
        super(MyEarlyStopping, self).__init__(*args, **kwargs)
    def on_train_epoch_end(self, trainer, pl_module) -> None:
        if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
            return
        raise NotImplementedError

    def on_validation_end(self, trainer, pl_module) -> None:
        if self._check_on_train_epoch_end or self._should_skip_check(trainer):
            return

        if trainer.current_epoch >= self.start_epoch:
            tmp = pl_module.metrics_stats['val']
            logs = {f"val_{name}": torch.tensor(value) for name, value in tmp.items()}


            self._run_early_stopping_check_with_logs(trainer, logs)


    def _run_early_stopping_check_with_logs(self, trainer, logs):
        if trainer.fast_dev_run or not self._validate_condition_metric(  # disable early_stopping with fast_dev_run
                logs
        ):  # short circuit if metric not present
            return

        current = logs[self.monitor].squeeze()
        should_stop, reason = self._evaluate_stopping_criteria(current)

        # stop every ddp process if any world process decides to stop
        should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
        trainer.should_stop = trainer.should_stop or should_stop
        if should_stop:
            self.stopped_epoch = trainer.current_epoch
        if reason and self.verbose:
            self._log_info(trainer, reason, self.log_rank_zero_only)
