from gym_minigrid.minigrid import *
from gym_minigrid.register import register

from .base import SimpleEnv

class DoorKeyEnv(SimpleEnv):
    """
    Environment with a door and key, sparse reward
    """

    class Actions(IntEnum):
        up      = 0
        right   = 1
        down    = 2
        left    = 3
        pickup  = 4
        open    = 5

    def __init__(self, 
        size=8, 
        agent_start_dir=0,
        **kwargs
    ):
        self.agent_start_dir = agent_start_dir
        self.agent_start_pos = (1, size-2)

        super().__init__(
            grid_size=size,
            max_steps=10*size*size,
        )

    @property
    def valid_positions(self):
        if not self._valid_positions:
            self._valid_positions = []
            splitIdx = (self.width-2) // 2 + 1
            for i in range(self.width):
                if i > splitIdx:
                    continue
                for j in range(self.height):
                    if self.grid.get(i, j) is None:
                        self._valid_positions.append((i,j))
        return self._valid_positions

    def get_fwd_pos(self, action):
        if action == self.actions.up:
            fwd_pos = self.agent_pos + np.array((0, -1))
        elif action == self.actions.right:
            fwd_pos = self.agent_pos + np.array((1, 0))
        elif action == self.actions.down:
            fwd_pos = self.agent_pos + np.array((0, 1))
        elif action == self.actions.left:
            fwd_pos = self.agent_pos + np.array((-1, 0))
        elif action == self.actions.pickup:
            for ngb_cell_pos in (np.array((0, -1)), np.array((1, 0)), np.array((0, 1)), np.array((-1, 0))):
                fwd_pos = self.agent_pos + ngb_cell_pos
                fwd_cell = self.grid.get(*fwd_pos)
                if fwd_cell and fwd_cell.can_pickup():
                    if self.carrying is None:
                        self.carrying = fwd_cell
                        self.carrying.cur_pos = np.array([-1, -1])
                        self.grid.set(*fwd_pos, None)
                        break
                fwd_pos = self.agent_pos
        elif action == self.actions.open:
            for ngb_cell_pos in (np.array((0, -1)), np.array((1, 0)), np.array((0, 1)), np.array((-1, 0))):
                fwd_pos = self.agent_pos + ngb_cell_pos
                fwd_cell = self.grid.get(*fwd_pos)
                if fwd_cell and fwd_cell.type == 'door' and not fwd_cell.is_open:
                    fwd_cell.toggle(self, fwd_pos)
                    break
                fwd_pos = self.agent_pos
        else:
            assert False, f"unknown action: {action}"
        return fwd_pos

    def step(self, action):
        self.step_count += 1

        reward = 0 # -1 / self.max_steps
        done = False

        fwd_pos = self.get_fwd_pos(action)
        fwd_cell = self.grid.get(*fwd_pos)
        # simple
        # if fwd_cell != None and fwd_cell.type == 'door' and not fwd_cell.is_open:
        #     fwd_cell.toggle(self, fwd_pos)
        # if fwd_cell != None and fwd_cell.can_pickup():
        #     if self.carrying is None:
        #         self.carrying = fwd_cell
        #         self.carrying.cur_pos = np.array([-1, -1])
        #         self.grid.set(*fwd_pos, None)
        #         fwd_cell = None
        # simple end
        if fwd_cell == None or fwd_cell.can_overlap():
            self.agent_pos = fwd_pos
        if fwd_cell != None and fwd_cell.type == 'goal':
            done = True
            reward = self._reward()
        if fwd_cell != None and fwd_cell.type == 'lava':
            done = True

        if self.step_count >= self.max_steps:
            done = True

        obs = self.gen_obs()

        return obs, reward, done, {}

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Place a goal in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Create a vertical splitting wall
        splitIdx = (width-2) // 2 + 1 # self._rand_int(2, width-2)
        self.grid.vert_wall(splitIdx, 0)

        # Place the agent at a random position and orientation
        # on the left side of the splitting wall
        self.agent_dir = self.agent_start_dir
        self.agent_pos = self.agent_start_pos
        # self.place_agent(size=(splitIdx, height), rand_dir=False)

        # Place a door in the wall
        doorIdx = 2 # self._rand_int(1, width-2)
        self.put_obj(Door('yellow', is_locked=True), splitIdx, doorIdx)

        # Place a yellow key on the left side
        # self.place_obj(
        #     obj=Key('yellow'),
        #     top=(0, 0),
        #     size=(splitIdx, height)
        # )
        self.put_obj(Key('yellow'), 1, 1)

        self.mission = "use the key to open the door and then get to the goal"

    def _reward(self):
        """
        Compute the reward to be given upon success
        """
        min_step_count_to_goal = 6
        return min_step_count_to_goal / self.step_count

class DoorKeyEnv5x5(DoorKeyEnv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class DoorKeyEnv6x6(DoorKeyEnv):
    def __init__(self, **kwargs):
        super().__init__(size=6, **kwargs)


class DoorKeyEnv7x7(DoorKeyEnv):
    def __init__(self, **kwargs):
        super().__init__(size=7, **kwargs)

class DoorKeyEnv8x8(DoorKeyEnv):
    def __init__(self, **kwargs):
        super().__init__(size=8, **kwargs)

class DoorKeyEnv16x16(DoorKeyEnv):
    def __init__(self, **kwargs):
        super().__init__(size=16, **kwargs)

register(
    id='MiniGrid-SimpleDoorKey-5x5-v0',
    entry_point='custom_minigrid.envs:DoorKeyEnv5x5'
)

register(
    id='MiniGrid-SimpleDoorKey-6x6-v0',
    entry_point='custom_minigrid.envs:DoorKeyEnv6x6'
)

register(
    id='MiniGrid-SimpleDoorKey-7x7-v0',
    entry_point='custom_minigrid.envs:DoorKeyEnv7x7'
)

register(
    id='MiniGrid-SimpleDoorKey-8x8-v0',
    entry_point='custom_minigrid.envs:DoorKeyEnv8x8'
)

register(
    id='MiniGrid-SimpleDoorKey-16x16-v0',
    entry_point='custom_minigrid.envs:DoorKeyEnv16x16'
)
