from gym_minigrid.minigrid import *
from gym_minigrid.roomgrid import RoomGrid
from gym_minigrid.register import register
from gym.spaces import Discrete
import itertools
import random

OBJ_TYPES = ["ball", "box"]
OBJ_COLORS = ["red", "blue", "green", "grey", "purple", "yellow"]

def check_objs_reachable(env):
    """
    Check that all objects are reachable from the agent's starting
    position without requiring any other object to be moved
    (without unblocking)
    """

    # Reachable positions
    reachable = set()

    # Work list
    stack = [env.agent_pos]

    while len(stack) > 0:
        i, j = stack.pop()

        if i < 0 or i >= env.grid.width or j < 0 or j >= env.grid.height:
            continue

        if (i, j) in reachable:
            continue

        # This position is reachable
        reachable.add((i, j))

        cell = env.grid.get(i, j)

        # If there is something other than a door in this cell, it
        # blocks reachability
        if cell and cell.type is not 'door':
            continue

        # Visit the horizontal and vertical neighbors
        stack.append((i+1, j))
        stack.append((i-1, j))
        stack.append((i, j+1))
        stack.append((i, j-1))

    # Check that all objects are reachable
    for i in range(env.grid.width):
        for j in range(env.grid.height):
            cell = env.grid.get(i, j)

            if not cell or cell.type is 'wall':
                continue

            if (i, j) not in reachable:
                return False

    # All objects reachable
    return True

def add_distractors(env, check_reachable=True):
    for _ in range(env.num_dists):
        color = random.choice(OBJ_COLORS)
        type_ = random.choice(OBJ_TYPES)
        if type_=='box':
            obj = Box(color)
        if type_=='ball':
            obj = Ball(color)
        if type_=='Key':
            obj = Key(color)
        pos = env.place_obj(obj)     
        # make sure no unblocking is required
        i = 0
        while check_reachable and not check_objs_reachable(env):
            env.grid.set(*pos, None)
            i += 1
            if i>100:
                break
            if type_=='box':
                obj = Box(color)
            if type_=='ball':
                obj = Ball(color)
            if type_=='Key':
                obj = Key(color)
            pos = env.place_obj(obj)  


class PickPlaceEnvV0(MiniGridEnv):    
    """
    Environment in which the agent is instructed to place a ball in a box.
    """

    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, size=7):
        self.goals =  goals
        self.dist_type = dist_type
        self.dist_color = dist_color
        self.num_dists = num_dists

        self.box_poss = [(1,1),(size-2,1),(1,size-2),(size-2,size-2)]   
        self.all_goals = list(itertools.product(OBJ_COLORS, self.box_poss))
        self.goals = self.goals if self.goals else self.all_goals
        
        super().__init__(
            grid_size=size,
            max_steps=float('inf'),
            # Set this to True for maximum speed
            see_through_walls=True
        )
        
        self.action_space = Discrete(5)

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.horz_wall(0, 0)
        self.grid.horz_wall(0, height-1)
        self.grid.vert_wall(0, 0)
        self.grid.vert_wall(width-1, 0)

        # Randomize the agent start position and orientation
        self.place_agent()
        
        # Add objects
        self.ball = self.goals[random.randint(0,len(self.goals)-1)][0]
        self.place_obj(Ball(self.ball))

        self.boxes = []
        for goal in self.goals:
            if goal[0] == self.ball:
                self.boxes.append(goal[1])
        for pos in self.box_poss: 
            self.place_obj(Box('grey'),top=pos,size=(1,1))

        # Add distractors
        add_distractors(self)

        self.mission = 'put the ball in the box'

    def step(self, action):
        preCarrying = self.carrying

        obs, reward, done, info = super().step(action)

        # If successfully dropping an object near the target
        if action == self.actions.drop and preCarrying:
            u, v = self.dir_vec
            ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
            if (ox, oy) in self.boxes:
                reward = self._reward()
                done = True

        return obs, reward, done, info


