import numpy as np

class ActionSelector():
    def __init__(self, num_actions, policy_name, learner, num_aux_tasks):
        self.num_actions = num_actions
        self.num_aux_tasks = num_aux_tasks
        self.learner = learner
        self.policy_name = policy_name
        self.aux_to_follow = 0
        self.step = 0
        self.k = 10
        self.a_tm1 = np.random.randint(self.num_actions)
        self.sticky_prob = 0.9

    def select_action(self, obs_t, **kwargs):
        if self.policy_name == 'round_robin' and self.step == self.k:
            self.aux_to_follow = (self.aux_to_follow + 1) % self.num_aux_tasks
            self.step = 0

        if self.policy_name == 'main':
            a_t = self.learner.select_action(obs_t)
        elif self.policy_name == 'random':
            a_t = np.random.randint(self.num_actions)
        elif self.policy_name == 'push_in_vel_direction':
            a_t = 1 - int(np.sign(kwargs['theta_dot_1']))
        elif self.policy_name == 'round_robin':
            a_t = self.learner.select_action(obs_t, follow_main = False, aux = self.aux_to_follow)
        elif self.policy_name == 'follow_aux':
            a_t = self.learner.select_action(obs_t, follow_main=False, aux=kwargs['aux_to_follow'])
        elif self.policy_name == 'sticky_actions':
            if np.random.random() < self.sticky_prob:
                a_t = np.copy(self.a_tm1)
            else:
                a_t = np.random.randint(self.num_actions)
                self.a_tm1 = np.copy(a_t)
        else:
            a_t = None

        self.step += 1
        return a_t