from logging import getLogger

import numpy as np

from pfrl import explorer


def select_action_epsilon_greedily(epsilon, random_action_func, greedy_action_func):
    if np.random.rand() < epsilon:
        return random_action_func(), False
    else:
        return greedy_action_func(), True


class ConstantEpsilonGreedy(explorer.Explorer):
    """Epsilon-greedy with constant epsilon.
    Args:
      epsilon: epsilon used
      random_action_func: function with no argument that returns action
      logger: logger used
    """

    def __init__(self, epsilon, random_action_func, logger=getLogger(__name__)):
        assert epsilon >= 0 and epsilon <= 1
        self.epsilon = epsilon
        self.random_action_func = random_action_func
        self.logger = logger

    def select_action(self, t, greedy_action_func, action_value=None):
        a, greedy = select_action_epsilon_greedily(
            self.epsilon, self.random_action_func, greedy_action_func
        )
        greedy_str = "greedy" if greedy else "non-greedy"
        self.logger.debug("t:%s a:%s %s", t, a, greedy_str)
        return a

    def __repr__(self):
        return "ConstantEpsilonGreedy(epsilon={})".format(self.epsilon)
    
class DecayEpsilonGreedy(explorer.Explorer):
    """Epsilon-greedy with constant epsilon.
    Args:
      epsilon: epsilon used
      random_action_func: function with no argument that returns action
      logger: logger used
    """

    def __init__(self, epsilon, random_action_func, decay_factor, end_epsilon, logger=getLogger(__name__)):
        assert epsilon >= 0 and epsilon <= 1
        self.epsilon = epsilon
        self.end_epsilon = end_epsilon
        self.random_action_func = random_action_func
        self.decay_factor = decay_factor
        self.logger = logger

    def select_action(self, t, greedy_action_func, action_value=None):
        a, greedy = select_action_epsilon_greedily(
            self.epsilon, self.random_action_func, greedy_action_func
        )
        greedy_str = "greedy" if greedy else "non-greedy"
        self.epsilon = max(self.decay_factor * self.epsilon, self.end_epsilon)
        self.logger.debug("t:%s a:%s %s", t, a, greedy_str)
        return a

    def __repr__(self):
        return "DecayEpsilonGreedy(epsilon={})".format(self.epsilon)
    

class LinearDecayEpsilonGreedy(explorer.Explorer):
    """Epsilon-greedy with linearly decayed epsilon
    Args:
      start_epsilon: max value of epsilon
      end_epsilon: min value of epsilon
      decay_steps: how many steps it takes for epsilon to decay
      random_action_func: function with no argument that returns action
      logger: logger used
    """

    def __init__(
        self,
        start_epsilon,
        end_epsilon,
        decay_steps,
        start_steps,
        random_action_func,
        logger=getLogger(__name__),
    ):
        assert start_epsilon >= 0 and start_epsilon <= 1
        assert end_epsilon >= 0 and end_epsilon <= 1
        assert decay_steps >= 0
        self.start_epsilon = start_epsilon
        self.end_epsilon = end_epsilon
        self.decay_steps = decay_steps
        self.start_steps = start_steps
        self.random_action_func = random_action_func
        self.logger = logger
        self.epsilon = start_epsilon

    def compute_epsilon(self, t):
        if t > self.decay_steps:
            return self.end_epsilon
        else:
            epsilon_diff = self.end_epsilon - self.start_epsilon
            return self.start_epsilon + epsilon_diff * (t / self.decay_steps)

    def select_action(self, t, greedy_action_func, action_value=None):        
        self.epsilon = self.compute_epsilon(t)
        if t >= self.start_steps:
            a, greedy = select_action_epsilon_greedily(
                self.epsilon, self.random_action_func, greedy_action_func
            )
        else:
            a = self.random_action_func()
            greedy = 'non-greedy'
        greedy_str = "greedy" if greedy else "non-greedy"
        self.logger.debug("t:%s a:%s %s", t, a, greedy_str)
        return a

    def __repr__(self):
        return "LinearDecayEpsilonGreedy(epsilon={})".format(self.epsilon)