import numpy as np
import gym
from gym import spaces
from gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX


class MultiGridFullyObsWrapper(gym.core.ObservationWrapper):
    """
    Fully observable gridworld using a compact grid encoding
    """
    def __init__(self, env):
        super().__init__(env)

        self.observation_space.spaces["image"] = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height, 3),  # number of cells
            dtype='uint8'
        )

    def _process_obs(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()

        # Note env.agent_pos is an array of length K, for K multigrid agents
        if env.agent_pos[0] is not None:
            full_grid[env.agent_pos[0][0]][env.agent_pos[0][1]] = np.array([
                OBJECT_TO_IDX['agent'],
                COLOR_TO_IDX['red'],
                env.agent_dir[0]
            ])

        obs['image'] = full_grid

        return obs

    def observation(self, obs):
        return self._process_obs(obs)

    def seed(self, seed):
        env = self.unwrapped
        return self._process_obs(env.seed(seed))