from gym_minigrid.minigrid import *

from . import multigrid
from . import register


class Room:
    def __init__(self,
        top,
        size,
        entryDoorPos,
        exitDoorPos
    ):
        self.top = top
        self.size = size
        self.entryDoorPos = entryDoorPos
        self.exitDoorPos = exitDoorPos

# class MultiRoomEnv(MiniGridEnv):
class MultiRoomEnv(multigrid.MultiGridEnv):
    """
    Environment with multiple rooms (subgoals)
    """

    def __init__(self,
        minNumRooms,
        maxNumRooms,
        maxRoomSize=10,
        gridSize=25,
        agent_view_size=5,
        p=0.5,
        rewards=[0.75, 1.0],
        reward_spreads=[10, 0],
        obl_correction=False,
        fixed_environment=False,
        seed=0
    ):
        # print(f'MC: Making env with params: {p}, {rewards}, {reward_spreads},', flush=True)

        assert minNumRooms > 0
        assert maxNumRooms >= minNumRooms
        assert maxRoomSize >= 4

        self.minNumRooms = minNumRooms
        self.maxNumRooms = maxNumRooms
        self.maxRoomSize = maxRoomSize

        self.rooms = []

        self.world = multigrid.World

        self.p = p
        self.rewards = rewards
        self.reward_spreads = reward_spreads
        self.obl_correction = obl_correction
        self.unseeded_np_random,_ = gym.utils.seeding.np_random()

        self.agent_start_pos = None
        self.agent_start_dir = None

        super().__init__(
            grid_size=gridSize,
            max_steps=self.maxNumRooms * 20,
            n_agents=1,
            minigrid_mode=True,
            agent_view_size=agent_view_size,
            see_through_walls=False,
            competitive=True,
            fixed_environment=fixed_environment,
            seed=seed
        )

        self.action_space = gym.spaces.Discrete(6)

    def seed(self, seed):
        super().seed(seed)
        self.level_seed = seed
        obs = self.reset()

        # print('reset to seed', seed, flush=True)

        return obs

    def reset_metrics(self):
        # Placeholder values for now
        self.n_clutter_placed = 0
        self.passable = True
        self.shortest_path_length = 0

    def reset_agent_status(self):
        """Reset the agent's position, direction, done, and carrying status."""
        self.agent_pos = [None] * self.n_agents
        self.agent_dir = [self.agent_start_dir] * self.n_agents
        self.done = [False] * self.n_agents
        self.carrying = [None] * self.n_agents

    def reset(self):
        self.step_count = 0
        self.reset_agent_status()

        self.agent_start_pos = None
        self.agent_start_dir = None
        self.goal_pos = None

        # Extra metrics
        self.reset_metrics()
        self._gen_grid(self.width, self.height)

        obs = self.gen_obs()

        return obs

    def reset_agent(self):
        return self.seed(self.level_seed)

    def _gen_grid(self, width, height):
        self.target = None
        roomList = []
        self.doorPosList = []

        # Choose a random number of rooms to generate
        numRooms = self._rand_int(self.minNumRooms, self.maxNumRooms+1)
        # numRooms = 1

        while len(roomList) < numRooms:
            curRoomList = []

            entryDoorPos = (
                self._rand_int(0, width - 2),
                self._rand_int(0, width - 2)
            )

            # Recursively place the rooms
            self._placeRoom(
                numRooms,
                roomList=curRoomList,
                minSz=4,
                maxSz=self.maxRoomSize,
                entryDoorWall=2,
                entryDoorPos=entryDoorPos
            )

            if len(curRoomList) > len(roomList):
                roomList = curRoomList

        # Store the list of rooms in this environment
        assert len(roomList) > 0
        self.rooms = roomList

        # Create the grid
        self.grid = multigrid.Grid(width, height)
        wall = Wall()

        prevDoorColor = None

        # For each room
        for idx, room in enumerate(roomList):

            topX, topY = room.top
            sizeX, sizeY = room.size

            # Draw the top and bottom walls
            for i in range(0, sizeX):
                self.grid.set(topX + i, topY, wall)
                self.grid.set(topX + i, topY + sizeY - 1, wall)

            # Draw the left and right walls
            for j in range(0, sizeY):
                self.grid.set(topX, topY + j, wall)
                self.grid.set(topX + sizeX - 1, topY + j, wall)

            # If this isn't the first room, place the entry door
            if idx > 0:
                # Pick a door color different from the previous one
                doorColors = set(COLOR_NAMES)
                if prevDoorColor:
                    doorColors.remove(prevDoorColor)
                # Note: the use of sorting here guarantees determinism,
                # This is needed because Python's set is not deterministic
                doorColor = self._rand_elem(sorted(doorColors))

                entryDoor = multigrid.Door(doorColor)
                self.grid.set(*room.entryDoorPos, entryDoor)
                prevDoorColor = doorColor

                prevRoom = roomList[idx-1]
                prevRoom.exitDoorPos = room.entryDoorPos

        # Place the final goal in the last room

        # Place red and blue goals
        # last_room_size = roomList[-1].size - np.array([2,2])
        last_room_size = np.array(roomList[-1].size)

        # import pdb; pdb.set_trace()
        top_left = roomList[-1].top
        bottom_right = roomList[-1].top + last_room_size + np.array([-1,-1])
        last_door_pos = self.doorPosList[-1]

        # print(f'last_size: {last_room_size}, top_left: {top_left}, bottom_right: {bottom_right}, last_door_pos:{last_door_pos}')

        if top_left[1] == last_door_pos[1]: #Top
            # print('Top')
            last_room_size += np.array([0,-1])
            ball_pos_offset = np.array([0,1])
        elif top_left[0] == last_door_pos[0]: # Left
            # print('Left')
            last_room_size += np.array([-1,0])
            ball_pos_offset = np.array([1,0])
        elif bottom_right[1] == last_door_pos[1]: # Bottom
            # print('Bottom')
            last_room_size += np.array([0,-1])
            ball_pos_offset = np.array([0,0])
        else: # Right
            # print('Right')
            last_room_size += np.array([-1,0])
            ball_pos_offset = np.array([0,0])

        last_room_size += np.array([-2,-2])
        ball_pos_offset += np.array([1,1])

        ball_pos_idx = self.np_random.choice(
            range(np.prod(last_room_size)), 
            size=(2,), 
            replace=False)

        self.red_pos = roomList[-1].top + np.array([
                ball_pos_idx[0]%last_room_size[0], ball_pos_idx[0]//last_room_size[0]]) + ball_pos_offset

        self.blue_pos = roomList[-1].top + np.array([
                ball_pos_idx[1]%last_room_size[0], ball_pos_idx[1]//last_room_size[0]]) + ball_pos_offset

        if self.obl_correction:
          red_outcome = self.unseeded_np_random.rand() # Maintain uniform prior over levels
        else:
          red_outcome = self.np_random.rand()

        self.red_reward, self.blue_reward = 0,0

        if red_outcome < self.p:
            self.goal_pos = self.red_pos
            # self.red_reward = self.np_random.normal(self.rewards[0], self.reward_spreads[0])
            self.red_reward = self.unseeded_np_random.rand()*(2*self.reward_spreads[0]) + self.rewards[0] - self.reward_spreads[0]
            # print('Red is goal w reward', self.red_reward)
        else:
            self.goal_pos = self.blue_pos
            # self.blue_reward = self.np_random.normal(self.rewards[1], self.reward_spreads[1])
            self.blue_reward = self.unseeded_np_random.rand()*(2*self.reward_spreads[1]) + self.rewards[1] - self.reward_spreads[1]
            # print('Blue is goal w reward', self.blue_reward)

        self.put_obj(
            multigrid.Ball(self.world, 
              index=COLOR_TO_IDX['red'],
              reward=self.red_reward),
            *self.red_pos)

        self.put_obj(
            multigrid.Ball(self.world, 
              index=COLOR_TO_IDX['blue'],
              reward=self.blue_reward), 
            *self.blue_pos)

        # Randomize the starting agent position and direction
        if len(roomList) == 1:
            first_room_choices = roomList[0].size + np.array([-2, -2])
            first_w = first_room_choices[0]
            first_top_left = roomList[0].top
            agent_pos_choices = [
                tuple(first_top_left + np.array([v%first_w + 1, v//first_w + 1])) \
                for v in range(np.prod(first_room_choices))]

            red_pos_tuple = tuple(self.red_pos)
            blue_pos_tuple = tuple(self.blue_pos)
            
            if red_pos_tuple in agent_pos_choices:
                agent_pos_choices.remove(red_pos_tuple)

            if blue_pos_tuple in agent_pos_choices:
                agent_pos_choices.remove(blue_pos_tuple)

            rand_idx = self.np_random.choice(range(len(agent_pos_choices)))
            agent_pos = agent_pos_choices[rand_idx]
            self.agent_start_pos = self.place_agent(agent_pos, np.array([1,1]))
        else: 
            self.agent_start_pos = self.place_agent(roomList[0].top, roomList[0].size)

        self.agent_start_dir = self._rand_int(0,4)

        self.mission = 'traverse the rooms to get to the goal'

        # print('reset to level', self.red_reward, self.blue_reward, self.goal_pos, flush=True)

    def _handle_pickup(self, agent_id, reward, fwd_pos, fwd_cell):
        self.done[agent_id] = True

        fwd_tuple = tuple(fwd_pos)

        if fwd_tuple == tuple(self.red_pos):
          self.target = 'red'
        elif fwd_tuple == tuple(self.blue_pos):
          self.target = 'blue'
        else:
          self.target = None

    def _placeRoom(
        self,
        numLeft,
        roomList,
        minSz,
        maxSz,
        entryDoorWall,
        entryDoorPos
    ):
        # Choose the room size randomly
        sizeX = self._rand_int(minSz, maxSz+1)
        sizeY = self._rand_int(minSz, maxSz+1)

        # The first room will be at the door position
        if len(roomList) == 0:
            topX, topY = entryDoorPos
        # Entry on the right
        elif entryDoorWall == 0:
            topX = entryDoorPos[0] - sizeX + 1
            y = entryDoorPos[1]
            topY = self._rand_int(y - sizeY + 2, y)
        # Entry wall on the south
        elif entryDoorWall == 1:
            x = entryDoorPos[0]
            topX = self._rand_int(x - sizeX + 2, x)
            topY = entryDoorPos[1] - sizeY + 1
        # Entry wall on the left
        elif entryDoorWall == 2:
            topX = entryDoorPos[0]
            y = entryDoorPos[1]
            topY = self._rand_int(y - sizeY + 2, y)
        # Entry wall on the top
        elif entryDoorWall == 3:
            x = entryDoorPos[0]
            topX = self._rand_int(x - sizeX + 2, x)
            topY = entryDoorPos[1]
        else:
            assert False, entryDoorWall

        # If the room is out of the grid, can't place a room here
        if topX < 0 or topY < 0:
            return False
        if topX + sizeX > self.width or topY + sizeY >= self.height:
            return False

        # If the room intersects with previous rooms, can't place it here
        for room in roomList[:-1]:
            nonOverlap = \
                topX + sizeX < room.top[0] or \
                room.top[0] + room.size[0] <= topX or \
                topY + sizeY < room.top[1] or \
                room.top[1] + room.size[1] <= topY

            if not nonOverlap:
                return False

        # Add this room to the list
        roomList.append(Room(
            (topX, topY),
            (sizeX, sizeY),
            entryDoorPos,
            None
        ))

        # Add door to door list
        self.doorPosList.append(entryDoorPos)

        # If this was the last room, stop
        if numLeft == 1:
            return True

        # Try placing the next room
        for i in range(0, 8):

            # Pick which wall to place the out door on
            wallSet = set((0, 1, 2, 3))
            wallSet.remove(entryDoorWall)
            exitDoorWall = self._rand_elem(sorted(wallSet))
            nextEntryWall = (exitDoorWall + 2) % 4

            # Pick the exit door position
            # Exit on right wall
            if exitDoorWall == 0:
                exitDoorPos = (
                    topX + sizeX - 1,
                    topY + self._rand_int(1, sizeY - 1)
                )
            # Exit on south wall
            elif exitDoorWall == 1:
                exitDoorPos = (
                    topX + self._rand_int(1, sizeX - 1),
                    topY + sizeY - 1
                )
            # Exit on left wall
            elif exitDoorWall == 2:
                exitDoorPos = (
                    topX,
                    topY + self._rand_int(1, sizeY - 1)
                )
            # Exit on north wall
            elif exitDoorWall == 3:
                exitDoorPos = (
                    topX + self._rand_int(1, sizeX - 1),
                    topY
                )
            else:
                assert False

            # Recursively create the other rooms
            success = self._placeRoom(
                numLeft - 1,
                roomList=roomList,
                minSz=minSz,
                maxSz=maxSz,
                entryDoorWall=nextEntryWall,
                entryDoorPos=exitDoorPos
            )

            if success:
                break

        return True

    def step(self, actions):
        obs, r, done, info = super().step(actions)

        if done:
            info['target'] = self.target

        # print('reward:',r)

        return obs, r, done, info

    def goal_color(self):
        goal_pos_tuple = tuple(self.goal_pos)
        if goal_pos_tuple == tuple(self.red_pos):
            return 'red'
        else:
            return 'blue'

class MultiRoomEnvN2S4(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=2,
            maxNumRooms=2,
            maxRoomSize=4,

            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )

class MultiRoomEnvN4S5(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=4,
            maxNumRooms=4,
            maxRoomSize=5,
            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )

class MultiRoomEnvN6(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=6,
            maxNumRooms=6,

            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )

# Custom levels
class MultiRoomEnvN2S4G14(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=2,
            maxNumRooms=2,
            maxRoomSize=4,
            gridSize=14,

            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )

class MultiRoomEnvN4S5G14(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=4,
            maxNumRooms=4,
            maxRoomSize=5,
            gridSize=14,

            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )

class MultiRoomEnvN4Random(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=1,
            maxNumRooms=4,
            maxRoomSize=5,
            gridSize=14,
            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )

class MultiRoomEnvN4S7Random(MultiRoomEnv):
    def __init__(
        self, 
        seed=0, 
        fixed_environment=False, 
        p=0.5, 
        rewards=(0.75,1.0), 
        reward_spreads=(10,0), 
        use_walls=False,
        obl_correction=False):
        super().__init__(
            minNumRooms=1,
            maxNumRooms=4,
            maxRoomSize=7,
            gridSize=13,

            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            obl_correction=obl_correction,
            fixed_environment=fixed_environment,
            seed=seed
        )


if hasattr(__loader__, 'name'):
  module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
  module_path = __loader__.fullname


register.register(
    env_id='MultiGrid-MultiRoomBC-N2-S4-v0',
    entry_point=module_path + ':MultiRoomEnvN2S4',
    max_episode_steps=40
)

register.register(
    env_id='MultiGrid-MultiRoomBC-N4-S5-v0',
    entry_point=module_path + ':MultiRoomEnvN4S5',
    max_episode_steps=80
)

register.register(
    env_id='MultiGrid-MultiRoomBC-N6-v0',
    entry_point=module_path + ':MultiRoomEnvN6',
    max_episode_steps=120
)

register.register(
    env_id='MultiGrid-MultiRoomBC-N4-Random-v0',
    entry_point=module_path + ':MultiRoomEnvN4Random',
    max_episode_steps=80
)

# Custom levels
register.register(
    env_id='MultiGrid-MultiRoomBC-N4-S7-Random-v0',
    entry_point=module_path + ':MultiRoomEnvN4S7Random',
    max_episode_steps=80
)

register.register(
    env_id='MultiGrid-MultiRoomBC-N2-S4-G14-v0',
    entry_point=module_path + ':MultiRoomEnvN2S4G14',
    max_episode_steps=40
)
register.register(
    env_id='MultiGrid-MultiRoomBC-N4-S5-G14-v0',
    entry_point=module_path + ':MultiRoomEnvN4S5G14',
    max_episode_steps=80
)
