import pettingzoo.mpe as mpe
import numpy as np
import time
from gymnasium import spaces as gymnasium_spaces
from k_level_policy_gradients.src.core.environment import Environment, MDPInfo
from k_level_policy_gradients.src.utils.spaces import *


class MPE(Environment):
    """
    Interface for PettingZoo Multi Particle Environments (MPE).
    Possible to use any MPE environent just using the environment name.
    """

    def __init__(
        self,
        env_name,
        horizon=None,
        gamma=0.99,
        wrappers=None,
        wrappers_args=None,
        continuous_actions=True,
        bool_render=False,
        **env_args
    ):
        """
        Constructor.

        Args:
            env_name (str): gym id of the environment;
            horizon (int): the horizon. If None, use the one from PettingZoo;
            gamma (float, 0.99): the discount factor;
            wrappers (list, None): list of wrappers to apply over the environment. It
                is possible to pass arguments to the wrappers by providing
                a tuple with two elements: the gym wrapper class and a
                dictionary containing the parameters needed by the wrapper
                constructor;
            wrappers_args (list, None): list of list of arguments for each wrapper;
            ** env_args: other gym environment parameters.

        Agents in MPE are named "agent_0", "agent_1", etc.
        """
        render_mode = "human" if bool_render else None
        if env_name == "simple_v3":
            self.env = mpe.simple_v3.env(
                continuous_actions=continuous_actions,
                render_mode=render_mode,
                **env_args
            )
        elif env_name == "simple_tag_v3":
            self.env = mpe.simple_tag_v3.env(
                continuous_actions=continuous_actions,
                render_mode=render_mode,
                **env_args
            )
        elif env_name == "simple_spread_v3":
            self.env = mpe.simple_spread_v3.env(
                continuous_actions=continuous_actions,
                render_mode=render_mode,
                N=2,
                **env_args
            )
        elif env_name == "simple_push_v3":
            self.env = mpe.simple_push_v3.env(
                continuous_actions=continuous_actions,
                render_mode=render_mode,
                **env_args
            )
        elif env_name == "simple_adversary_v3":
            self.env = mpe.simple_adversary_v3.env(
                continuous_actions=continuous_actions,
                render_mode=render_mode,
                **env_args
            )

        else:
            raise ValueError("Unknown MPE environment")

        self._first = True
        self._render_dt = (
            self.env.unwrapped.dt if hasattr(self.env.unwrapped, "dt") else 0.01
        )

        self.env.reset()
        self.num_agents = self.env.num_agents
        self.agent_ids = self.env.agents

        if wrappers is not None:
            if wrappers_args is None:
                wrappers_args = [dict()] * len(wrappers)
            for wrapper, args in zip(wrappers, wrappers_args):
                if isinstance(wrapper, tuple):
                    self.env = wrapper[0](self.env, *args, **wrapper[1])
                else:
                    self.env = wrapper(self.env, *args, **env_args)

        horizon = self._set_horizon(self.env, horizon)

        action_space = []
        observation_space = []
        for agent in self.agent_ids:
            action_space_agent = self.env.action_space(agent)
            action_space_agent_mrl = self._convert_gymnasium_space(action_space_agent)
            action_space.append(action_space_agent_mrl)
            observation_space_agent = self.env.observation_space(agent)
            observation_space_agent_mrl = self._convert_gymnasium_space(
                observation_space_agent
            )
            observation_space.append(observation_space_agent_mrl)

        mdp_info = MDPInfo(
            observation_space, action_space, gamma, horizon, self._render_dt
        )

        super().__init__(mdp_info)

    def reset(self, state=None):
        if state is None:
            self.env.reset()
            return [np.atleast_1d(obs) for obs in self._get_observations()]
        else:
            self.env.reset()
            self.env.state = state
            return [np.atleast_1d(obs) for obs in state]

    def step(self, actions):
        """
        self.env.agent_iter() returns an iterator over the agents in the environment.
        self.env.step(action) takes an action for the current agent and switches control to the next agent.
        After all agents have taken an action, the environment advances to the next time step with an internal
        world step.
        """
        action_clipped = self._clip_action(actions)

        observations = []
        rewards = []
        terminations = []
        info_dict = dict()
        agent_counter = 0
        for _ in self.env.agent_iter():
            observation, reward, termination, truncation, info = self.env.last()
            if termination or truncation:
                action = None
            else:
                action = np.float32(action_clipped[agent_counter])
            self.env.step(action)

            observations.append(observation)
            rewards.append(reward)
            terminations.append(termination)
            info_dict[self.agent_ids[agent_counter]] = info

            agent_counter += 1
            if agent_counter == self.num_agents:
                absorbing = np.array(terminations).all()
                return observations, rewards, absorbing, info_dict

    def render(self, mode="human"):
        if self._first:
            self.env.render()
            self._first = False
            time.sleep(self._render_dt)

    def stop(self):
        try:
            self.env.close()
        except:
            pass

    def _clip_action(self, action):
        """
        Clip action to environment action space bounds
        """
        action_clipped = [
            np.clip(
                a,
                a_min=self.info.action_space[i].low,
                a_max=self.info.action_space[i].high,
            )
            for i, a in enumerate(action)
        ]
        return action_clipped

    def _get_observations(self):
        observations = []
        for agent in self.agent_ids:
            observation = self.env.observe(agent)
            observations.append(observation)
        return observations

    def _get_rewards(self):
        rewards = []
        for agent in self.agent_ids:
            reward = self.env.rewards[agent]
            rewards.append(reward)
        return rewards

    def _get_terminations(self):
        terminations = []
        for agent in self.agent_ids:
            termination = self.env.terminations[agent]
            terminations.append(termination)
        return terminations

    def _get_truncations(self):
        truncations = []
        for agent in self.agent_ids:
            truncation = self.env.truncations[agent]
            truncations.append(truncation)
        return truncations

    def _get_infos(self):
        infos = []
        for agent in self.agent_ids:
            info = self.env.infos[agent]
            infos.append(info)
        return infos

    @staticmethod
    def _set_horizon(env, horizon):

        if horizon is None:
            if not hasattr(env, "max_cycles"):
                raise RuntimeError("This MPE environment has no specified time limit!")
            horizon = env.max_cycles

        if hasattr(env, "max_cycles"):
            env.max_cycles = np.inf  # Hack to ignore PettingZoo time limit.

        return horizon

    @staticmethod
    def _convert_gymnasium_space(space):
        if isinstance(space, gymnasium_spaces.Discrete):
            return Discrete(space.n)
        elif isinstance(space, gymnasium_spaces.Box):
            return Box(low=space.low, high=space.high, shape=space.shape)
        else:
            raise ValueError
