# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from transformers import (
    TrainerCallback,
    TrainerControl,
    TrainerState,
    TrainingArguments,
)

from swift.utils import get_logger

logger = get_logger()


class EarlyStopCallback(TrainerCallback):
    """An early stop implementation"""

    def __init__(self, total_interval=3):
        self.best_metric = None
        self.interval = 0
        self.total_interval = total_interval

    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        operator = np.greater if args.greater_is_better else np.less
        if self.best_metric is None or operator(state.best_metric, self.best_metric):
            self.best_metric = state.best_metric
        else:
            self.interval += 1

        if self.interval >= self.total_interval:
            logger.info(
                f"Training stop because of eval metric is stable at step {state.global_step}"
            )
            control.should_training_stop = True


extra_callbacks = []
# This example shows a simple example of EarlyStop Callback, uncomment this to use
# extra_callbacks = [EarlyStopCallback()]
