import gym
from gym import spaces
import numpy as np
from gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX


class FixEnv(gym.core.Wrapper):
    """
    Wrapper to,
    - Remake environment on each reset.
    - Terminate environment after reaching goal states. 
    - Make goal rewards binary.
    - Narrow down goal space by modifying goal states to show only the agent and the object infront of it.
    """

    def __init__(self, env, partial_goal_obs = True):
        super().__init__(env)
        self.env.done = False
        self.env.partial_goal_obs = partial_goal_obs
            
    def reset(self):
        obs = self.env.reset()        
        self.env.done = False
        return obs

    def step(self, action):  
        obs, reward, done, info = self.env.step(action)
        self.env.done = done
                                
        return obs, reward, done, {}


class FixFullyObsWrapper(gym.core.ObservationWrapper):
    """
    Fully observable gridworld using a compact grid encoding
    """

    def __init__(self, env):
        super().__init__(env)

        self.observation_space.spaces["image"] = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height, 3),  # number of cells
            dtype='uint8'
        )

    def observation(self, obs):        
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
            OBJECT_TO_IDX['agent'],
            COLOR_TO_IDX['red'],
            env.agent_dir
        ])
        
        if self.env.done and self.env.partial_goal_obs:
            image = np.zeros(shape=full_grid.shape, dtype=full_grid.dtype)
            ox = (image.shape[0]-obs['image'].shape[0])//2
            oy = (image.shape[1]-obs['image'].shape[1])//2
            image[ox:ox+obs['image'].shape[0],oy:oy+obs['image'].shape[1],:] = obs['image']
            full_grid = image

        return {
            'mission': obs['mission'],
            'image': full_grid
        }

class FixRGBImgObsWrapper(gym.core.ObservationWrapper):
    """
    Wrapper to use fully observable RGB image as the only observation output,
    no language/mission. This can be used to have the agent to solve the
    gridworld in pixel space.
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.tile_size = tile_size

        self.observation_space.spaces['image'] = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
            dtype='uint8'
        )

    def observation(self, obs):               
        env = self.unwrapped

        # self.env.grid.set(*self.env.agent_pos, self.env.carrying)
        
        # u, v = self.env.dir_vec
        # ox, oy = (self.env.agent_pos[0] + u, self.env.agent_pos[1] + v)

        # obj = self.grid.get(ox, oy)
        # if not isinstance(obj, Box) or not self.env.partial_goal_obs:

        if not self.env.done or not self.env.partial_goal_obs:
            rgb_img = env.render(
                mode='rgb_array',
                highlight=False,
                tile_size=self.tile_size
            )
        else:
            rgb_img = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')+100
            img = self.env.get_obs_render(
                obs['image'],
                tile_size=self.tile_size
            )     
            ox = (rgb_img.shape[0]-img.shape[0])//2
            oy = (rgb_img.shape[1]-img.shape[1])//2
            rgb_img[ox:ox+img.shape[0],oy:oy+img.shape[1],:] = img

        self.env.grid.set(*self.env.agent_pos, None)

        return {
            'mission': obs['mission'],
            'image': rgb_img
        }