from meltingpot import substrate
import dm_env
from gymnasium import spaces
import numpy as np
import tree
import cv2

from typing import List, Tuple, Dict

from envs.multiagentenv import MultiAgentEnv

def timestep_to_RGB_obs(timestep: dm_env.TimeStep, reshape: Tuple[int, int] = None) -> List[np.ndarray]:
    obss = [np.array(obs["RGB"], dtype=np.float32) for obs in timestep.observation]
    if reshape:
        obss = [cv2.resize(obs, reshape, interpolation=cv2.INTER_AREA) for obs in obss]
    obss = np.stack(obss, axis=0)
    return obss

def timestep_to_avail_actions(timestep: dm_env.TimeStep) -> List[np.ndarray]:
    ready_to_shoot = np.array([obs["READY_TO_SHOOT"] for obs in timestep.observation]).astype(bool)
    avail_actions = np.ones((len(ready_to_shoot), 8))
    avail_actions[~ready_to_shoot, 7] = 0
    return avail_actions

def remove_world_observations_from_space(observation: spaces.Dict) -> spaces.Dict:
    return spaces.Dict({
        key: observation[key] for key in observation if "WORLD." not in key
    })

def spec_to_space(spec: tree.Structure[dm_env.specs.Array]) -> spaces.Space:
    """Converts a dm_env nested structure of specs to a Gym Space.

    BoundedArray is converted to Box Gym spaces. DiscreteArray is converted to
    Discrete Gym spaces. Using Tuple and Dict spaces recursively as needed.

    Args:
        spec: The nested structure of specs

    Returns:
        The Gym space corresponding to the given spec.
    """
    if isinstance(spec, dm_env.specs.DiscreteArray):
        return spaces.Discrete(spec.num_values)
    elif isinstance(spec, dm_env.specs.BoundedArray):
        return spaces.Box(spec.minimum, spec.maximum, spec.shape, spec.dtype)
    elif isinstance(spec, dm_env.specs.Array):
        if np.issubdtype(spec.dtype, np.floating):
            return spaces.Box(-np.inf, np.inf, spec.shape, spec.dtype)
        elif np.issubdtype(spec.dtype, np.integer):
            info = np.iinfo(spec.dtype)
            return spaces.Box(info.min, info.max, spec.shape, spec.dtype)
        else:
            raise NotImplementedError(f'Unsupported dtype {spec.dtype}')
    elif isinstance(spec, (list, tuple)):
        return spaces.Tuple([spec_to_space(s) for s in spec])
    elif isinstance(spec, dict):
        return spaces.Dict({key: spec_to_space(s) for key, s in spec.items()})
    else:
        raise ValueError('Unexpected spec of type {}: {}'.format(type(spec), spec))

class MeltingPotWrapper(MultiAgentEnv):
    def __init__(self, env_config: Dict):
        self._env = substrate.build(env_config['substrate'], roles=env_config['roles'])
        self.n_agents = len(self._env.observation_spec())
        self.observation_space = remove_world_observations_from_space(spec_to_space(self._env.observation_spec()[0]))
        self.action_space = spec_to_space(self._env.action_spec()[0])
        self.reshape = env_config.get("reshape", None)
        self.obs_shape = (*self.reshape, 3) if self.reshape else self.observation_space["RGB"].shape
        self.n_actions = self.action_space.n.item()
        self.max_cycles = env_config["max_cycles"]
        self.use_avail_actions = env_config.get("use_avail_actions", False)

    def reset(self, *args, **kwargs):
        timestep = self._env.reset()
        self.num_steps = 0

        obs = timestep_to_RGB_obs(timestep, self.reshape)
        ready_to_shoot = np.array([obs["READY_TO_SHOOT"] for obs in timestep.observation])
        avail_actions = timestep_to_avail_actions(timestep)
        # state = self.render().astype(np.float32)

        result = {
            "obs": obs,
            "READY_TO_SHOOT": ready_to_shoot,
            # "state": state,
            "agent_mask": self.get_agent_mask(),
            "rewards": np.zeros((self.n_agents, 1), dtype=np.float32),
            "terminated": np.array([False], dtype=bool),
            "truncated": np.array([False], dtype=bool),
            "is_first": np.array([True], dtype=bool),
        }
        if self.use_avail_actions:
            result["avail_actions"] = avail_actions
        return result

    def step(self, actions):
        timestep = self._env.step(actions)
        self.num_steps += 1

        obs = timestep_to_RGB_obs(timestep, self.reshape)
        ready_to_shoot = np.array([obs["READY_TO_SHOOT"] for obs in timestep.observation])
        avail_actions = timestep_to_avail_actions(timestep)
        # state = self.render().astype(np.float32)
        rewards = np.array([[timestep.reward[index]] for index in range(self.n_agents)])
        truncated = self.num_steps >= self.max_cycles
        terminated = timestep.last()

        result = {
            "obs": obs,
            # "state": state,
            "READY_TO_SHOOT": ready_to_shoot,
            "agent_mask": self.get_agent_mask(),
            "rewards": rewards,
            "terminated": np.array([terminated], dtype=bool),
            "truncated": np.array([truncated], dtype=bool),
            "is_first": np.array([False], dtype=bool),
        }
        if self.use_avail_actions:
            result["avail_actions"] = avail_actions
        return result

    def get_avail_actions(self):
        pass

    def close(self):
        self._env.close()

    def get_agent_mask(self):
        agent_mask = np.ones((self.n_agents, 1), dtype=np.float32)
        return agent_mask

    def render(self) -> np.ndarray:
        """Render the environment.

        This allows you to set `record_env` in your training config, to record
        videos of gameplay.

        Returns:
            np.ndarray: This returns a numpy.ndarray with shape (x, y, 3),
            representing RGB values for an x-by-y pixel image, suitable for turning
            into a video.
        """
        observation = self._env.observation()
        world_rgb = observation[0]['WORLD.RGB']

        # RGB mode is used for recording videos
        return world_rgb
