from torch import nn


class EarlyStopping(nn.Module):
    """
    Early stopping to stop training when the validation loss does not improve.

    Args:
        patience (int): Number of epochs with no improvement after which
        training will be stopped.
        threshold (float): Minimum change to qualify as an improvement.
        mode (str): One of {"min", "max"}. In "min" mode, training will stop
        when the quantity monitored has stopped decreasing; in "max" mode it
        will stop when the quantity monitored has stopped increasing.
    """

    def __init__(
        self,
        patience: int = 5,
        threshold: float = 1e-6,
        mode: str = "min",
    ) -> None:
        super().__init__()
        assert mode in ["min", "max"], "mode must be 'min' or 'max'"
        self.patience = patience
        self.threshold = threshold
        self.num_bad_epochs = 0
        self.best = float("inf") if mode == "min" else float("-inf")
        self.mode = mode

    def is_better_than_current(self, val_loss: float) -> bool:
        """
        Check if the current validation loss is better than the best one.
        This is used to determine if the model has improved.

        Args:
            val_loss (float): The current validation loss.

        Returns:
            bool: True if the current validation loss is better than the best one,
            False otherwise.
        """
        if self.mode == "min":
            return val_loss < self.best - self.threshold

        # Other mode is "max"
        return val_loss > self.best + self.threshold

    def step(self, val_loss: float) -> bool:
        """
        Check if the training should be stopped based on the validation loss.
        If the validation loss has not improved for a number of epochs
        equal to patience, training will be stopped.
        This function should be called after each epoch.

        Args:
            val_loss (float): The current validation loss.

        Returns:
            bool: True if training should be stopped, False otherwise.
        """
        if self.is_better_than_current(val_loss):
            self.best = val_loss
            self.num_bad_epochs = 0
            return False

        self.num_bad_epochs += 1
        return self.num_bad_epochs > self.patience