class PickPlaceEnvV1(MiniGridEnv):    
    """
    Environment in which the agent is instructed to place a ball in a box.
    """

    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, size=7):
        self.goals =  goals
        self.dist_type = dist_type
        self.dist_color = dist_color
        self.num_dists = num_dists

        self.all_goals = list(itertools.product(OBJ_COLORS, OBJ_COLORS))
        self.goals = self.goals if self.goals else self.all_goals

        super().__init__(
            grid_size=size,
            max_steps=float('inf'),
            # Set this to True for maximum speed
            see_through_walls=True
        )
        
        self.action_space = Discrete(5)

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.horz_wall(0, 0)
        self.grid.horz_wall(0, height-1)
        self.grid.vert_wall(0, 0)
        self.grid.vert_wall(width-1, 0)

        # Randomize the agent start position and orientation
        self.place_agent()
        
        # Add objects
        self.ball = self.goals[random.randint(0,len(self.goals)-1)][0]
        self.place_obj(Ball(self.ball))

        self.boxes = []
        for goal in self.goals:
            if goal[0] == self.ball:
                self.boxes.append(goal[1])
        self.place_obj(Box(self.boxes[0]))

        # Add distractors
        add_distractors(self)
        
        self.mission = 'put the ball in the box'
    
    def step(self, action):
        preCarrying = self.carrying

        obs, reward, done, info = super().step(action)

        # If successfully dropping an object near the target
        if action == self.actions.drop and preCarrying:
            u, v = self.dir_vec
            ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
            obj = self.grid.get(ox, oy)
            if isinstance(obj, Box) and obj.color in self.boxes:
                reward = self._reward()
                done = True

        return obs, reward, done, info


