from gym_minigrid.minigrid import *
from gym_minigrid.register import register

from .base import SimpleEnv

def random_start_pos(width, height):

    options = []
    for i in range(1, height):
        if i % 2 == 0:
            if i % 4 == 0:
                options.append((1, i))
            else:
                options.append((width-1, i))
        else:
            options.extend([(j, i) for j in range(1, width)])
    return options[np.random.choice(np.arange(len(options)))]

class SimpleMazeEnv(SimpleEnv):
    # XXXXXXX
    # X@....X
    # XXXXX.X
    # X.....X
    # X.XXXXX
    # X....GX
    # XXXXXXX

    def __init__(
        self,
        size=8,
        agent_start_pos=None,
        agent_start_dir=0,
    ):
        width = size+2
        height = size+2

        if agent_start_pos is None:
            self.agent_start_pos = (1, 1) 
        elif agent_start_pos == "random":
            self.agent_start_pos = None
        else:
            self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self._init_pos = None

        super().__init__(
            width=width,
            height=height,
            max_steps=4*size*size,
            # Set this to True for maximum speed
            see_through_walls=True
        )

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)
        self.grid.horz_wall(1, 2, width-3)
        self.grid.horz_wall(2, 4, width-3)

        # Place a goal square in the bottom-right corner
        if (height - 2) % 3 == 0:
            self.put_obj(Goal(), 1, height - 2)
        else:
            self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self._init_pos = self.agent_start_pos
        else:
            self._init_pos = self.place_agent()

        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_start_dir

        self.mission = "get to the green goal square"

    def _reward(self):
        """
        Compute the reward to be given upon success
        """
        min_step_count_to_goal = self._compute_min_step_count_to_goal()
        return min_step_count_to_goal / self.step_count

    def _compute_min_step_count_to_goal(self):
        j, i = self._init_pos
        if i == 1:
            return 17 - j
        elif i == 2:
            return 11
        elif i == 3:
            return 5 + j
        elif i == 4:
            return 5
        elif i == 5:
            return 5 - j
        else:
            raise ValueError(self._init_pos)


class ComplexMazeEnv(SimpleMazeEnv):
    # XXXXXXX
    # X@....X
    # X.XXX.X
    # X.....X
    # XX.XXXX
    # X....GX
    # XXXXXXX

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)
        self.grid.horz_wall(2, 2, width-4)
        self.grid.horz_wall(1, 4, 1)
        self.grid.horz_wall(3, 4, width-4)
        

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
        else:
            self.place_agent()

        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_start_dir

        self.mission = "get to the green goal square"


class SimpleMazeEnvInv(SimpleEnv):
    # XXXXXXX
    # XG....X
    # XXXXX.X
    # X.....X
    # X.XXXXX
    # X....@X
    # XXXXXXX

    def __init__(
        self,
        size=8,
        agent_start_pos=None,
        agent_start_dir=2,
    ):

        width = size+2
        height = 7

        if agent_start_pos is None:
            self.agent_start_pos = (width - 2, height - 2)
        elif agent_start_pos == "random":
            self.agent_start_pos = None
        else:
            self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self._init_pos = None

        super().__init__(
            width=width,
            height=height,
            max_steps=4*size*size,
            # Set this to True for maximum speed
            see_through_walls=True
        )

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)
        self.grid.horz_wall(1, 2, width-3)
        self.grid.horz_wall(2, 4, width-3)

        self.put_obj(Goal(), 1, 1)

        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self._init_pos = self.agent_start_pos
        else:
            self._init_pos = self.place_agent()

        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_start_dir

        self.mission = "get to the green goal square"

    def _reward(self):
        """
        Compute the reward to be given upon success
        """
        min_step_count_to_goal = self._compute_min_step_count_to_goal()
        return min_step_count_to_goal / self.step_count

    def _compute_min_step_count_to_goal(self):
        j, i = self._init_pos
        if i == 1:
            return j - 1
        elif i == 2:
            return 5
        elif i == 3:
            return 11 - j
        elif i == 4:
            return 11
        elif i == 5:
            return 11 + j
        else:
            raise ValueError(self._init_pos)


class ComplexMazeEnvInv(SimpleMazeEnv):
    # XXXXXXX
    # XG....X
    # X.XXX.X
    # X.....X
    # XX.XXXX
    # X....@X
    # XXXXXXX

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)
        self.grid.horz_wall(2, 2, width-4)
        self.grid.horz_wall(1, 4, 1)
        self.grid.horz_wall(3, 4, width-4)
        

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), 1, 1)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
        else:
            self.place_agent()

        if self.agent_start_dir is not None:
            self.agent_dir = self.agent_start_dir

        self.mission = "get to the green goal square"


class SimpleMazeEnv3(SimpleMazeEnv):
    def __init__(self, **kwargs):
        super().__init__(size=3, **kwargs)

class SimpleMazeEnv5(SimpleMazeEnv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class SimpleMazeEnvInv5(SimpleMazeEnvInv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class ComplexMazeEnv5(ComplexMazeEnv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class ComplexMazeEnvInv5(ComplexMazeEnvInv):
    def __init__(self, **kwargs):
        super().__init__(size=5, **kwargs)

class SimpleMazeEnv10(SimpleMazeEnv):
    def __init__(self, **kwargs):
        super().__init__(size=10, **kwargs)

class SimpleMazeEnvInv10(SimpleMazeEnv):
    def __init__(self, **kwargs):
        super().__init__(size=10, **kwargs)

class ComplexMazeEnv10(ComplexMazeEnv):
    def __init__(self, **kwargs):
        super().__init__(size=10, **kwargs)

class ComplexMazeEnvInv10(ComplexMazeEnvInv):
    def __init__(self, **kwargs):
        super().__init__(size=10, **kwargs)

register(
    id='MiniGrid-SimpleMaze-3-v0',
    entry_point='custom_minigrid.envs:SimpleMazeEnv3'
)

register(
    id='MiniGrid-SimpleMaze-5-v0',
    entry_point='custom_minigrid.envs:SimpleMazeEnv5'
)

register(
    id='MiniGrid-SimpleMazeInv-5-v0',
    entry_point='custom_minigrid.envs:SimpleMazeEnvInv5'
)

register(
    id='MiniGrid-ComplexMaze-5-v0',
    entry_point='custom_minigrid.envs:ComplexMazeEnv5'
)

register(
    id='MiniGrid-ComplexMazeInv-5-v0',
    entry_point='custom_minigrid.envs:ComplexMazeEnvInv5'
)

register(
    id='MiniGrid-SimpleMaze-10-v0',
    entry_point='custom_minigrid.envs:SimpleMazeEnv10'
)

register(
    id='MiniGrid-SimpleMazeInv-10-v0',
    entry_point='custom_minigrid.envs:SimpleMazeEnvInv10'
)

register(
    id='MiniGrid-ComplexMaze-10-v0',
    entry_point='custom_minigrid.envs:ComplexMazeEnv10'
)

register(
    id='MiniGrid-ComplexMazeInv-10-v0',
    entry_point='custom_minigrid.envs:ComplexMazeEnvInv10'
)