from gym_minigrid.minigrid import *
from gym_minigrid.roomgrid import RoomGrid
from gym_minigrid.register import register
from gym.spaces import Discrete
import itertools
import random

OBJ_TYPES = ["key", "ball", "box"]
OBJ_COLORS = ["red", "blue", "green", "purple", "yellow"]

def check_objs_reachable(env):
    """
    Check that all objects are reachable from the agent's starting
    position without requiring any other object to be moved
    (without unblocking)
    """

    # Reachable positions
    reachable = set()

    # Work list
    stack = [env.agent_pos]

    while len(stack) > 0:
        i, j = stack.pop()

        if i < 0 or i >= env.grid.width or j < 0 or j >= env.grid.height:
            continue

        if (i, j) in reachable:
            continue

        # This position is reachable
        reachable.add((i, j))

        cell = env.grid.get(i, j)

        # If there is something other than a door in this cell, it
        # blocks reachability
        if cell and cell.type is not 'door':
            continue

        # Visit the horizontal and vertical neighbors
        stack.append((i+1, j))
        stack.append((i-1, j))
        stack.append((i, j+1))
        stack.append((i, j-1))

    # Check that all objects are reachable
    for i in range(env.grid.width):
        for j in range(env.grid.height):
            cell = env.grid.get(i, j)

            if not cell or cell.type is 'wall':
                continue

            if (i, j) not in reachable:
                return False

    # All objects reachable
    return True

def add_distractors(env, check_reachable=True):
    for _ in range(env.num_dists):
        color = random.choice(OBJ_COLORS)
        type_ = random.choice(OBJ_TYPES)
        if type_=='box':
            obj = Box(color)
        if type_=='ball':
            obj = Ball(color)
        if type_=='key':
            obj = Key(color)
        pos = env.place_obj(obj)     
        # make sure no unblocking is required
        i = 0
        while check_reachable and not check_objs_reachable(env):
            env.grid.set(*pos, None)
            i += 1
            if i>100:
                break
            if type_=='box':
                obj = Box(color)
            if type_=='ball':
                obj = Ball(color)
            if type_=='key':
                obj = Key(color)
            pos = env.place_obj(obj)  


class PickUpObjEnv(MiniGridEnv):    
    """
    Environment in which the agent is instructed to place a ball in a box.
    """

    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=4, size=7, seed=None):
        self.goals =  goals
        self.dist_type = dist_type
        self.dist_color = dist_color
        self.num_dists = num_dists
        self.same_terminal_states = True
        self.r = -0.1
        self.rmin = -0.1
        self.rmax = 2
        self.partial_goal_obs = True
  
        self.all_goals = list(itertools.product(OBJ_COLORS, OBJ_TYPES))
        self.goals = self.goals if self.goals != None else self.all_goals
        
        super().__init__(
            grid_size=size,
            max_steps=float('inf'),
            # Set this to True for maximum speed
            see_through_walls=True
        )
        
        # self.action_space = Discrete(5)

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.horz_wall(0, 0)
        self.grid.horz_wall(0, height-1)
        self.grid.vert_wall(0, 0)
        self.grid.vert_wall(width-1, 0)

        # Randomize the agent start position and orientation
        self.place_agent()
        
        # Add objects
        if len(self.goals)>0:
            g = random.randint(0,len(self.goals)-1)
            color, type_ = self.goals[g]
            if type_=='box':
                obj = Box(color)
            if type_=='ball':
                obj = Ball(color)
            if type_=='key':
                obj = Key(color)
            self.place_obj(obj)        

        # Add distractors
        add_distractors(self)

        self.mission = 'pickup object'
    
    def get_goals_imgs(self):
        goals = []

        p_obj = self.grid.get(1, 1)
        for color, type_ in self.goals:
            self.grid.set(1,1, None)
            if type_=='box':
                obj = Box(color)
            if type_=='ball':
                obj = Ball(color)
            if type_=='key':
                obj = Key(color)
            
            self.place_obj(obj, top=(1,1),size=(1,1))

            t=32
            rgb_img = self.render(
                mode='rgb_array',
                highlight=False,
                tile_size=t
            )
            goals.append(rgb_img[t:t*2,t:t*2,:])
        
        if p_obj:
            self.grid.set(1,1, None)
            self.place_obj(p_obj, top=(1,1),size=(1,1))          

        return goals

    def step(self, action):
        obs, reward, done, info = super().step(action)
        reward = self.r
        
        obj = self.carrying
        # If successfully picked an object
        if obj:
            goal=(obj.color, obj.type)
            # if self.same_terminal_states:
            #     done = True
            # if goal in self.goals:
            #     reward = self.rmax
            #     done = True
            if goal in self.goals:
                reward = self.rmax
            else:
                reward = self.rmin
            done = True

            
            ### Narrow down goal space
            if self.partial_goal_obs:
                agent_view_size = self.agent_view_size
                self.agent_view_size = 3
                image_ = self.gen_obs()['image']
                image_[[0,2],:,:] = 0
                image_[:,0,:] = 0
                self.agent_view_size = agent_view_size

                image = np.zeros(shape=obs['image'].shape, dtype=obs['image'].dtype)
                ox = (image.shape[0]-image_.shape[0])//2
                oy = image.shape[1]-image_.shape[1]
                image[ox:ox+image_.shape[0],oy:oy+image_.shape[1],:] = image_
                obs['image'] = image

        return obs, reward, done, info


register(
    id='MiniGrid-PickUpObj-v0',
    entry_point='envs.envs:PickUpObjEnv'
)