class PickUpObjinRoom(RoomGrid):    
    """
    Environment in which the agent is instructed to goto balls or boxes in rooms.
    """
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, size=4, num_rows=3, seed=None):
        self.goals =  goals
        self.dist_type = dist_type
        self.dist_color = dist_color
        self.num_dists = num_dists
        self.partial_goal_obs = True
        self.automatic_doors = True
        self.same_terminal_states = True

        self.doors_colors = OBJ_COLORS.copy()
        self.doors_pos = [(x,y) for x in [1,2] for y in range(3)]
        self.hallway = [(1,y) for y in range(3)]
        self.rooms = [(x,y) for x in [0,2] for y in range(3)]+[None]
        self.all_goals = list(itertools.product(OBJ_COLORS,OBJ_TYPES,self.rooms))
        self.goals = self.goals if self.goals else self.all_goals

        super().__init__(
            room_size=size,
            num_rows=num_rows,
            max_steps=float('inf'),
            seed=seed,
        )
        if self.automatic_doors:
            self.action_space = Discrete(5)

    def _gen_grid(self, width, height):
        super()._gen_grid(width, height)

        # Connect the middle column rooms into a hallway
        for j in range(1, self.num_rows):
            self.remove_wall(1, j, 3)

        # Add doors
        np.random.shuffle(self.doors_colors)
        for i in range(len(self.doors_pos)):
            self.add_door(self.doors_pos[i][0], self.doors_pos[i][1], 2, color=self.doors_colors[i], locked=False)
        
        self.balls = []
        self.boxes = []

        # Place the agent in the middle
        self.place_agent(1, self.num_rows // 2)
        
        # Add objects
        self.goal = self.goals[random.randint(0,len(self.goals)-1)]
        if self.goal[2]:
            self.add_object(self.goal[2][0], self.goal[2][1], self.goal[1], self.goal[0])
        else:
            try:
                self.add_object(1, np.random.randint(3), self.goal[1], self.goal[0])
            except:
                if not all(np.equal(self.agent_pos,(1,0))):
                    if self.goal[1]=='ball':
                        self.place_obj(Ball(self.goal[0]),top=(self.room_size,1),size=(1,1))
                    else:
                        self.place_obj(Box(self.goal[0]),top=(self.room_size,1),size=(1,1))
                else:
                    if self.goal[1]=='ball':
                        self.place_obj(Ball(self.goal[0]),top=(self.room_size,self.room_size*(self.num_rows-1)),size=(1,1))
                    else:
                        self.place_obj(Box(self.goal[0]),top=(self.room_size,self.room_size*(self.num_rows-1)),size=(1,1))


        # Add distractors
        add_distractors(self)

        self.mission = 'put * balls from rooms * in the * boxes in rooms *'

    def step(self, action):
        obs, reward, done, info = super().step(action)

        if self.automatic_doors:
            u, v = self.dir_vec
            fx, fy = (self.agent_pos[0] + 1, self.agent_pos[1] + 0)
            rx, ry = (self.agent_pos[0] + 0, self.agent_pos[1] + 1)
            dx, dy = (self.agent_pos[0] - 1, self.agent_pos[1] + 0)
            lx, ly = (self.agent_pos[0] - 1, self.agent_pos[1] - 1)
            ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
            r_obj = self.grid.get(rx, ry)
            l_obj = self.grid.get(lx, ly)
            f_obj = self.grid.get(fx, fy)
            d_obj = self.grid.get(dx, dy)
            obj = self.grid.get(ox, oy)
            if isinstance(r_obj, Door):
                if r_obj.is_open:
                    r_obj.is_open = False
            if isinstance(l_obj, Door):
                if l_obj.is_open:
                    l_obj.is_open = False
            if isinstance(f_obj, Door):
                if f_obj.is_open:
                    f_obj.is_open = False
            if isinstance(d_obj, Door):
                if d_obj.is_open:
                    d_obj.is_open = False
            if isinstance(obj, Door):
                if not obj.is_open:
                    obj.is_open = True
        
        # If successfully picked an object
        if self.carrying:
            obj = self.carrying
            u, v = self.dir_vec
            ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
            i = ox // (self.room_size-1)
            j = oy // (self.room_size-1)
            
            pos = (i, j) if (i, j) in self.rooms else None
            goal=[obj.color, obj.type, pos]
            reward = 0
            if self.same_terminal_states:
                done = True
            for g in self.goals:
                if str(g[0])==str(goal[0]) and str(g[1])==str(goal[1]) and str(g[2])==str(goal[2]):
                    reward = self._reward()
                    done = True
                    break
            
            ### Narrow down goal space
            if self.partial_goal_obs:
                agent_view_size = self.agent_view_size
                self.agent_view_size = 3
                image_ = self.gen_obs()['image']
                image_[[0,2],:,:] = 0
                image_[:,0,:] = 0
                if (i, j) in self.rooms:
                    color = COLOR_TO_IDX[self.doors_colors[self.rooms.index((i, j))]]
                    image_[1,1] = [OBJECT_TO_IDX['door'],color,1]
                self.agent_view_size = agent_view_size

                image = np.zeros(shape=obs['image'].shape, dtype=obs['image'].dtype)
                ox = (image.shape[0]-image_.shape[0])//2
                oy = image.shape[1]-image_.shape[1]
                image[ox:ox+image_.shape[0],oy:oy+image_.shape[1],:] = image_
                obs['image'] = image

        return obs, reward, done, info


class GoToObjinRoom(RoomGrid):    
    """
    Environment in which the agent is instructed to goto balls or boxes in rooms.
    """
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, size=4, num_rows=3, seed=None):
        self.goals =  goals
        self.dist_type = dist_type
        self.dist_color = dist_color
        self.num_dists = num_dists
        self.partial_goal_obs = True
        self.automatic_doors = True
        self.same_terminal_states = False

        self.doors_colors = OBJ_COLORS.copy()
        self.doors_pos = [(x,y) for x in [1,2] for y in range(3)]
        self.hallway = [(1,y) for y in range(3)]
        self.rooms = [(x,y) for x in [0,2] for y in range(3)]+[None]
        self.all_goals = list(itertools.product(OBJ_COLORS,OBJ_TYPES,self.rooms))
        self.goals = self.goals if self.goals else self.all_goals

        super().__init__(
            room_size=size,
            num_rows=num_rows,
            max_steps=float('inf'),
            seed=seed,
        )
        if self.automatic_doors:
            self.action_space = Discrete(3)

    def _gen_grid(self, width, height):
        super()._gen_grid(width, height)

        # Connect the middle column rooms into a hallway
        for j in range(1, self.num_rows):
            self.remove_wall(1, j, 3)

        # Add doors
        np.random.shuffle(self.doors_colors)
        for i in range(len(self.doors_pos)):
            self.add_door(self.doors_pos[i][0], self.doors_pos[i][1], 2, color=self.doors_colors[i], locked=False)
        
        self.balls = []
        self.boxes = []

        # Place the agent in the middle
        self.place_agent(1, self.num_rows // 2)
        
        # Add objects
        self.goal = self.goals[random.randint(0,len(self.goals)-1)]
        if self.goal[2]:
            self.add_object(self.goal[2][0], self.goal[2][1], self.goal[1], self.goal[0])
        else:
            try:
                self.add_object(1, np.random.randint(3), self.goal[1], self.goal[0])
            except:
                if not all(np.equal(self.agent_pos,(1,0))):
                    if self.goal[1]=='ball':
                        self.place_obj(Ball(self.goal[0]),top=(self.room_size,1),size=(1,1))
                    else:
                        self.place_obj(Box(self.goal[0]),top=(self.room_size,1),size=(1,1))
                else:
                    if self.goal[1]=='ball':
                        self.place_obj(Ball(self.goal[0]),top=(self.room_size,self.room_size*(self.num_rows-1)),size=(1,1))
                    else:
                        self.place_obj(Box(self.goal[0]),top=(self.room_size,self.room_size*(self.num_rows-1)),size=(1,1))


        # Add distractors
        add_distractors(self)

        self.mission = 'put * balls from rooms * in the * boxes in rooms *'

    def step(self, action):
        u, v = self.dir_vec
        ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
        obj = self.grid.get(ox, oy)

        obs, reward, done, info = super().step(action)

        # Automatic opening and closing doors
        if self.automatic_doors:
            u, v = self.dir_vec
            fx, fy = (self.agent_pos[0] + 1, self.agent_pos[1] + 0)
            rx, ry = (self.agent_pos[0] + 0, self.agent_pos[1] + 1)
            dx, dy = (self.agent_pos[0] - 1, self.agent_pos[1] + 0)
            lx, ly = (self.agent_pos[0] - 1, self.agent_pos[1] - 1)
            ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
            r_obj = self.grid.get(rx, ry)
            l_obj = self.grid.get(lx, ly)
            f_obj = self.grid.get(fx, fy)
            d_obj = self.grid.get(dx, dy)
            _obj = self.grid.get(ox, oy)
            if isinstance(r_obj, Door):
                if r_obj.is_open:
                    r_obj.is_open = False
            if isinstance(l_obj, Door):
                if l_obj.is_open:
                    l_obj.is_open = False
            if isinstance(f_obj, Door):
                if f_obj.is_open:
                    f_obj.is_open = False
            if isinstance(d_obj, Door):
                if d_obj.is_open:
                    d_obj.is_open = False
            if isinstance(_obj, Door):
                if not _obj.is_open:
                    _obj.is_open = True
        
        # If successfully reached an object
        if action == self.actions.forward and isinstance(obj, WorldObj) and obj.type in OBJ_TYPES:
            i = ox // (self.room_size-1)
            j = oy // (self.room_size-1)
            pos = (i, j) if (i, j) in self.rooms else None
            goal=[obj.color, obj.type, pos]
            reward = 0
            if self.same_terminal_states:
                done = True
            for g in self.goals:
                if str(g[0])==str(goal[0]) and str(g[1])==str(goal[1]) and str(g[2])==str(goal[2]):
                    reward = self._reward()
                    done = True
                    break
            
            ### Narrow down goal space
            if self.partial_goal_obs:
                agent_view_size = self.agent_view_size
                self.agent_view_size = 3
                image_ = self.gen_obs()['image']
                image_[[0,2],:,:] = 0
                image_[:,0,:] = 0
                if (i, j) in self.rooms:
                    color = COLOR_TO_IDX[self.doors_colors[self.rooms.index((i, j))]]
                    image_[1,0] = [OBJECT_TO_IDX['door'],color,1]
                self.agent_view_size = agent_view_size

                image = np.zeros(shape=obs['image'].shape, dtype=obs['image'].dtype)
                ox = (image.shape[0]-image_.shape[0])//2
                oy = image.shape[1]-image_.shape[1]
                image[ox:ox+image_.shape[0],oy:oy+image_.shape[1],:] = image_
                obs['image'] = image

        return obs, reward, done, info



class RoomsCorridor(RoomGrid):
    """
    Environment in which the agent is instructed to place a ball in a box.
    """

    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, size=7, num_rows=3, seed=None):
        self.goals =  goals
        self.dist_type = dist_type
        self.dist_color = dist_color
        self.num_dists = num_dists

        self.rooms = [(x,y) for x in [0,2] for y in range(3)]
        balls = list(itertools.product(self.rooms, OBJ_COLORS))
        boxes = list(itertools.product(self.rooms, OBJ_COLORS))
        self.all_goals = list(itertools.product(balls,boxes))
        self.goals = self.goals if self.goals else self.all_goals
        self.carrying_room = None

        super().__init__(
            room_size=size,
            num_rows=num_rows,
            max_steps=float('inf'),
            seed=seed,
        )

    def _gen_grid(self, width, height):
        super()._gen_grid(width, height)

        # Connect the middle column rooms into a hallway
        for j in range(1, self.num_rows):
            self.remove_wall(1, j, 3)

        # Add doors
        room_idx = self._rand_int(0, self.num_rows)
        self.add_door(1, 0, 2, color=OBJ_COLORS[0], locked=False)
        self.add_door(1, 1, 2, color=OBJ_COLORS[1], locked=False)
        self.add_door(1, 2, 2, color=OBJ_COLORS[2], locked=False)
        self.add_door(2, 0, 2, color=OBJ_COLORS[3], locked=False)
        self.add_door(2, 1, 2, color=OBJ_COLORS[4], locked=False)
        self.add_door(2, 2, 2, color=OBJ_COLORS[5], locked=False)
        
        # Place the agent in the middle
        self.place_agent(1, self.num_rows // 2)
        
        self.balls = []
        self.boxes = []

        # Add objects
        self.goal = self.goals[random.randint(0,len(self.goals)-1)]
        self.add_object(self.goal[0][0][0], self.goal[0][0][1], 'ball', self.goal[0][1])
        self.add_object(self.goal[1][0][0], self.goal[1][0][1], 'box',self.goal[1][1])
        for _ in range(self.num_dists):
            self.goal = self.goals[random.randint(0,len(self.goals)-1)]
            ball, pos1 = self.add_object(self.goal[0][0][0], self.goal[0][0][1], 'ball', self.goal[0][1])
            box, pos2 = self.add_object(self.goal[1][0][0], self.goal[1][0][1], 'box',self.goal[1][1])
            # make sure no unblocking is required
            i = 0
            while not check_objs_reachable(self):
                self.grid.set(*pos1, None)
                self.grid.set(*pos2, None)
                i += 1
                if i>100:
                    break
                self.goal = self.goals[random.randint(0,len(self.goals)-1)]
                ball, pos1 = self.add_object(self.goal[0][0][0], self.goal[0][0][1], 'ball', self.goal[0][1])
                box, pos2 = self.add_object(self.goal[1][0][0], self.goal[1][0][1], 'box',self.goal[1][1])
            self.balls.append(ball)
            self.boxes.append(box)

        # Add distractors
        # add_distractors(self)

        self.mission = 'put * balls from rooms * in the * boxes in rooms *'

    def step(self, action):
        obs, reward, done, info = super().step(action)

        # If successfully dropping an object near the target
        if action == self.actions.drop and self.carrying:
            u, v = self.dir_vec
            ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v)
            obj = self.grid.get(ox, oy)
            if isinstance(obj, Box):
                goal=[[self.carrying_room,self.carrying.color], [(ox, oy), obj.color]]
                if goal in self.goals:
                    reward = self._reward()
                else:
                    reward = 0
                done = True

        return obs, reward, done, info

