import numpy as np
from k_level_policy_gradients.src.utils.serialization import Serializable


class MDPInfo(Serializable):
    """
    This class is used to store the information of the environment.

    """

    def __init__(
        self,
        state_space,
        observation_space,
        action_space,
        discrete_actions,
        gamma,
        horizon=None,
        dt=1e-1,
        has_obs=False,
        has_action_masks=False,
        n_agents=1,
    ):
        """
        Constructor.

        Args:
             observation_space ([Box, Discrete]): the state space;
             action_space ([list, Box, Discrete]): the action spaces of the agents;
             gamma (float): the discount factor;
             horizon (int): the horizon.

        """
        self.state_space = state_space
        self.observation_space = observation_space
        self.action_space = action_space
        self.discrete_actions = discrete_actions
        self.gamma = gamma
        self.horizon = horizon
        self.dt = dt
        self.has_obs = has_obs
        self.has_action_masks = has_action_masks
        self.n_agents = n_agents

        self._add_save_attr(
            state_space="mushroom",
            observation_space="mushroom",
            action_space="mushroom",
            gamma="primitive",
            horizon="primitive",
            dt="primitive",
            has_obs="primitive",
            has_action_masks="primitive",
            n_agents="primitive",
        )

    @property
    def size(self):
        """
        Returns:
            The sum of the number of discrete states and discrete actions. Only
            works for discrete spaces.

        """
        return (
            self.observation_space.size
            + self.action_space[0].size
            + self.action_space[1].size
        )

    @property
    def shape(self):
        """
        Returns:
            The concatenation of the shape tuple of the state and action
            spaces.

        """
        return (
            self.observation_space.shape
            + self.action_space[0].shape
            + self.action_space[1].shape
        )


class Environment(Serializable):
    """
    Basic interface used by any multi agent environment.

    """

    def __init__(self, mdp_info):
        """
        Constructor.

        Args:
             mdp_info (MDPInfo): an object containing the info of the
                environment.

        """
        self._mdp_info = mdp_info

        self._add_save_attr(_mdp_info="mushroom")

    def seed(self, seed):
        """
        Set the seed of the environment.

        Args:
            seed (float): the value of the seed.

        """
        self.env.seed(seed)

    def reset(self):
        """
        Reset the current state.

        Returns:
            The current state.

        """
        raise NotImplementedError

    def step(self, actions):
        """
        Move the agent from its current state according to the action.

        Args:
            action, list[np.ndarray]: the list of actions to execute.

        Returns:
            The list of observations reached by the agent executing actions in their current
            state, the next state of the environment the rewards obtained in the transition and a flag to signal
            if the next state is absorbing. Also an additional info dictionary is
            returned (possibly empty).

        """
        raise NotImplementedError

    def render(self):
        raise NotImplementedError

    def stop(self):
        """
        Method used to stop an mdp. Useful when dealing with real world
        environments, simulators, or when using openai-gym rendering

        """
        pass

    @property
    def info(self):
        """
        Returns:
             An object containing the info of the environment.

        """
        return self._mdp_info

    @staticmethod
    def _bound(x, min_value, max_value):
        """
        Method used to bound state and action variables.

        Args:
            x: the variable to bound;
            min_value: the minimum value;
            max_value: the maximum value;

        Returns:
            The bounded variable.

        """
        return np.maximum(min_value, np.minimum(x, max_value))
