""" File copied from rllib
References: https://docs.ray.io/en/latest/_modules/ray/tune/stopper.html#Stopper
"""

import time
import numpy as np

from expground.logger import Log
from expground.types import Any, Dict


class Stopper:
    """Base class for implementing a Tune experiment stopper.

    Allows users to implement experiment-level stopping via ``stop_all``. By
    default, this class does not stop any trials. Subclasses need to
    implement ``__call__`` and ``stop_all``.

    .. code-block:: python

        import time
        from ray import tune
        from ray.tune import Stopper

        class TimeStopper(Stopper):
            def __init__(self):
                self._start = time.time()
                self._deadline = 300

            def __call__(self, trial_id, result):
                return False

            def stop_all(self):
                return time.time() - self._start > self.deadline

        tune.run(Trainable, num_samples=200, stop=TimeStopper())

    """

    def __init__(self):
        self._global_step = 0

    @property
    def counter(self) -> int:
        return self._global_step

    def step(
        self,
        rollout_statis: Dict,
        training_statis: Dict,
        time_step: int,
        episode_th: int,
    ):
        self._global_step += 1

    def is_terminal(self) -> bool:
        raise NotImplementedError

    def stop_all(self):
        """Returns true if the experiment should be terminated."""
        raise NotImplementedError

    def reset(self):
        self._global_step = 0


class CombinedStopper(Stopper):
    def __init__(self, *stoppers: Stopper):
        self._stoppers = stoppers
        self._global_step = 0

    @property
    def counter(self):
        return self._global_step

    def reset(self):
        _ = [s.reset() for s in self._stoppers]
        self._global_step = 0

    def step(
        self,
        rollout_statis: Dict,
        training_statis: Dict,
        time_step: int,
        episode_th: int,
    ):
        _ = [
            s.step(rollout_statis, training_statis, time_step, episode_th)
            for s in self._stoppers
        ]
        super(CombinedStopper, self).step(
            rollout_statis, training_statis, time_step, episode_th
        )

    def is_terminal(self) -> bool:
        terminals = [s.is_terminal() for s in self._stoppers]
        Log.debug(f"stopper terminal state: {terminals}")
        return all(terminals)


# class NoopStopper(Stopper):
#     def __call__(self, trial_id, result):
#         return False

#     def stop_all(self):
#         return False


# class FunctionStopper(Stopper):
#     def __init__(self, function):
#         self._fn = function

#     def __call__(self, trial_id, result):
#         return self._fn(trial_id, result)

#     def stop_all(self):
#         return False

#     @classmethod
#     def is_valid_function(cls, fn):
#         is_function = callable(fn) and not issubclass(type(fn), Stopper)
#         if is_function and hasattr(fn, "stop_all"):
#             raise ValueError(
#                 "Stop object must be ray.tune.Stopper subclass to be detected "
#                 "correctly."
#             )
#         return is_function


class EarlyStopping(Stopper):
    def __init__(self, metric, std=0.001, top=10, mode="min", patience=0):
        """Create the EarlyStopping object.

        Stops the entire experiment when the metric has plateaued
        for more than the given amount of iterations specified in
        the patience parameter.

        Args:
            metric (str): The metric to be monitored.
            std (float): The minimal standard deviation after which
                the tuning process has to stop.
            top (int): The number of best model to consider.
            mode (str): The mode to select the top results.
                Can either be "min" or "max".
            patience (int): Number of epochs to wait for
                a change in the top models.

        Raises:
            ValueError: If the mode parameter is not "min" nor "max".
            ValueError: If the top parameter is not an integer
                greater than 1.
            ValueError: If the standard deviation parameter is not
                a strictly positive float.
            ValueError: If the patience parameter is not
                a strictly positive integer.
        """
        if mode not in ("min", "max"):
            raise ValueError("The mode parameter can only be" " either min or max.")
        if not isinstance(top, int) or top <= 1:
            raise ValueError(
                "Top results to consider must be"
                " a positive integer greater than one."
            )
        if not isinstance(patience, int) or patience < 0:
            raise ValueError("Patience must be" " a strictly positive integer.")
        if not isinstance(std, float) or std <= 0:
            raise ValueError(
                "The standard deviation must be" " a strictly positive float number."
            )
        self._mode = mode
        self._metric = metric
        self._patience = patience
        self._iterations = 0
        self._std = std
        self._top = top
        self._top_values = []

    def __call__(self, trial_id, result):
        """Return a boolean representing if the tuning has to stop."""
        self._top_values.append(result[self._metric])
        if self._mode == "min":
            self._top_values = sorted(self._top_values)[: self._top]
        else:
            self._top_values = sorted(self._top_values)[-self._top :]

        # If the current iteration has to stop
        if self.has_plateaued():
            # we increment the total counter of iterations
            self._iterations += 1
        else:
            # otherwise we reset the counter
            self._iterations = 0

        # and then call the method that re-executes
        # the checks, including the iterations.
        return self.stop_all()

    def has_plateaued(self):
        return (
            len(self._top_values) == self._top and np.std(self._top_values) <= self._std
        )

    def stop_all(self):
        """Return whether to stop and prevent trials from starting."""
        return self.has_plateaued() and self._iterations >= self._patience


