from typing import Tuple

import numpy as np
from gymnasium import spaces
from miniworld.entity import Box, ImageFrame
from miniworld.miniworld import MiniWorldEnv, Room, DEFAULT_WALL_HEIGHT

from envs import ConditionalActionEnv


class NamedRoom(Room):

    def __init__(
            self,
            outline,
            wall_height=DEFAULT_WALL_HEIGHT,
            floor_tex="floor_tiles_bw",
            wall_tex="concrete",
            ceil_tex="concrete_tiles",
            no_ceiling=False,
            name=None,
    ):
        super().__init__(outline, wall_height, floor_tex, wall_tex, ceil_tex, no_ceiling)
        self.name = name


class CustomMiniWorldEnv(MiniWorldEnv):

    def add_room(self, **kwargs):
        """
        Create a new room
        """

        assert len(self.wall_segs) == 0, "cannot add rooms after static data is generated"

        room = NamedRoom(**kwargs)
        self.rooms.append(room)

        return room

    def connect_rooms(
            self,
            room_a,
            room_b,
            min_x=None,
            max_x=None,
            min_z=None,
            max_z=None,
            max_y=None,
            name=None,
    ):
        """
        Connect two rooms along facing edges
        """

        def find_facing_edges():
            for idx_a in range(room_a.num_walls):
                norm_a = room_a.edge_norms[idx_a]

                for idx_b in range(room_b.num_walls):
                    norm_b = room_b.edge_norms[idx_b]

                    # Reject edges that are not facing each other
                    if np.dot(norm_a, norm_b) > -0.9:
                        continue

                    dir = room_b.outline[idx_b] - room_a.outline[idx_a]

                    # Reject edges that are not touching
                    if np.dot(norm_a, dir) > 0.05:
                        continue

                    return idx_a, idx_b

            return None, None

        idx_a, idx_b = find_facing_edges()
        assert idx_a is not None, "matching edges not found in connect_rooms"

        start_a, end_a = room_a.add_portal(
            edge=idx_a,
            min_x=min_x,
            max_x=max_x,
            min_z=min_z,
            max_z=max_z,
            max_y=max_y
        )

        start_b, end_b = room_b.add_portal(
            edge=idx_b,
            min_x=min_x,
            max_x=max_x,
            min_z=min_z,
            max_z=max_z,
            max_y=max_y
        )

        a = room_a.outline[idx_a] + room_a.edge_dirs[idx_a] * start_a
        b = room_a.outline[idx_a] + room_a.edge_dirs[idx_a] * end_a
        c = room_b.outline[idx_b] + room_b.edge_dirs[idx_b] * start_b
        d = room_b.outline[idx_b] + room_b.edge_dirs[idx_b] * end_b

        # If the portals are directly connected, stop
        if np.linalg.norm(a - d) < 0.001:
            return

        len_a = np.linalg.norm(b - a)
        len_b = np.linalg.norm(d - c)

        # Room outline points must be specified in counter-clockwise order
        outline = np.stack([c, b, a, d])
        outline = np.stack([outline[:, 0], outline[:, 2]], axis=1)

        max_y = max_y if max_y is not None else room_a.wall_height

        room = NamedRoom(
            outline,
            wall_height=max_y,
            wall_tex=room_a.wall_tex_name,
            floor_tex=room_a.floor_tex_name,
            ceil_tex=room_a.ceil_tex_name,
            no_ceiling=room_a.no_ceiling,
            name=name,
        )

        self.rooms.append(room)

        room.add_portal(1, start_pos=0, end_pos=len_a)
        room.add_portal(3, start_pos=0, end_pos=len_b)


