from k_level_policy_gradients.src.utils.serialization import Serializable


class Agent(Serializable):
    """
    This class implements the functions to manage the agent (e.g. move the agent
    following its policy).

    """

    def __init__(self, mdp_info, policy, idx_agent):
        """
        Constructor.

        Args:
            mdp_info (MDPInfo): information about the MDP;
            policy (Policy): the policy followed by the agent;
            features (object, None): features to extract from the state.

        """
        self.mdp_info = mdp_info
        self.policy = policy
        self._idx_agent = idx_agent

        self._add_save_attr(
            mdp_info="mushroom",
            policy="mushroom",
            _idx_agent="primitive",
        )

    def fit(self, dataset):
        """
        Fit step.

        Args:
            dataset (list): the dataset.

        """
        raise NotImplementedError("Agent is an abstract class")

    def draw_action(self, state):
        """
        Return the action to execute in the given state. It is the action
        returned by the policy or the action set by the algorithm (e.g. in the
        case of SARSA).

        Args:
            state (np.ndarray): the state where the agent is.

        Returns:
            The action to be executed.

        """
        return self.policy.draw_action(state)

    def episode_start(self):
        """
        Called by the agent when a new episode starts.

        """
        self.policy.reset()

    def stop(self):
        """
        Method used to stop an agent. Useful when dealing with real world
        environments, simulators, or to cleanup environments internals after
        a core learn/evaluate to enforce consistency.

        """
        pass

    def set_logger(self, logger):
        """
        Setter that can be used to pass a logger to the algorithm

        Args:
            logger (Logger): the logger to be used by the algorithm.

        """
        self._logger = logger

    def set_profiler(self, profiler):
        self.profiler = profiler

    def set_random_mode(self):
        self.policy.set_mode("random")

    def set_training_mode(self):
        self.policy.set_mode("train")

    def set_testing_mode(self):
        self.policy.set_mode("test")
