import numpy as np
import mxnet as mx

from gluonts.evaluation import Evaluator
from gluonts.dataset.common import Dataset
from gluonts.mx import copy_parameters, GluonPredictor
from gluonts.mx.trainer.callback import Callback


class EarlyStopping(Callback):
    """
    Early Stopping mechanism based on the validation score.

    Parameters
    ----------
    metric
        The metric on which to base the early stopping on.
    patience
        Number of epochs to train on given the metric did not improve more than min_delta.
    min_delta
        Minimum change in the monitored metric counting as an improvement
    verbose
        Controls, if the validation metric is printed after each epoch.
    minimize_metric
        The metric objective.
    num_samples
        The amount of samples drawn to calculate the inference metrics.
    """

    def __init__(
        self,
        patience: int = 10,
        min_delta: float = 0.0,
        verbose: bool = True,
        minimize_metric: bool = True,
        num_samples: int = 100,
    ):
        assert patience >= 0, "EarlyStopping Callback patience needs to be >= 0"
        assert min_delta >= 0, "EarlyStopping Callback min_delta needs to be >= 0.0"
        assert num_samples >= 1, "EarlyStopping Callback num_samples needs to be >= 1"

        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.num_samples = num_samples

        if minimize_metric:
            self.best_metric_value = np.inf
            self.is_better = np.less
        else:
            self.best_metric_value = -np.inf
            self.is_better = np.greater

        self.validation_metric_history: List[float] = []
        self.best_network = None
        self.n_stale_epochs = 0

    def on_epoch_end(
        self,
        epoch_no: int,
        epoch_loss: float,
        training_network: mx.gluon.nn.HybridBlock,
        trainer: mx.gluon.Trainer,
        best_epoch_info: dict,
        ctx: mx.Context,
    ) -> bool:
        should_continue = True
        current_metric_value = epoch_loss
        self.validation_metric_history.append(current_metric_value)

        if self.verbose:
            print(
                f"Validation loss: {current_metric_value}, best: {self.best_metric_value}"
            )

        if self.is_better(current_metric_value, self.best_metric_value):
            self.best_metric_value = current_metric_value

            self.n_stale_epochs = 0
        else:
            self.n_stale_epochs += 1
            if self.n_stale_epochs == self.patience:
                should_continue = False
                print(
                    f"EarlyStopping callback initiated stop of training at epoch {epoch_no}."
                )

        return should_continue