class RoomsCorridorS3R1(RoomsCorridor):
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, seed=None):
        super().__init__(
            goals=goals, dist_type=dist_type, dist_color=dist_color, num_dists=num_dists,
            size=3,
            num_rows=1,
            seed=seed
        )

class RoomsCorridorS3R2(RoomsCorridor):
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, seed=None):
        super().__init__(
            goals=goals, dist_type=dist_type, dist_color=dist_color, num_dists=num_dists,
            size=3,
            num_rows=2,
            seed=seed
        )

class RoomsCorridorS3R3(RoomsCorridor):
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, seed=None):
        super().__init__(
            goals=goals, dist_type=dist_type, dist_color=dist_color, num_dists=num_dists,
            size=3,
            num_rows=3,
            seed=seed
        )

class RoomsCorridorS4R3(RoomsCorridor):
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, seed=None):
        super().__init__(
            goals=goals, dist_type=dist_type, dist_color=dist_color, num_dists=num_dists,
            size=4,
            num_rows=3,
            seed=seed
        )

class RoomsCorridorS5R3(RoomsCorridor):
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, seed=None):
        super().__init__(
            goals=goals, dist_type=dist_type, dist_color=dist_color, num_dists=num_dists,
            size=5,
            num_rows=3,
            seed=seed
        )

