# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gym_minigrid.minigrid as minigrid
import numpy as np
from . import multigrid

# from . import register

import envs.registration as register


class MazeEnv(multigrid.MultiGridEnv):
    """Single-agent maze environment specified via a bit map."""

    def __init__(
        self,
        agent_view_size=5,
        minigrid_mode=True,
        max_steps=None,
        bit_map=None,
        start_pos=None,
        goal_pos=None,
        size=15,
    ):
        default_agent_start_x = 7
        default_agent_start_y = 1
        default_goal_start_x = 7
        default_goal_start_y = 13
        self.start_pos = (
            np.array([default_agent_start_x, default_agent_start_y])
            if start_pos is None
            else start_pos
        )
        self.goal_pos = (
            (default_goal_start_x, default_goal_start_y) if goal_pos is None else goal_pos
        )

        if max_steps is None:
            max_steps = 2 * size * size

        if bit_map is not None:
            bit_map = np.array(bit_map)
            if bit_map.shape != (size - 2, size - 2):
                print("Error! Bit map shape does not match size. Using default maze.")
                bit_map = None

        if bit_map is None:
            self.bit_map = np.array(
                [
                    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
                    [0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0],
                    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1],
                    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                    [1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
                    [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0],
                    [0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1],
                    [0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
                    [1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0],
                    [1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
                    [1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0],
                    [0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0],
                ]
            )
        else:
            self.bit_map = bit_map

        super().__init__(
            n_agents=1,
            grid_size=size,
            agent_view_size=agent_view_size,
            max_steps=max_steps,
            see_through_walls=True,  # Set this to True for maximum speed
            minigrid_mode=minigrid_mode,
        )

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = multigrid.Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Goal
        self.put_obj(minigrid.Goal(), self.goal_pos[0], self.goal_pos[1])

        # Agent
        self.place_agent_at_pos(0, self.start_pos)

        # Walls
        for x in range(self.bit_map.shape[0]):
            for y in range(self.bit_map.shape[1]):
                if self.bit_map[y, x]:
                    # Add an offset of 1 for the outer walls
                    self.put_obj(minigrid.Wall(), x + 1, y + 1)


class HorizontalMazeEnv(MazeEnv):
    """A short but non-optimal path is 80 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([1, 7])
        goal_pos = np.array([13, 5])
        bit_map = np.array(
            [
                [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0],
                [0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1],
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1],
                [0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0],
                [1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0],
                [0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0],
                [0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0],
                [0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0],
                [0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class Maze3Env(MazeEnv):
    """A short but non-optimal path is 80 moves."""

    def __init__(self):
        # positions go col, row and indexing starts at 1
        start_pos = np.array([4, 1])
        goal_pos = np.array([13, 7])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0],
                [0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1],
                [1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0],
                [0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1],
                [0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0],
                [0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0],
                [0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
                [0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class SmallCorridorEnv(MazeEnv):
    """A shorter backtracking env."""

    def __init__(self):
        # positions go col, row and indexing starts at 1
        start_pos = np.array([1, 7])
        row = np.random.choice([6, 8])
        col = np.random.choice([3, 5, 7, 9, 11])
        goal_pos = np.array([col, row])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class LargeCorridorEnv(MazeEnv):
    """A long backtracking env."""

    def __init__(self):
        # positions go col, row and indexing starts at 1
        start_pos = np.array([1, 10])
        row = np.random.choice([9, 11])
        col = np.random.choice([3, 5, 7, 9, 11, 13, 15, 17])
        goal_pos = np.array([col, row])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=21, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class LabyrinthEnv(MazeEnv):
    """A short but non-optimal path is 118 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([1, 13])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
                [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class Labyrinth2Env(MazeEnv):
    """A short but non-optimal path is 118 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class NineRoomsEnv(MazeEnv):
    """Can be completed in 27 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([2, 2])
        goal_pos = np.array([12, 12])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class NineRoomsFewerDoorsEnv(MazeEnv):
    """Can be completed in 27 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([2, 2])
        goal_pos = np.array([12, 12])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class SixteenRoomsEnv(MazeEnv):
    """Can be completed in 16 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([2, 2])
        goal_pos = np.array([12, 12])
        bit_map = np.array(
            [
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0],
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1],
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class SixteenRoomsFewerDoorsEnv(MazeEnv):
    """Can be completed in 16 moves."""

    def __init__(self):
        # positions go col, row
        start_pos = np.array([2, 2])
        goal_pos = np.array([12, 12])
        bit_map = np.array(
            [
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
            ]
        )
        super().__init__(size=15, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class MiniMazeEnv(MazeEnv):
    """A smaller maze for debugging."""

    def __init__(self):
        start_pos = np.array([1, 1])
        goal_pos = np.array([1, 3])
        bit_map = np.array(
            [
                [0, 0, 0, 0],
                [1, 1, 1, 0],
                [0, 1, 0, 0],
                [0, 0, 0, 1],
            ]
        )
        super().__init__(size=6, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class MediumMazeEnv(MazeEnv):
    """A 10x10 Maze environment."""

    def __init__(self):
        start_pos = np.array([5, 1])
        goal_pos = np.array([3, 8])
        bit_map = np.array(
            [
                [0, 1, 0, 0, 0, 1, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 0],
                [0, 1, 0, 1, 1, 1, 1, 0],
                [0, 0, 0, 0, 0, 1, 0, 0],
                [1, 1, 1, 1, 0, 1, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 1, 1, 1, 1, 1, 0],
                [0, 0, 0, 1, 0, 0, 0, 0],
            ]
        )
        super().__init__(size=10, bit_map=bit_map, start_pos=start_pos, goal_pos=goal_pos)


class EmptyField9x9Env(MazeEnv):
    """Trivially easy: fully open interior."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.zeros((7, 7), dtype=int)
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class SingleWallDoor9x9Env(MazeEnv):
    """One vertical wall with a single doorway (very easy detour)."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0],  # doorway
                [0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class EasyZigZag9x9Env(MazeEnv):
    """Sparse pillars/short walls that induce 1–2 gentle turns."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


# === Moderate Envs ===


class TwoDoorBarrier9x9Env(MazeEnv):
    """Long vertical barrier with two doors; small detour near end."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0],  # door #1
                [0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0],  # door #2
                [0, 0, 0, 1, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class LShapedCorridor9x9Env(MazeEnv):
    """Horizontal wall with doorway + vertical segment; L-shaped detour."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0],
                [1, 1, 1, 0, 1, 1, 1],  # doorway at col=3
                [0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 1, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class SnakingPassage9x9Env(MazeEnv):
    """Two barriers create a gentle 'snake' path to the goal."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 0, 1, 0, 0, 0, 0],
                [0, 0, 1, 0, 1, 0, 0],
                [0, 0, 1, 0, 1, 0, 0],
                [0, 0, 0, 0, 1, 0, 0],  # opening
                [0, 1, 1, 1, 1, 0, 0],
                [0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


# === 11x11 SIMPLE MAZE VARIANTS ===


class EmptyField11x11Env(MazeEnv):
    """Trivially easy: fully open interior."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.zeros((9, 9), dtype=int)
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class SingleWallDoor11x11Env(MazeEnv):
    """One vertical wall with a single doorway (very easy detour)."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],  # door
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class EasyZigZag11x11Env(MazeEnv):
    """Sparse pillars/short walls that induce 1–2 gentle turns."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class TwoDoorBarrier11x11Env(MazeEnv):
    """Long vertical barrier with two doors."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],  # door 1
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],  # door 2
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class LShapedCorridor11x11Env(MazeEnv):
    """Horizontal wall with doorway + vertical segment."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 1, 1, 0, 1, 1, 1, 1, 1],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class SnakingPassage11x11Env(MazeEnv):
    """Two barriers create a gentle 'snake' path to the goal."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 1, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 0, 1, 0, 0, 0, 0],
                [0, 0, 1, 0, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0, 0, 0, 0],
                [0, 1, 1, 1, 1, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


# === 11x11 LABYRINTH VARIANTS ===


class Labyrinth11x11Env(MazeEnv):
    """Dense labyrinth with a main corridor."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 0, 0, 1, 0],
                [1, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 1, 1, 1, 0, 1, 0, 1, 1],
                [0, 0, 0, 1, 0, 0, 0, 0, 0],
                [0, 1, 0, 1, 1, 1, 1, 1, 0],
                [0, 1, 0, 0, 0, 0, 0, 1, 0],
                [0, 1, 1, 1, 0, 1, 0, 1, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class LabyrinthDense11x11Env(MazeEnv):
    """Denser labyrinth with multiple detours."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 1, 0, 1, 0, 1, 0, 1],
                [0, 0, 0, 1, 0, 1, 0, 0, 0],
                [1, 0, 1, 1, 0, 1, 1, 0, 1],
                [0, 0, 0, 0, 0, 0, 1, 0, 0],
                [1, 1, 1, 0, 1, 0, 1, 1, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 1, 1, 1, 1, 1, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 1, 0, 1, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class CrossRooms11x11Env(MazeEnv):
    """Rooms-with-cross-walls + doorways; labyrinth-like structured layout."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 0, 1, 0, 1, 0, 1, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 1, 0, 1, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 1, 0, 1, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [1, 0, 1, 0, 1, 0, 1, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 1, 0, 1, 0, 1, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class Maze11x11Env(MazeEnv):
    """Sparser 'maze' variant: fewer walls, moderate difficulty."""

    def __init__(self):
        size = 11
        start_pos = np.array([1, 1])
        goal_pos = np.array([9, 9])
        bit_map = np.array(
            [
                [0, 1, 0, 0, 0, 1, 0, 0, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 0, 0, 1, 0],
                [1, 1, 0, 1, 1, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 1, 1, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


# === Registration ===


if hasattr(__loader__, "name"):
    module_path = __loader__.name
elif hasattr(__loader__, "fullname"):
    module_path = __loader__.fullname

# module_path = __name__


register.register(
    id="MultiGrid-EmptyField9x9-v0",
    entry_point=module_path + ":EmptyField9x9Env",
)

register.register(
    id="MultiGrid-SingleWallDoor9x9-v0",
    entry_point=module_path + ":SingleWallDoor9x9Env",
)

register.register(
    id="MultiGrid-EasyZigZag9x9-v0",
    entry_point=module_path + ":EasyZigZag9x9Env",
)

register.register(
    id="MultiGrid-TwoDoorBarrier9x9-v0",
    entry_point=module_path + ":TwoDoorBarrier9x9Env",
)

register.register(
    id="MultiGrid-LShapedCorridor9x9-v0",
    entry_point=module_path + ":LShapedCorridor9x9Env",
)

register.register(
    id="MultiGrid-SnakingPassage9x9-v0",
    entry_point=module_path + ":SnakingPassage9x9Env",
)


register.register(id="MultiGrid-Maze-v0", entry_point=module_path + ":MazeEnv")

register.register(id="MultiGrid-MiniMaze-v0", entry_point=module_path + ":MiniMazeEnv")

register.register(id="MultiGrid-MediumMaze-v0", entry_point=module_path + ":MediumMazeEnv")

register.register(id="MultiGrid-Maze2-v0", entry_point=module_path + ":HorizontalMazeEnv")

register.register(id="MultiGrid-Maze3-v0", entry_point=module_path + ":Maze3Env")

register.register(
    id="MultiGrid-SmallCorridor-v0", entry_point=module_path + ":SmallCorridorEnv"
)

register.register(
    id="MultiGrid-LargeCorridor-v0", entry_point=module_path + ":LargeCorridorEnv"
)

register.register(id="MultiGrid-Labyrinth-v0", entry_point=module_path + ":LabyrinthEnv")

register.register(id="MultiGrid-Labyrinth2-v0", entry_point=module_path + ":Labyrinth2Env")

register.register(
    id="MultiGrid-SixteenRooms-v0", entry_point=module_path + ":SixteenRoomsEnv"
)

register.register(
    id="MultiGrid-SixteenRoomsFewerDoors-v0",
    entry_point=module_path + ":SixteenRoomsFewerDoorsEnv",
)

register.register(id="MultiGrid-NineRooms-v0", entry_point=module_path + ":NineRoomsEnv")

register.register(
    id="MultiGrid-NineRoomsFewerDoors-v0",
    entry_point=module_path + ":NineRoomsFewerDoorsEnv",
)


# === 7x7 test variants ===


class Maze9x9Env(MazeEnv):
    """Size-9 maze with a single clear route; interior bitmap is 7x7."""

    def __init__(self):
        size = 9  # full grid, outer wall included
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 1, 0, 0, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [1, 1, 0, 1, 1, 1, 0],
                [0, 0, 0, 0, 0, 1, 0],
                [0, 1, 1, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class Labyrinth9x9Env(MazeEnv):
    """Size-9 labyrinth layout; denser walls but solvable."""

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        bit_map = np.array(
            [
                [0, 1, 0, 1, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 1, 0, 0, 0],
                [1, 1, 0, 1, 1, 1, 0],
                [0, 0, 0, 0, 0, 1, 0],
                [0, 1, 0, 1, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0],
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


class SixteenRooms9x9Env(MazeEnv):
    """
    Size-9 “rooms” style: interior cross-walls with doorways.
    Emulates a tiny multi-room vibe on 7x7 interior.
    """

    def __init__(self):
        size = 9
        start_pos = np.array([1, 1])
        goal_pos = np.array([7, 7])
        # Cross walls at x=4 and y=4 (in full grid), with doorways.
        bit_map = np.array(
            [
                [0, 0, 1, 0, 1, 0, 0],  # y=0
                [0, 0, 0, 0, 0, 0, 0],  # doors in vertical wall
                [1, 0, 1, 0, 1, 0, 1],  # cross with doors
                [0, 0, 0, 0, 0, 0, 0],  # doors in vertical wall
                [1, 0, 1, 0, 1, 0, 1],  # central cross
                [0, 0, 0, 0, 0, 0, 0],  # doors in vertical wall
                [0, 0, 1, 0, 1, 0, 0],  # y=6
            ]
        )
        super().__init__(
            size=size,
            bit_map=bit_map,
            start_pos=start_pos,
            goal_pos=goal_pos,
            agent_view_size=3,
        )


# === Registrations for the 9x9 tests ===
register.register(id="MultiGrid-Maze-9x9-v0", entry_point=module_path + ":Maze9x9Env")

register.register(
    id="MultiGrid-Labyrinth-9x9-v0", entry_point=module_path + ":Labyrinth9x9Env"
)

register.register(
    id="MultiGrid-SixteenRooms-9x9-v0", entry_point=module_path + ":SixteenRooms9x9Env"
)

# === Registrations for 11x11 Environments ===

register.register(
    id="MultiGrid-EmptyField11x11-v0", entry_point=module_path + ":EmptyField11x11Env"
)
register.register(
    id="MultiGrid-SingleWallDoor11x11-v0",
    entry_point=module_path + ":SingleWallDoor11x11Env",
)
register.register(
    id="MultiGrid-EasyZigZag11x11-v0", entry_point=module_path + ":EasyZigZag11x11Env"
)
register.register(
    id="MultiGrid-TwoDoorBarrier11x11-v0",
    entry_point=module_path + ":TwoDoorBarrier11x11Env",
)
register.register(
    id="MultiGrid-LShapedCorridor11x11-v0",
    entry_point=module_path + ":LShapedCorridor11x11Env",
)
register.register(
    id="MultiGrid-SnakingPassage11x11-v0",
    entry_point=module_path + ":SnakingPassage11x11Env",
)

register.register(
    id="MultiGrid-Labyrinth-11x11-v0", entry_point=module_path + ":Labyrinth11x11Env"
)
register.register(
    id="MultiGrid-LabyrinthDense-11x11-v0",
    entry_point=module_path + ":LabyrinthDense11x11Env",
)
register.register(
    id="MultiGrid-CrossRooms-11x11-v0", entry_point=module_path + ":CrossRooms11x11Env"
)
register.register(id="MultiGrid-Maze-11x11-v0", entry_point=module_path + ":Maze11x11Env")
