from gym_minigrid.minigrid import *
from gym_minigrid.register import register

from .base import SimpleEnv

class SimpleXIslandEnv(SimpleEnv):
    def __init__(
        self,
        size=8,
        agent_start_pos=None,
        agent_start_dir=0,
    ):
        width = size+2
        height = 7

        if agent_start_pos is None:
            self.agent_start_pos = (1, 1) 
        elif agent_start_pos == "random":
            self.agent_start_pos = None
        else:
            self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self._init_pos = None

        super().__init__(
            width=width,
            height=height,
            max_steps=4*size*size,
            # Set this to True for maximum speed
            see_through_walls=True
        )

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)
        #island 1
        # self.grid.wall_rect(2, 2, 3, height-4)
        self.grid.wall_rect(3, 3, 1, 1)
        self.grid.wall_rect(2, 2, 1, 1)
        self.grid.wall_rect(4, 4, 1, 1)
        self.grid.wall_rect(4, 2, 1, 1)
        self.grid.wall_rect(2, 4, 1, 1)
        # island 2
        # self.grid.wall_rect(width - (2 + 3), 2, 3, height-4)
        # self.grid.wall_rect(width - (2 + 3 - 1), 3, 1, 1)

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self._init_pos = self.agent_start_pos
        else:
            self._init_pos = self.place_agent()

        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_start_dir

        self.mission = "get to the green goal square"

    def _reward(self):
        """
        Compute the reward to be given upon success
        """
        min_step_count_to_goal = self._compute_min_step_count_to_goal()
        return min_step_count_to_goal / self.step_count

    def _compute_min_step_count_to_goal(self):
        j, i = self._init_pos
        steps = {
            (1,1): 8, (1,2): 7, (1,3): 6, (1,4): 5, (1,5): 4,
            (2,1): 7,           (2,3): 7,           (2,5): 3,
            (3,1): 6, (3,2): 7,           (3,4): 3, (3,5): 2,
            (4,1): 5,           (4,3): 3,           (4,5): 1,
            (5,1): 4, (5,2): 3, (5,3): 2, (5,4): 1, (5,5): 0
        }.get((i,j))

        if steps is None:
            raise ValueError((i,j))
        return steps


class SimpleXIslandBisEnv(SimpleXIslandEnv):
    
    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)
        #island 1
        # self.grid.wall_rect(2, 2, 3, height-4)
        self.grid.wall_rect(3, 3, 1, 1)
        self.grid.wall_rect(2, 2, 1, 1)
        self.grid.wall_rect(4, 4, 1, 1)
        self.grid.wall_rect(4, 2, 1, 1)
        self.grid.wall_rect(2, 4, 1, 1)
        # island 2
        # self.grid.wall_rect(width - (2 + 3), 2, 3, height-4)
        # self.grid.wall_rect(width - (2 + 3 - 1), 3, 1, 1)

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 4, height - 3)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self._init_pos = self.agent_start_pos
        else:
            self._init_pos = self.place_agent()

        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_start_dir

        self.mission = "get to the green goal square"

    def _compute_min_step_count_to_goal(self):
        j, i = self._init_pos
        steps = {
            (1,1): 7, (1,2): 8, (1,3): 9, (1,4): 8,  (1,5): 7,
            (2,1): 6,           (2,3): 10,           (2,5): 6,
            (3,1): 5, (3,2): 6,           (3,4): 6,  (3,5): 5,
            (4,1): 4,           (4,3): 0,            (4,5): 4,
            (5,1): 3, (5,2): 2, (5,3): 1, (5,4): 2,  (5,5): 3
        }.get((i,j))

        if steps is None:
            raise ValueError((i,j))
        return steps



class SimpleXIslandEnv5(SimpleXIslandEnv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class SimpleXIslandBisEnv5(SimpleXIslandBisEnv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class SimpleXIslandEnv10(SimpleXIslandEnv):
    def __init__(self, **kwargs):
        super().__init__(size=10, **kwargs)

class SimpleXIslandBisEnv10(SimpleXIslandBisEnv):
    def __init__(self, **kwargs):
        super().__init__(size=10, **kwargs)

register(
    id='MiniGrid-SimpleXIslandEnv-5-v0',
    entry_point='custom_minigrid.envs:SimpleXIslandEnv5'
)

register(
    id='MiniGrid-SimpleXIslandBisEnv-5-v0',
    entry_point='custom_minigrid.envs:SimpleXIslandBisEnv5'
)

register(
    id='MiniGrid-SimpleXIslandEnv-10-v0',
    entry_point='custom_minigrid.envs:SimpleXIslandEnv10'
)

register(
    id='MiniGrid-SimpleXIslandBisEnv-10-v0',
    entry_point='custom_minigrid.envs:SimpleXIslandBisEnv10'
)