from k_level_policy_gradients.src.utils.serialization import Serializable
import numpy as np


def to_parameter(x):
    if isinstance(x, Parameter):
        return x
    else:
        return Parameter(x)


class Parameter(Serializable):
    """
    This class implements function to manage parameters, such as learning rate.
    It also allows to have a single parameter for each state of state-action
    tuple.

    """

    def __init__(self, initial_value, min_value=None, max_value=None):
        """
        Constructor.

        Args:
            value (float): initial value of the parameter;
            min_value (float, None): minimum value that the parameter can reach
                when decreasing;
            max_value (float, None): maximum value that the parameter can reach
                when increasing;

        """
        self._initial_value = initial_value
        self._min_value = min_value
        self._max_value = max_value
        self._n_updates = 0

        self._add_save_attr(
            _initial_value="primitive",
            _min_value="primitive",
            _max_value="primitive",
            _n_updates="primitive",
        )

    def __call__(self):
        """
        Update and return the parameter in the provided index.

        Returns:
            The updated parameter in the provided index.

        """

        self.update_counter()

        return self.get_value()

    def get_value(self):
        """
        Return the current value of the parameter.

        Returns:
            The current value of the parameter.

        """
        value = self.compute()

        if self._min_value is None and self._max_value is None:
            return value
        else:
            return np.clip(value, self._min_value, self._max_value)

    def update_counter(self):
        """
        Updates the number of visit of the parameter.
        """
        self._n_updates += 1

    def compute(self):
        """
        Returns:
            The value of the parameter in the provided index.

        """
        return self._initial_value

    @property
    def initial_value(self):
        """
        Returns:
            The initial value of the parameters.

        """
        return self._initial_value


class DelayedLinearParameter(Parameter):
    """
    This class implements a linearly changing parameter according to the number
    of times it has been used.

    args:
        value (float): the initial value of the parameter;
        threshold_value (float): the value of the parameter after n_start + n_end updates;
        n_start (int): the number of updates to start updating;
        n_end (int): the number of updates after which the linear change ends;
        size (tuple): the shape of the parameter

        e.g. to create a delaying epsilon that starts at 1.0, stays at 1.0 for 1000 steps,
        and decays to 0.1 by 2000 steps, use DelayedLinearParameter(1.0, 0.1, 1000, 2000)
    """

    def __init__(self, initial_value, threshold_value, n_start, n_end):
        self._coeff = (threshold_value - initial_value) / (n_end - n_start)

        self._n_start = n_start
        self._n_end = n_end

        if self._coeff >= 0:  # increasing
            super().__init__(
                initial_value=initial_value, min_value=None, max_value=threshold_value
            )
        else:  # decreasing
            super().__init__(
                initial_value=initial_value, min_value=threshold_value, max_value=None
            )

        self._add_save_attr(
            _n_start="primitive", _n_end="primitive", _coeff="primitive"
        )

    def compute(self):
        if self._n_updates < self._n_start:
            return self._initial_value
        else:
            return self._coeff * (self._n_updates - self._n_start) + self._initial_value
