from k_level_policy_gradients.src.algorithms.agent import Agent
import numpy as np


class DiscreteRandomAgent(Agent):
    """
    Initialize with random policy object
    """

    def __init__(self, mdp_info, policy, idx_agent):

        self.num_actions = mdp_info.action_space[idx_agent].size

        super().__init__(mdp_info, policy, idx_agent)

    def fit(self, dataset):
        pass

    def draw_action(self, state):
        return np.array([np.random.randint(self.num_actions)])


class ContinuousRandomAgent(Agent):
    """
    Initialize with random policy object
    """

    def __init__(self, mdp_info, policy, idx_agent):

        self.action_space = mdp_info.action_space[idx_agent]

        super().__init__(mdp_info, policy, idx_agent)

    def fit(self, dataset):
        pass

    def draw_action(self, state):
        return self.action_space.sample()