class Vault(CustomMiniWorldEnv, ConditionalActionEnv):
    """
    Multi-room environment with treasure in 5th room
    The agent must:
        - navigate to the door
        - navigate to the treasure
        - collect the treasure
    """
    portraits = [
        "adelaide_hanscom",
        "alessandro_allori",
        "alexandre_cabanel",
        "alexei_harlamov",
        "alexey_petrovich_antropov",
        "alice_pike_barney",
        "aman_theodor",
        "antonello_messina",
        "antonio_herrera_toro",
        "benjamin-constant",
        "benoist_marie-guillemine",
        "bouguereau_william-adolphe",
        "byron",
        "carl_fredric_breda",
        "cramacj_lucas",
        "cristobal_rojas",
        "delacroix_eugene_ferdinand_victor",
        "domenikos_theotokopoulos",
        "edmund_blair_leighton",
        "edwin_longsden_long",
        "falero_luis_ricardo",
        "felix_bonfils",
        "francesco_hayez",
        "francisco_goya_lucientes",
        "francisco_zurbaran",
        "franz_von_defregger",
        "frederic_westin",
        "frederic_yates",
        "frederick_leighton",
        "gaston_bussiere",
        "george_henry_hall",
        "giovanni_battista_tiepolo",
        "giovanni_bellini",
        "hans_holbein",
        "hayez_francesco",
        "henryk_siemiradzki",
        "ilja_jefimowitsch_repin",
        "james_carrol_beckwith",
        "jean-baptiste-camille_corot",
        "jean-leon_gerome",
        "john_william_godward",
        "julije_klovic",
        "juriaen_streek",
        "kiprenskij_orest_adamovic",
        "konstantin_makovsky",
        "lefebvre_jules_joseph",
        "leon-francois_comerre",
        "leopold_loffler",
        "lewis_john_frederick",
        "madrazo_garreta_raimundo",
        "marie_bashkirtseff",
        "moritz_kellerhoven",
        "nathaniel_jocelyn",
        "nikolai_alexandrowitsch_jaroschenko",
        "nils_johan_olsson_blommer",
        "paolo_veronese",
        "parmigianino",
        "paul_cesar_helleu",
        "regnault_henri",
        "richard_bergh",
        "robert_dampier",
        "robert_lefevre",
        "robert_leopold",
        "sichel_nathanael",
        "svetoslav_roerich",
        "velazquez_diego",
        "viktor_vasnetsov",
        "william-adolphe_bouguereau"
    ]

    @property
    def available_mask(self) -> Tuple:

        # [forward, centre, clockwise, anticlockwise, door]
        rooms = [room for room in self.rooms if 'room' in room.name]
        room_centres = [self.get_room_centre(room) for room in rooms]
        dist_to_centres = [self.dist(self.agent.pos, room_ctr) for room_ctr in room_centres]
        current_room = np.argmin(dist_to_centres)

        vector_to_centre = self.room_centres[current_room] - self.agent.pos
        dist_to_centre = np.linalg.norm(vector_to_centre)

        if self.agent.carrying:
            # if np.linalg.norm(self.gold.pos - self.agent.pos) < 0.5 and self.main_rooms[current_room].name == "room4":
            # we have reached the gold - nothing else to do, so no options left
            return 0, 0, 0, 0, 0

        get_gold = int(self.main_rooms[current_room].name == "room4")
        # can only walk to centre if you're not already there
        to_centre = int(dist_to_centre > 0.4 and self.main_rooms[current_room].name != "room4")

        # can always walk to hallways unless in final room
        to_hallway_clockwise = int(self.main_rooms[current_room].name != "room4")
        to_hallway_anticlockwise = int(self.main_rooms[current_room].name != "room4")

        to_vault = int(self.main_rooms[current_room].name == "room3")

        return get_gold, to_centre, to_hallway_clockwise, to_hallway_anticlockwise, to_vault

    def __init__(self, max_option_steps=20, add_portraits=False, goal_conditioned=False,
                 use_initiation_vector=False, **kwargs):
        self._add_portraits = add_portraits
        self._gc = goal_conditioned
        self._use_init_vec = use_initiation_vector
        self._max_option_steps = max_option_steps
        self._prev_action = -1
        self._iter = 0
        self._goal_obs = None
        self._goal_pos = None
        self._goal_init = None
        self._goal_states = [
            # bottom-left
            (np.array([-4, 0, 4]), np.pi),          # 0
            (np.array([-4, 0, 4]), -np.pi/2),       # 1
            # bottom-right
            (np.array([4, 0, 4]), 0),               # 2
            (np.array([4, 0, 4]), -np.pi/2),        # 3
            # top-right
            (np.array([4, 0, -4]), np.pi/2),        # 4
            (np.array([4, 0, -4]), 0),              # 5
            # top-left
            (np.array([-4, 0, -4]), np.pi/2),       # 6
            (np.array([-4, 0, -4]), np.pi),         # 7
            # bottom hall
            (np.array([-1, 0, 4]), np.pi),          # 8
            (np.array([1, 0, 4]), 0),               # 9
            # top hall
            (np.array([-1, 0, -4]), np.pi),         # 10
            (np.array([1, 0, -4]), 0),              # 11
            # left hall
            (np.array([-4, 0, -1]), np.pi/2),       # 12
            (np.array([-4, 0, 1]), -np.pi/2),       # 13
            # right hall
            (np.array([4, 0, -1]), np.pi/2),        # 14
            (np.array([4, 0, 1]), -np.pi/2),        # 15
            # goals
            (np.array([-4, 0, -9]), np.pi/2),       # 16
            (np.array([-6.07, 0, -13.76]), 2.198),  # 17
        ]

        super().__init__(
            max_episode_steps=np.inf,
            **kwargs
        )
        if self._gc:
            width = kwargs.get("obs_width", 80)
            height = kwargs.get("obs_height", 60)
            self.observation_space = spaces.Dict({
                "observation": spaces.Box(low=0, high=1, shape=(height*width*3,), dtype=np.float64),
                "achieved_goal": spaces.Box(low=0, high=1, shape=(height*width*3,), dtype=np.float64),
                "desired_goal": spaces.Box(low=0, high=1, shape=(height*width*3,), dtype=np.float64)
            })

        self.action_space = spaces.Discrete(5)
        self.action_names = ["to_gold", "centre", "clockwise", "anticlockwise", "door"]
        # [forward, centre, clockwise, anticlockwise, door]

    def _gen_world(self):
        # Bottom-left room
        self.room0 = self.add_rect_room(
            min_x=-7, max_x=-1,
            min_z=1, max_z=7,
            wall_tex='cardboard',
            # wall_tex='marble',  # old version
            name='room0',
        )
        # Bottom-right room
        self.room1 = self.add_rect_room(
            min_x=1, max_x=7,
            min_z=1, max_z=7,
            wall_tex='brick_wall',
            name='room1',
        )
        # Top-right room
        self.room2 = self.add_rect_room(
            min_x=1, max_x=7,
            min_z=-7, max_z=-1,
            wall_tex='wood',
            name='room2',
        )
        # Top-left room
        self.room3 = self.add_rect_room(
            min_x=-7, max_x=-1,
            min_z=-7, max_z=-1,
            wall_tex='rock',
            name='room3',
        )
        # Treasure room
        self.room4 = self.add_rect_room(
            min_x=-7, max_x=-1,
            min_z=-15, max_z=-9,
            # wall_tex='stucco',  Old version
            wall_tex='ceiling_tiles',
            name='room4'
        )

        # Add openings to connect the rooms together
        self.connect_rooms(self.room0, self.room1, min_z=3, max_z=5, max_y=2.2, name='hall_0_1')
        self.connect_rooms(self.room1, self.room2, min_x=3, max_x=5, max_y=2.2, name='hall_1_2')
        self.connect_rooms(self.room2, self.room3, min_z=-5, max_z=-3, max_y=2.2, name='hall_2_3')
        self.connect_rooms(self.room3, self.room0, min_x=-5, max_x=-3, max_y=2.2, name='hall_3_4')
        self.connect_rooms(self.room4, self.room3, min_x=-5, max_x=-3, max_y=2.2, name='vault_door')

        self.gold = self.place_entity(
            Box(color='yellow', size=0.5),
            room=self.room4,
            # pos=np.array([-4, .9, -12]),  # old version
            pos=np.array([-6.5, .9, -14.5]),
        )

        self.place_entity(self.agent, pos=np.array([self.room0.mid_x, 0, self.room0.mid_z]))

        if self._add_portraits:
            # portraits in the bottom-right room
            self.entities.append(
                ImageFrame(
                    pos=[7, 0.8+np.random.rand(), 2+np.random.rand()*4],
                    dir=np.pi,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[2+np.random.rand()*4, 0.8+np.random.rand(), 7],
                    dir=np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[1, 0.8+np.random.rand(), 5.5+np.random.rand()],
                    dir=0,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[1, 0.8+np.random.rand(), 1.5+np.random.rand()],
                    dir=0,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[1.5+np.random.rand(), 0.8+np.random.rand(), 1],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[5.5+np.random.rand(), 0.8+np.random.rand(), 1],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            # portraits in the bottom-left room
            self.entities.append(
                ImageFrame(
                    pos=[-7, 0.8+np.random.rand(), 2+np.random.rand()*4],
                    dir=0,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[-2-np.random.rand()*4, 0.8+np.random.rand(), 7],
                    dir=np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[-1, 0.8+np.random.rand(), 5.5+np.random.rand()],
                    dir=np.pi,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-1, 0.8+np.random.rand(), 1.5+np.random.rand()],
                    dir=np.pi,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-1.5-np.random.rand(), 0.8+np.random.rand(), 1],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-5.5-np.random.rand(), 0.8+np.random.rand(), 1],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            # portraits in the top-right room
            self.entities.append(
                ImageFrame(
                    pos=[7, 0.8+np.random.rand(), -2-np.random.rand()*4],
                    dir=np.pi,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[2+np.random.rand()*4, 0.8+np.random.rand(), -7],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[1, 0.8+np.random.rand(), -5.5-np.random.rand()],
                    dir=0,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[1, 0.8+np.random.rand(), -1.5-np.random.rand()],
                    dir=0,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[1.5+np.random.rand(), 0.8+np.random.rand(), -1],
                    dir=np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[5.5+np.random.rand(), 0.8+np.random.rand(), -1],
                    dir=np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            # portraits in the top-left room
            self.entities.append(
                ImageFrame(
                    pos=[-7, 0.8+np.random.rand(), -2-np.random.rand()*4],
                    dir=0,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=1.5))
            self.entities.append(
                ImageFrame(
                    pos=[-1.5-np.random.rand(), 0.8+np.random.rand(), -7],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-5.5-np.random.rand(), 0.8+np.random.rand(), -7],
                    dir=-np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-1, 0.8+np.random.rand(), -5.5-np.random.rand()],
                    dir=np.pi,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-1, 0.8+np.random.rand(), -1.5-np.random.rand()],
                    dir=np.pi,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-1.5-np.random.rand(), 0.8+np.random.rand(), -1],
                    dir=np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))
            self.entities.append(
                ImageFrame(
                    pos=[-5.5-np.random.rand(), 0.8+np.random.rand(), -1],
                    dir=np.pi/2,
                    tex_name=f"portraits/{np.random.choice(self.portraits)}",
                    width=0.75))

        self.ignore_carrying_intersect = True

        self.main_rooms = [room for room in self.rooms if 'room' in room.name]
        self.room_centres = [self.get_room_centre(room) for room in self.main_rooms]

        self.hallways = [room for room in self.rooms if 'hall' in room.name]
        self.hallway_centres = [self.get_room_centre(hallway) for hallway in self.hallways]
        self._prev_action = -1

    @property
    def observation(self):
        obs = self.render_obs().reshape(-1) / 255.0
        if self._gc:
            obs = {
                "observation": obs,
                "achieved_goal": obs,
                "desired_goal": self._goal_obs
            }
        return obs

    @property
    def info(self):
        dic = {
            "position": self.agent.pos,
            "direction": self.agent.dir,
            "state": self.agent.pos,
            "goal": self._goal_pos,
            "goal_init": self._goal_init
        }
        return dic

    def compute_reward(self, achieved_goal, desired_goal, info):
        if isinstance(info, dict):
            return self.reward(info)
        else:
            return np.array([self.reward(x) for x in info])

    def reward(self, info):
        return float(np.linalg.norm(info["state"] - info["goal"]) < 1)

    def step(self, action):
        experiences = {'obs': [], 'reward': [], 'done': [], 'info': [], 'skills_valid': []}

        step_count = 0
        if self.available_mask[action] == 1:
            for low_action in self.get_next_skill_action(action):
                obs, reward, done, _, info = super().step(low_action)

                if low_action == self.actions.pickup and self.agent.carrying is self.gold:
                    reward += self._reward()
                    done = True

                experiences['obs'].append(obs)
                experiences['reward'].append(reward)
                experiences['done'].append(done)
                experiences['info'].append(info)
                step_count += 1
                if self.render_mode == "human":
                    self.render()

                if done:
                    break
            obs, reward, done = experiences['obs'][-1], sum(experiences['reward']), experiences['done'][-1]
        else:
            step_count = 1
            obs = self.render_obs()
            reward = 0
            done = False

        obs = obs.reshape(-1) / 255.0

        if self._gc and self._goal_obs is not None:
            obs = {
                "observation": obs.copy(),
                "achieved_goal": obs.copy(),
                "desired_goal": self._goal_obs.copy()
            }

        info = self.info
        info["steps"] = step_count
        if self._gc:
            reward = self.reward(info)
            done = reward > 0.5

        self._prev_action = action
        self._iter += 1
        truncated = self._iter >= self._max_option_steps
        return obs, reward, done, truncated, info

    def reset(self, seed=None, options=None):
        self._iter = 0
        self._prev_action = -1
        obs, _ = super().reset(seed=seed, options=options)
        obs = obs.reshape(-1) / 255.0
        if self._gc:
            info = self.info
            idx = np.random.randint(0, len(self._goal_states))
            pos, direc = self._goal_states[idx]
            eps = np.random.randn(3) * 0.1
            eps[1] = 0
            pos = pos + eps
            direc = direc + np.random.randn() * 0.05
            self.place_entity(self.agent, pos=pos, dir=direc)
            self._goal_obs = self.render_obs().reshape(-1) / 255.0
            self._goal_pos = pos
            self._goal_init = self.available_mask
            self.place_entity(self.agent, pos=info["position"], dir=info["direction"])
            obs = {
                "observation": obs.copy(),
                "achieved_goal": obs.copy(),
                "desired_goal": self._goal_obs.copy()
            }
        info = self.info
        return obs, info

    def get_next_skill_action(self, skill):
        if skill == 0:  # walk to gold
            for timestep in range(60):
                yield self.skill_go_to_position(self.gold.pos)
            yield self.actions.pickup
        elif skill == 1:  # centre of room
            for timestep in range(40):
                yield self.skill_go_to_centre_of_current_room()
        elif skill == 2:  # clockwise
            # if in same direction, walk to centre to avoid collisions
            if skill == self._prev_action:
                for timestep in range(20):
                    yield self.skill_go_to_centre_of_current_room()
            for timestep in range(40):
                yield self.skill_go_to_hallway(direction='clockwise')
            for timestep in range(10):
                yield self.actions.move_forward
        elif skill == 3:  # anticlockwise
            # if in same direction, walk to centre to avoid collisions
            if skill == self._prev_action:
                for timestep in range(20):
                    yield self.skill_go_to_centre_of_current_room()
            for timestep in range(40):
                yield self.skill_go_to_hallway(direction='anticlockwise')
            for timestep in range(10):
                yield self.actions.move_forward
        elif skill == 4:
            for timestep in range(20):
                yield self.skill_go_to_centre_of_current_room()
            for timestep in range(40):
                yield self.skill_go_to_vault_door()
            for timestep in range(10):
                yield self.actions.move_forward
        else:
            yield skill

    def skill_go_to_centre_of_current_room(self):
        rooms = [room for room in self.rooms if 'room' in room.name]
        room_centres = [self.get_room_centre(room) for room in rooms]
        dist_to_centres = [self.dist(self.agent.pos, room_ctr) for room_ctr in room_centres]
        current_room = np.argmin(dist_to_centres)
        return self.skill_go_to_position(room_centres[current_room])

    def skill_go_to_position(self, target_pos):
        vector_to_centre = target_pos - self.agent.pos
        dist_to_centre = np.linalg.norm(vector_to_centre)

        if dist_to_centre > 0.4:
            angle_to_centre = self.vector2angle(vector_to_centre)
            angle_to_face_centre = self.relative_angle(angle_to_centre)
            if np.abs(angle_to_face_centre) < .2:
                return self.actions.move_forward
            elif angle_to_face_centre > 0:
                return self.actions.turn_left
            else:
                return self.actions.turn_right
        else:
            return None

    def skill_go_to_hallway(self, direction='clockwise'):
        hallways = [room for room in self.rooms if 'hall' in room.name]
        hallway_centres = [self.get_room_centre(hallway) for hallway in hallways]
        agent_angle = self.vector2angle(self.agent.pos)
        hallway_angles = [self.vector2angle(pos) for pos in hallway_centres]
        if direction in ['clockwise', 'cw']:
            next_hallway = np.argmin([np.mod(agent_angle - angle, 2 * np.pi) for angle in hallway_angles])
        else:
            next_hallway = np.argmin([np.mod(angle - agent_angle, 2 * np.pi) for angle in hallway_angles])
        return self.skill_go_to_position(hallway_centres[next_hallway])

    def skill_go_to_vault_door(self):
        rooms = [room for room in self.rooms if 'room' in room.name]
        room_centres = [self.get_room_centre(room) for room in rooms]
        dist_to_centres = [self.dist(self.agent.pos, room_ctr) for room_ctr in room_centres]
        current_room = rooms[np.argmin(dist_to_centres)]
        if current_room.name != 'room3':
            return False
        else:
            vault_door = [room for room in self.rooms if room.name == 'vault_door'][0]
            return self.skill_go_to_position(self.get_room_centre(vault_door))

    @staticmethod
    def get_room_centre(room):
        return np.array([room.mid_x, 0, room.mid_z])

    @staticmethod
    def vector2angle(vector):
        dx, _, dy = vector
        return np.arctan2(-dy, dx)

    def relative_angle(self, angle):
        return self.wrap_angle(angle - self.agent.dir)

    @staticmethod
    def dist(pos1, pos2):
        return np.linalg.norm(pos1 - pos2)

    @staticmethod
    def wrap_angle(angle):
        return np.angle(np.exp(1j * angle))


if __name__ == '__main__':

    env = Vault(render_mode="human", view="top")
    observation, info = env.reset()
    # plan = [3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 4, 0]
    for i in range(1000):
        action = env.sample_action()
        # action = plan[i]
        observation, reward, terminated, truncated, info = env.step(action)
        if terminated or truncated:
            observation, info = env.reset()
    env.close()
