import numpy as np
from k_level_policy_gradients.src.policy.policy import ParametricPolicy
from k_level_policy_gradients.src.utils.parameters import to_parameter


class EpsGreedy(ParametricPolicy):
    """
    Epsilon greedy policy.

    """

    def __init__(self, epsilon):
        """
        Constructor.

        Args:
            epsilon ([float, Parameter]): the exploration coefficient. It indicates
                the probability of performing a random actions in the current
                step.

        """
        super().__init__()

        self._epsilon = to_parameter(epsilon)

        self._add_save_attr(_epsilon="mushroom")

    def draw_action(self, state, action_mask=None):
        if self._mode == "random":
            return np.array([np.random.choice(np.where(action_mask)[0])])
        elif self._mode == "test":
            q = self._approximator.predict(state)
            q_mask = np.where(action_mask, q.squeeze(), -np.inf)
            max_a = np.argwhere(q_mask == np.max(q_mask)).ravel()
            if len(max_a) > 1:
                max_a = np.array([np.random.choice(max_a)])
            return max_a
        elif self._mode == "train":
            if not np.random.uniform() < self._epsilon():
                q = self._approximator.predict(state)
                q_mask = np.where(action_mask, q.squeeze(), -np.inf)
                max_a = np.argwhere(q_mask == np.max(q_mask)).ravel()
                if len(max_a) > 1:
                    max_a = np.array([np.random.choice(max_a)])
                return max_a
            return np.array([np.random.choice(np.where(action_mask)[0])])

    def draw_action_hidden(self, state, hidden_state, action_mask=None):
        # Draw Q-values from recurrent network
        q, h = self._approximator.predict(state, hidden_state, output_hidden=True)
        if self._mode == "random":
            return np.array([np.random.choice(np.where(action_mask)[0])]), h
        elif self._mode == "test":
            q_mask = np.where(action_mask, q.squeeze(), -np.inf)
            max_a = np.argwhere(q_mask == np.max(q_mask)).ravel()
            if len(max_a) > 1:
                max_a = np.array([np.random.choice(max_a)])
            return max_a, h
        elif self._mode == "train":
            if not np.random.uniform() < self._epsilon():  # Draw action from Q-values
                q_mask = np.where(action_mask, q.squeeze(), -np.inf)
                max_a = np.argwhere(q_mask == np.max(q_mask)).ravel()
                if len(max_a) > 1:
                    max_a = np.array([np.random.choice(max_a)])
                return max_a, h
            return (
                np.array([np.random.choice(np.where(action_mask)[0])]),
                h,
            )  # Draw random action

    def set_mode(self, mode):
        self._mode = mode
