from typing import Dict, Any
import logging


class TrainingCallback:
    """Base class for callbacks used during training"""
    
    def after_epoch(self, metrics: Dict[str, Any]) -> None:
        """Called after each training epoch"""
        pass


class EarlyStopping(TrainingCallback):
    """
    Early stopping callback to halt training when a monitored metric stops improving.
    """
    
    def __init__(self, metric='val_loss', patience=10, restore_best=True, minimize=True):
        """
        Initialize the early stopping callback.
        
        Parameters
        ----------
        metric: str, default: 'val_loss'
            The metric to monitor
        patience: int, default: 10
            Number of epochs with no improvement after which training will be stopped
        restore_best: bool, default: True
            Whether to restore the best model weights after training stops
        minimize: bool, default: True
            Whether the metric should be minimized (True) or maximized (False)
        """
        self.metric = metric
        self.patience = patience
        self.restore_best = restore_best
        self.minimize = minimize
        self.counter = 0
        self.best_score = None
        self.epoch = 0
        self.state_dict = None
        self.model = None
        
    def after_epoch(self, metrics: Dict[str, Any]) -> None:
        """
        Called after each training epoch.
        
        Parameters
        ----------
        metrics: Dict[str, Any]
            Dictionary of metrics from the training
        
        Raises
        ------
        CallbackException
            When training should be stopped
        """
        self.epoch += 1
        
        # Get the metric value
        current = metrics.get(self.metric)
        if current is None:
            raise KeyError(f"Metric '{self.metric}' not found in metrics")
        
        # For the first epoch, just save the score
        if self.best_score is None:
            self.best_score = current
            if self.restore_best and self.model is not None:
                self.state_dict = self.model.state_dict().copy()
            return
        
        # Check if the metric improved
        if (self.minimize and current < self.best_score) or (not self.minimize and current > self.best_score):
            # Improvement
            self.best_score = current
            self.counter = 0
            if self.restore_best and self.model is not None:
                self.state_dict = self.model.state_dict().copy()
        else:
            # No improvement
            self.counter += 1
            if self.counter >= self.patience:
                # If patience is exceeded, stop training
                if self.restore_best and self.model is not None and self.state_dict is not None:
                    self.model.load_state_dict(self.state_dict)
                raise Exception(f"Early stopping after epoch {self.epoch} (patience {self.patience}).")


def get_callbacks_from_config(train_cfg) -> list:
    """
    Create callbacks based on the training configuration.
    
    Parameters
    ----------
    train_cfg: object
        Training configuration with stopping parameters
        
    Returns
    -------
    list
        List of callback objects
    """
    callbacks = []
    
    if hasattr(train_cfg, 'stopping_mode') and train_cfg.stopping_mode is not None:
        if train_cfg.stopping_patience > 0:
            early_stopping = EarlyStopping(
                patience=train_cfg.stopping_patience,
                restore_best=train_cfg.stopping_restore_best,
                metric=train_cfg.stopping_metric,
                minimize=train_cfg.stopping_minimize
            )
            callbacks.append(early_stopping)
    
    return callbacks
