import jax.numpy as jnp

from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.typing import NnParams


class EarlyStopping:
    def __init__(
        self,
        patience: int,
        min_relative_improvement: float,
        checkpointer: CheckpointManager | None,
        prefix: str = '',
    ):
        self.patience = patience
        self.__best_loss = float('inf')
        self.__loss_at_last_reset = float('inf')
        self.__best_params = None
        self.__counter = 0
        self.__checkpointer = checkpointer
        self.__prefix = prefix
        self.__min_relative_improvement = min_relative_improvement

    def increment_patience_counter(self) -> bool:
        """
        Returns True if the patience counter has reached the patience threshold.
        """
        self.__counter += 1
        print(f'Early stopping counter: {self.__counter} / {self.patience}')
        return self.__counter >= self.patience

    def stop(self, loss: float, params: NnParams) -> bool:
        """
        Returns True if the loss is NaN or if the loss is not improving
        for `patience` epochs. Otherwise, returns False.^
        """
        if jnp.isnan(loss) or jnp.isinf(loss):
            print('Loss is NaN or infinite, stopping early')
            return True
        if loss < self.__best_loss:
            if loss < self.__loss_at_last_reset * (1 - self.__min_relative_improvement):
                self.__loss_at_last_reset = loss
                self.__counter = 0
            self.__best_loss = loss
            self.__best_params = params
            if self.__checkpointer is not None:
                self.__checkpointer.save_best_params(params, self.__prefix)
            return False
        else:
            return self.increment_patience_counter()

    @property
    def best_params(self) -> NnParams:
        return self.__best_params
