# The file is adapted from https://github.com/uoe-agents/epymarl/blob/main/src/envs/pz_wrapper.py

import gymnasium as gym
import pettingzoo
import numpy as np

import importlib
from typing import Tuple, List, Dict

from .multiagentenv import MultiAgentEnv

class PettingZooWrapper(MultiAgentEnv):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 5,
    }

    def __init__(self, lib_name, env_name, seed, **kwargs):
        env = importlib.import_module(f"pettingzoo.{lib_name}.{env_name}")
        self._env = env.parallel_env(**kwargs)
        self._env.reset(seed=seed)

        self.n_agents = int(self._env.num_agents)
        assert "max_cycles" in kwargs
        self.episode_limit = kwargs["max_cycles"] + 1

        self.action_space: List[gym.Space] = [self._env.action_space(k) for k in self._env.possible_agents]
        self.observation_space: List[gym.Space] = [self._env.observation_space(k) for k in self._env.possible_agents]
        self.n_actions = int(self.action_space[0].n)
        self.obs_shape = self.observation_space[0].shape

    def _get_possible_agents_obs(self, obs_dict: Dict[str, np.ndarray]):
        obss = []
        for agent in self._env.possible_agents:
            obs = np.array(obs_dict[agent], dtype=np.float32)
            obss.append(obs)
        return tuple(obss)

    def reset(self, *args, **kwargs):
        obs_dict, info_dict = self._env.reset(*args, **kwargs)
        obss = self._get_possible_agents_obs(obs_dict)
        result = {
            "obs": np.stack(obss, axis=0),
            "agent_mask": self.get_agent_mask(),
            "rewards": np.zeros((self.n_agents, 1), dtype=np.float32),
            "terminated": np.array([False], dtype=np.bool),
            "truncated": np.array([False], dtype=np.bool),
            "is_first": np.array([True], dtype=np.bool),
        }
        return result

    def get_agent_mask(self):
        agent_mask = np.ones((self.n_agents, 1), dtype=np.float32)
        return agent_mask

    def render(self, mode="human"):
        return self._env.render(mode)

    # TODO: log info
    def step(self, actions):
        dict_actions = {}
        for agent, action in zip(self._env.possible_agents, actions):
            dict_actions[agent] = action
        obs_dict, reward_dict, terminated_dict, truncated_dict, info_dict = self._env.step(dict_actions)

        terminated = all([terminated_dict[k] for k in self._env.possible_agents])
        truncated = all([truncated_dict[k] for k in self._env.possible_agents])
        is_first = False

        obss = self._get_possible_agents_obs(obs_dict)
        rewards = np.array([[reward_dict[k]] for k in self._env.possible_agents], dtype=np.float32).sum()
        result = {
            "obs": np.stack(obss, axis=0),
            "agent_mask": self.get_agent_mask(),
            "rewards": np.ones((self.n_agents, 1), dtype=np.float32) * rewards,
            "terminated": np.array([terminated], dtype=np.bool),
            "truncated": np.array([truncated], dtype=np.bool),
            "is_first": np.array([is_first], dtype=np.bool),
        }
        return result
    
    def get_state_size(self) -> Tuple[int]:
        state_shape = self._env.state_space.shape
        return state_shape

    def get_obs_size(self) -> Tuple[int]:
        obs_shape = self.observation_space[0].shape
        return obs_shape

    def get_total_actions(self) -> List[int]:
        return [action_space.n for action_space in self.action_space]

    def close(self):
        return self._env.close()