class TimeoutStopper(Stopper):
    """Stops all trials after a certain timeout.

    Args:
        timeout (int|float|datetime.timedelta): Either a number specifying
            the timeout in seconds, or a `datetime.timedelta` object.
    """

    def __init__(self, timeout):
        from datetime import timedelta

        if isinstance(timeout, timedelta):
            self._timeout_seconds = timeout.total_seconds()
        elif isinstance(timeout, (int, float)):
            self._timeout_seconds = timeout
        else:
            raise ValueError(
                "`timeout` parameter has to be either a number or a "
                "`datetime.timedelta` object. Found: {}".format(type(timeout))
            )

        # To account for setup overhead, set the start time only after
        # the first call to `stop_all()`.
        self._start = None
        self._end = None
        self._terminal = False

        super(TimeoutStopper, self).__init__()

    def is_terminal(self) -> bool:
        return self._end - self._start >= self._timeout_seconds

    def step(
        self,
        rollout_statis: Dict,
        training_statis: Dict,
        time_step: int,
        episode_th: int,
    ):
        self._end = time.time()
        super(TimeoutStopper, self).step(
            rollout_statis, training_statis, time_step, episode_th
        )

    def reset(self):
        self._global_step = 0
        self._start = time.time()
        self._end = time.time()


class MaxIterationStopper(Stopper):
    def __init__(self, max_iteration: int) -> None:
        super(MaxIterationStopper, self).__init__()
        self._max_iteration = max_iteration

    def step(
        self,
        rollout_statis: Dict,
        training_statis: Dict,
        time_step: int,
        episode_th: int,
    ):
        super(MaxIterationStopper, self).step(
            rollout_statis, training_statis, time_step, episode_th
        )

    def is_terminal(self) -> bool:
        return self.counter >= self._max_iteration

    def stop_all(self):
        self._global_step = self._max_iteration


class TimeStepStopper(Stopper):
    def __init__(self, max_time_step) -> None:
        super(TimeStepStopper, self).__init__()
        self._max_time_step = max_time_step
        self._cur_time_step = 0

    def step(
        self,
        rollout_statis: Dict,
        training_statis: Dict,
        time_step: int,
        episode_th: int,
    ):
        super(TimeStepStopper, self).step(
            rollout_statis, training_statis, time_step, episode_th
        )
        self._cur_time_step = time_step

    def is_terminal(self) -> bool:
        return self._max_time_step <= self._cur_time_step

    def reset(self):
        super(TimeStepStopper, self).reset()
        self._cur_time_step = 0


class EpisodeStopper(TimeStepStopper):
    def step(
        self,
        rollout_statis: Dict,
        training_statis: Dict,
        time_step: int,
        episode_th: int,
    ):
        Stopper.step(self, rollout_statis, training_statis, time_step, episode_th)
        self._cur_time_step = episode_th


class ExploitabilityStopper(Stopper):
    def __init__(self, threshld: float) -> None:
        self._counter = 0
        self._threshold = threshld
        self._min_value = float("inf")

    @property
    def counter(self) -> int:
        return self._counter

    def step(self, rollout_statis, traininng_statis):
        raise NotImplementedError

    def is_terminal(self) -> bool:
        return self._min_value <= self._threshold

    def reset(self):
        self._counter = 0
        self._min_value = float("inf")


def get_stopper(config: Dict[str, Any]) -> Stopper:
    stoppers = []
    for k, v in config.items():
        if k == "timeout":
            stoppers.append(TimeoutStopper(v))
        elif k == "early_stopping":
            stoppers.append(EarlyStopping(**v))
        elif k == "max_iteration":
            stoppers.append(MaxIterationStopper(v))
        elif k == "exploitability":
            stoppers.append(ExploitabilityStopper(v))
        elif k == "max_timestep":
            stoppers.append(TimeStepStopper(v))
        elif k == "max_episode":
            stoppers.append(EpisodeStopper(v))
        elif k == "function":
            raise NotImplementedError
        elif k == "noop":
            raise NotImplementedError
        else:
            raise ValueError(f"No such stopper named with: `{k}`")

    if len(stoppers) > 1:
        return CombinedStopper(*stoppers)
    else:
        return stoppers[0]


DEFAULT_STOP_CONDITIONS = {"max_iteration": 2}
