from transformers import EarlyStoppingCallback


class CustomEarlyStoppingCallback(EarlyStoppingCallback):
    """ Early stoping callback when the loss drops below the threshould """
    def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: float = 0.0):
        super().__init__(early_stopping_patience, early_stopping_threshold)

    def check_metric_value(self, args, state, control, metric_value):
        if metric_value > self.early_stopping_threshold:
            self.early_stopping_patience_counter = 0
        else:
            self.early_stopping_patience_counter += 1