class RoomsCorridorS6R3(RoomsCorridor):
    def __init__(self, goals=None, dist_type=None, dist_color=None, num_dists=0, seed=None):
        super().__init__(
            goals=goals, dist_type=dist_type, dist_color=dist_color, num_dists=num_dists,
            size=6,
            num_rows=3,
            seed=seed
        )





register(
    id='MiniGrid-PickPlace-v0',
    entry_point='envs.envs:PickPlaceEnvV0'
)
register(
    id='MiniGrid-PickPlace-v1',
    entry_point='envs.envs:PickPlaceEnvV1'
)
register(
    id='MiniGrid-GoToObjinRoom-v0',
    entry_point='envs.envs:GoToObjinRoom'
)
register(
    id='MiniGrid-PickUpObjinRoom-v0',
    entry_point='envs.envs:PickUpObjinRoom'
)



register(
    id='MiniGrid-RoomsCorridorS3R1-v0',
    entry_point='envs.envs:RoomsCorridorS3R1'
)

register(
    id='MiniGrid-RoomsCorridorS3R2-v0',
    entry_point='envs.envs:RoomsCorridorS3R2'
)

register(
    id='MiniGrid-RoomsCorridorS3R3-v0',
    entry_point='envs.envs:RoomsCorridorS3R3'
)

register(
    id='MiniGrid-RoomsCorridorS4R3-v0',
    entry_point='envs.envs:RoomsCorridorS4R3'
)

register(
    id='MiniGrid-RoomsCorridorS5R3-v0',
    entry_point='envs.envs:RoomsCorridorS5R3'
)

register(
    id='MiniGrid-RoomsCorridorS6R3-v0',
    entry_point='envs.envs:RoomsCorridorS6R3'
)
