import warnings
from typing import Dict, List
# from trainer.trainers.base import BaseTrainer


class Callback(object):

    def __init__(self):
        self.trainer = None
        self.params = {}

    def set_params(self, params: Dict):
        self.params.update(params)

    def set_trainer(self, trainer):
        self.trainer = trainer

    def on_epoch_begin(self, epoch_number: int, logs: Dict):
        """
        Is called before the beginning of each epoch.
        Args:
            epoch_number (int): The epoch number.
            logs (dict): Usually an empty dict.
        """
        pass

    def on_epoch_end(self, epoch_number: int, logs: Dict):
        """
        Is called before the end of each epoch.
        Args:
            epoch_number (int): The epoch number.
            logs (dict): Contains the following keys:
                 * 'epoch': The epoch number.
                 * 'loss': The average loss of the batches.
                 * 'time': The computation time of the epoch.
                 * Other metrics: One key for each type of metrics. The metrics are also averaged.
                 * val_loss': The average loss of the batches on the validation set.
                 * Other metrics: One key for each type of metrics on the validation set. The metrics are also averaged.
        Example::
            logs = {'epoch': 6, 'time': 3.141519837, 'loss': 4.34462, 'accuracy': 0.766,
                    'val_loss': 5.2352, 'val_accuracy': 0.682}
        """
        pass

    def on_train_batch_begin(self, batch_number: int, logs: Dict):
        """
        Is called before the beginning of the training batch.
        Args:
            batch_number (int): The batch number.
            logs (dict): Usually an empty dict.
        """
        pass

    def on_train_batch_end(self, batch_number: int, logs: Dict):
        """
        Is called before the end of the training batch.
        Args:
            batch_number (int): The batch number.
            logs (dict): Usually an empty dict.
        """
        pass

    def on_eval_batch_begin(self, batch_number: int, logs: Dict):
        """
        Is called before the beginning of the testing batch.
        Args:
            batch_number (int): The batch number.
            logs (dict): Usually an empty dict.
        """
        pass

    def on_eval_batch_end(self, batch_number: int, logs: Dict):
        """
        Is called before the end of the testing batch.
        Args:
            batch_number (int): The batch number.
            logs (dict): Usually an empty dict.
        """
        pass

    def on_train_begin(self, logs: Dict):
        """
        Is called before the beginning of the training.
        Args:
            logs (dict): Usually an empty dict.
        """
        pass

    def on_train_epoch_begin(self, logs: Dict):
        pass

    def on_train_epoch_end(self, logs: Dict):
        pass

    def on_train_end(self, logs: Dict):
        """
        Is called before the end of the training.
        Args:
            logs (dict): Usually an empty dict.
        """
        pass

    def on_eval_begin(self, logs: Dict):
        """
        Is called before the beginning of the testing.
        Args:
            logs (dict): Usually an empty dict.
        """
        pass

    def on_eval_end(self, logs: Dict):
        """
        Is called before the end of the testing.
        Args:
            logs (dict): Usually an empty dict.
        """
        pass

    def on_backward_end(self, batch_number: int):
        """
        Is called after the backpropagation but before the optimization step.
        Args:
            batch_number (int): The batch number.
        """
        pass


class CallbackList:
    def __init__(self, callbacks: List[Callback]):
        callbacks = callbacks or []
        self.callbacks = list(callbacks)

    def append(self, callback: Callback):
        self.callbacks.append(callback)

    def set_params(self, params: Dict):
        for callback in self.callbacks:
            callback.set_params(params)

    def set_trainer(self, trainer):
        for callback in self.callbacks:
            callback.set_trainer(trainer)

    def on_train_epoch_begin(self, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_epoch_begin(logs)

    def on_train_epoch_end(self, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_epoch_end(logs)

    def on_epoch_begin(self, epoch_number: int, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_begin(epoch_number, logs)

    def on_epoch_end(self, epoch_number: int, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_end(epoch_number, logs)

    def on_train_batch_begin(self, batch_number: int, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_batch_begin(batch_number, logs)

    def on_train_batch_end(self, batch_number: int, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_batch_end(batch_number, logs)

    def on_eval_batch_begin(self, batch_number: int, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_eval_batch_begin(batch_number, logs)

    def on_eval_batch_end(self, batch_number: int, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_eval_batch_end(batch_number, logs)

    def on_train_begin(self, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_begin(logs)

    def on_train_end(self, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_train_end(logs)

    def on_eval_begin(self, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_eval_begin(logs)

    def on_eval_end(self, logs: Dict):
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_eval_end(logs)

    def on_backward_end(self, batch_number: int):
        for callback in self.callbacks:
            callback.on_backward_end(batch_number)

    def __iter__(self):
        return iter(self.callbacks)


class StepCallback(Callback):
    def __init__(self):
        super(StepCallback, self).__init__()
        self._step = 0

    @property
    def step(self):
        return self._step

    def on_train_batch_end(self, batch_number: int, logs: Dict):
        self._step += 1
