import gymnasium_robotics
from k_level_policy_gradients.src.core.environment import Environment, MDPInfo
from k_level_policy_gradients.src.utils.spaces import *


class Mamujoco(Environment):
    """
    Wraps a Mamujoco env to be compatible with MushroomRL multi-agent setup.
    """

    def __init__(
        self,
        scenario="",
        partitioning="",
        horizon=1000,
        gamma=0.99,
        seed=0,
        bool_render=False,
    ):
        """Create a new multi-agent Mamujoco env compatible with MushoomRL.

        Arguments:
            scenario (str): Name of the scenario to load.
            partitioning (str): Partitioning of the agent.

        """
        self._env = gymnasium_robotics.mamujoco_v1.parallel_env(
            scenario=scenario,
            agent_conf=partitioning,
            render_mode="human" if bool_render else None,
        )
        self._n_agents = self._env.num_agents
        self._agents = self._env.agents
        self._seed = seed
        self._initial_seed_set = False

        # Set the state, observation, and action spaces
        action_spaces = self._env.action_spaces
        observation_spaces = self._env.observation_spaces

        action_space = [
            Box(action_spaces[agent].low, action_spaces[agent].high)
            for agent in self._agents
        ]
        state_space_shape = self._env.state().shape
        state_space = Box(-np.inf, np.inf, shape=state_space_shape)
        observation_space = [
            Box(observation_spaces[agent].low, observation_spaces[agent].high)
            for agent in self._agents
        ]

        # Set the MDP info
        mdp_info = MDPInfo(
            state_space=state_space,
            observation_space=observation_space,
            action_space=action_space,
            discrete_actions=False,
            gamma=gamma,
            horizon=horizon,
            has_obs=True,
            has_action_masks=False,
            n_agents=self._n_agents,
        )

        super().__init__(mdp_info)

    def reset(self):
        """Reset the environment.

        Returns:
            numpy.ndarray: The initial state of the environment.

        """
        if not self._initial_seed_set:
            self._env.reset(
                seed=self._seed
            )  # set seed for rest of experiment (after the first reset with a seed, the following resets without a seed are reproducible)
            self._initial_seed_set = True
        obs, info = self._env.reset()

        obs = [obs[agent] for agent in self._agents]
        step = {"state": self._env.state(), "obs": obs, "info": info}

        return step

    def step(self, actions):
        """Returns observations for agents after applying action to the environment.

        Returns
        -------
            obs_list (list): New observations for each agent. (Agents with 0 health get a null observation)
            rewards (list): Reward for each agent.
            absorbing (bool): Whether the episode has entered an absorbing state.
            info (dict): Optional info values (battle_won, dead_allies, dead_enemies, episode_limit).
            state (list): State of the environment.
            action_masks (list): Mask of available actions for each agent for the next step.
        """

        # rewards and terminated are same for all agents
        actions_dict = {agent: actions[i] for i, agent in enumerate(self._agents)}
        observations, reward, terminated, _, info = self._env.step(actions_dict)

        state = self._env.state()
        obs = [observations[agent] for agent in self._agents]
        rewards = [reward[agent] for agent in self._agents]

        if any(terminated.values()):
            pass
        absorbing = any(terminated.values())

        step = {
            "state": state,
            "obs": obs,
            "rewards": rewards,
            "absorbing": absorbing,
            "info": info,
        }
        return step

    def render(self, render_info):
        """Render the environment."""
        self._env.render()

    def stop(self):
        """Close the environment"""
        self._env.close()
