import cv2
import gym
import numpy
import numpy as np

try:
    from gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
except:
    OBJECT_TO_IDX, COLOR_TO_IDX=0,0

class MiniGridWrapper(gym.Wrapper):
    def __init__(self, env):
        super(MiniGridWrapper, self).__init__(env)

    def step(self, action):
        s,r,d,i=self.env.step(action)
        return np.transpose(s["image"],(2,0,1)),r,d,i

    def reset(self):
        return np.transpose(self.env.reset()['image'],(2,0,1))

    def get_fullobs_img(self):
        img=self.env.render("array")
        img=cv2.resize(img, (160, 160), interpolation=cv2.INTER_AREA)
        return img

    def get_obs(self):
        return self.env.gen_obs()["image"]

    def get_full_obs(self):
        env = self.env.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
            OBJECT_TO_IDX['agent'],
            COLOR_TO_IDX['red'],
            env.agent_dir
        ])
        return np.transpose(full_grid,(2,0,1))

    def get_partialobs_img(self):
        img=self.env.get_obs_render(self.get_obs())
        img = cv2.resize(img, (50, 50), interpolation=cv2.INTER_AREA)
        return img

    def get_state_image(self):
        img1=self.get_fullobs_img()
        img2=self.get_partialobs_img()
        img=np.ones([img1.shape[0]+img2.shape[0],img1.shape[1],3],dtype=numpy.uint8)*192
        img[:img1.shape[0],:img1.shape[1],:] = img1
        img[img1.shape[0]:,:img2.shape[1],:] = img2
        return img




def make_minigrid(env_id):
    env = gym.make(env_id)
    env = MiniGridWrapper(env)

    return env


