from glob import glob
import os
import shutil
from typing import Any, Dict, Optional, Tuple

import cv2
from home_robot.mapping.semantic.constants import MapConstants
from home_robot.navigation_planner.discrete_planner import DiscretePlanner
import nvtx
import imageio
import omegaconf
import numpy as np
from natsort import natsorted
from scipy.spatial.transform.rotation import Rotation
import torch
import skfmm

from home_robot.agent.ovmm_agent.ovmm_agent import OpenVocabManipAgent, Skill
from home_robot.core.interfaces import ContinuousNavigationAction, DiscreteNavigationAction, Observations, ContinuousFullBodyAction
from vlfm.utils.geometry_utils import get_point_cloud, transform_points

from helios.agent.gaze.hierarchical_gaze import HierarchicalGaze
from helios.agent.planner.fm2_planner import Fm2Planner, get_times
from helios.agent.nav.vlfm_nav import VlfmNav
from helios.env.privileged_env import PrivEnv
from helios.agent.gaze.utils.map_interface import MapInterface
from helios.agent.utils.utils import gps_to_px, px_to_gps, xy_to_px, px_to_xy, obs_to_tf
from mpl_toolkits.mplot3d import Axes3D

from helios.agent.utils.visualization import write_frame, apply_mask, process_and_write_frame, apply_mask_abs
from threading import Thread

from matplotlib import pyplot as plt
from vlfm.utils.img_utils import monochannel_to_inferno_rgb

import time

import hydra

class HELIOS(OpenVocabManipAgent):
    def __init__(
        self,
        config: omegaconf.DictConfig,
        n_semantic_channels: int,
        planner: Fm2Planner,
        nav: VlfmNav,
        gaze: HierarchicalGaze,
        platform: str,
        device_id: int = 0,
        use_our_pick: bool = True,
        use_our_place: bool = True,
        stop_on_collision: bool = True
    ):
        hydra.core.global_hydra.GlobalHydra.instance().clear()
        super().__init__(config, device_id=device_id)
        self.n_semantic_channels = n_semantic_channels
        self.our_planner = planner
        self.nav = nav
        self.gaze = gaze
        self.recep_indices = [0, 0]
        if self.gaze:
            if self.nav:
                map_shape = self.nav.map_policy.obstacle_map._map.shape
                objmap_px_per_m = self.nav.map_policy.obstacle_map.pixels_per_meter
            else:
                map_shape = self.semantic_map.global_map.shape
                objmap_px_per_m = (
                    self.semantic_map.global_map_size_cm * 100 / map_shape[-1]
                )
            if isinstance(self.gaze, HierarchicalGaze):
                self.map_interface = MapInterface(map_shape, objmap_px_per_m)
                self.gaze.set_map_interface(self.map_interface)
                self.gaze.set_objectnav_method(self.nav)
        if self.nav:
            self.px_per_m = self.nav.map_policy.obstacle_map.pixels_per_meter
        else:
            map_shape = self.semantic_map.global_map.shape
            self.px_per_m = self.semantic_map.global_map_size_cm * 100 / map_shape[-1]

        self.last_skill = None
        self.viz_extra_t = 0

        self.platform = platform

        self.img_before_pick = None
        self.pose_before_pick = None
        self.pose_after_pick = None

        self.repeats=0

        self.use_our_pick = use_our_pick
        self.use_our_place = use_our_place
        self.stop_on_collision = stop_on_collision

        self.prev_time = -1

    def set_episode_key(self, episode_key: str):
        if hasattr(self.nav, 'set_episode_key'):
            self.nav.set_episode_key(episode_key=episode_key)
        if hasattr(self.gaze, 'set_episode_key'):
            self.gaze.set_episode_key(episode_key=episode_key)
        if hasattr(self.gaze, 'set_viz_save_dir'):
            self.gaze.set_viz_save_dir(f'{self.config.habitat_baselines.image_dir[:-6]}/eig_images/{self.gaze.object_state_map.episode_key}')

    def set_envs(self, envs: PrivEnv):
        if hasattr(self.our_planner, 'set_envs'):
            self.our_planner.set_envs(envs)

    @nvtx.annotate("HELIOS.reset")
    def reset(self):
        super().reset()
        if self.nav:
            self.nav.reset()
        if self.gaze:
            self.gaze.reset()
        self.n_collisions = 0
        self.transition_pause_counter = -1
        self.next_skill = None
        self.action = None

        self.last_skill = None
        self.viz_extra_t = 0

        self.img_before_pick = None
        self.pose_before_pick = None
        self.pose_after_pick = None

        self.repeats=0

        self.prev_time = -1

    def update(
        self,
        obs: Observations,
        info: Dict[str, Any],
        skip_map_update: bool = False
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        # update info with agent
        info['image_dir'] = self.config.habitat_baselines.image_dir
        info['video_dir'] = self.config.habitat_baselines.video_dir
        info['timestep'] = self.timesteps[0]
        info['our_planner'] = not isinstance(self.our_planner, DiscretePlanner)
        if self.stop_on_collision:
            info['collision'] = check_collision(
                self.last_poses[0],
                np.array([obs.gps[0], obs.gps[1], obs.compass[0]]),
                self.action,
                self.px_per_m,
                self.platform
            )
        else:
            info['collision'] = False
        self.last_poses[0] = np.array([obs.gps[0], obs.gps[1], obs.compass[0]])

        skip_map_update = skip_map_update or (self.timesteps[0] == self.prev_time)
        
        if self.nav:
            self.nav.update(obs, info, skip_map_update)
            info['height_map'] = self.nav.map_policy.obstacle_map.height_map
        if self.gaze and not (self.timesteps[0] == self.prev_time):
            self.gaze.update(obs, info)
        if self.nav and self.visualize:
            if skip_map_update:
                info['merged_object_map'] = self.gaze.object_state_map
            self.visualization_update(obs, info)

        self.prev_time = self.timesteps[0]

    def visualization_update(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ):
        info['timestep'] = self.timesteps[0]
        self.nav.habitat_vis.collect_data(info, self.nav.map_policy)

        if hasattr(self.gaze, 'visualize_3dgs') and self.gaze.visualize_3dgs:
            if self.states[0] == Skill.NAV_TO_OBJ:
                vm_idx = 0
                curr_skill = 'NAV_TO_OBJ'
            if self.states[0] == Skill.GAZE_AT_OBJ:
                vm_idx = 0
                curr_skill = 'GAZE_AT_OBJ'
            if self.states[0] == Skill.PICK:
                vm_idx = 0
                curr_skill = 'PICK'
            if self.states[0] == Skill.NAV_TO_REC:
                vm_idx = 1
                curr_skill = 'NAV_TO_REC'
            if self.states[0] == Skill.GAZE_AT_REC:
                vm_idx = 1
                curr_skill = 'GAZE_AT_REC'
            if self.states[0] == Skill.PLACE:
                vm_idx = 1
                curr_skill = 'PLACE'
            if self.states[0] == Skill.FALL_WAIT:
                vm_idx = 1
                curr_skill = 'FALL_WAIT'

            render, render_sem, render_u, render_instances, render_d, render_orig = self.gaze.object_state_map.get_renderings(obs)

            #save map images for scene rep -- use for scene representation visualization figure
            # om = self.nav.map_policy.obstacle_map
            # vm = self.nav.map_policy.value_map._value_map[..., 0]
            # vm[om.explored_area == 0] =0
            # map_img = np.flipud(vm)
            # zero_mask = map_img == 0
            # map_img[zero_mask] = np.max(map_img)
            # map_img = monochannel_to_inferno_rgb(map_img)
            # map_img[zero_mask] = (255, 255, 255)

            # value_map1 = cv2.cvtColor(map_img, cv2.COLOR_BGR2RGB)

            # vm = self.nav.map_policy.value_map._value_map[..., 1]
            # vm[om.explored_area == 0] =0
            # map_img = np.flipud(vm)
            # zero_mask = map_img == 0
            # map_img[zero_mask] = np.max(map_img)
            # map_img = monochannel_to_inferno_rgb(map_img)
            # map_img[zero_mask] = (255, 255, 255)

            # value_map2 = cv2.cvtColor(map_img, cv2.COLOR_BGR2RGB)

            # vis_img = np.ones((*om._map.shape[:2], 3), dtype=np.uint8) * 255
            # # Draw explored area in light green
            # vis_img[om.explored_area == 1] = (200, 255, 200)
            # # Draw unnavigable areas in gray
            # vis_img[om._navigable_map == 0] = om.radius_padding_color
            # # Draw obstacles in black
            # vis_img[om._map == 1] = (0, 0, 0)

            # occ_map = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)

            # plt.imsave(f"raw_map_vis/{info['timestep']}_value1.png", value_map1)
            # plt.imsave(f"raw_map_vis/{info['timestep']}_value2.png", value_map2)
            # plt.imsave(f"raw_map_vis/{info['timestep']}_occ.png", occ_map)

            #TODO: pause on transitions
            # Add objects to ocupancy map
            policy_info = self.nav.map_policy.get_policy_info(visualize=True)
            occ = policy_info['obstacle_map'].copy()
            occ = np.flipud(occ)
            occ[info['wall_map']] = (125, 0, 255)
            gp = []
            try:
                for c in [obs.task_observations['object_goal'],obs.task_observations['start_recep_goal'],obs.task_observations['end_recep_goal']]:
                    if hasattr(self.gaze, 'simple_version') and self.gaze.simple_version:
                        gp += [self.gaze.object_state_map.map==c]
                    else:
                        gp += [self.gaze.object_state_map.semantic_maps[c].cpu().numpy()]
                occ = apply_mask_abs(occ,gp).astype(np.uint8)
            except:
                pass

            third_person = info['third_person_image']
            if third_person is None:
                third_person = obs.rgb

            gaze_target = ("goal" if self.gaze.gaze_target == obs.task_observations['object_goal']
                else "start rec" if self.gaze.gaze_target == obs.task_observations['start_recep_goal']
                else "end rec" if self.gaze.gaze_target == obs.task_observations['end_recep_goal']
                else "none")

            if isinstance(self.gaze, HierarchicalGaze):
                if hasattr(self.gaze, 'simple_version') and self.gaze.simple_version:
                    occ[~self.gaze.object_state_map.valid_pick] = (255, 125, 125)
                else:
                    occ[~self.gaze.object_state_map.valid_pick.cpu().numpy()] = (255, 125, 125)

            if not self.gaze.wayposes is None: 
                for i in range(len(self.gaze.wayposes)):
                    px_goal = gps_to_px(self.gaze.wayposes[i][:2], info['pixels_per_meter'], info['map_size'])
                    if i == self.gaze.i_waypose:
                        cv2.circle(occ, tuple(px_goal[::-1]), 5, (200, 0, 0), 2)
                    else:
                        cv2.circle(occ, tuple(px_goal[::-1]), 5, (200, 0, 200), 2)

            if 'bounds' in info.keys():
                y0, y1, x0, x1 = info['bounds']
                occ = np.flipud(occ[y0:y1, x0:x1])
            else:
                occ = np.flipud(occ)

            if isinstance(self.gaze, HierarchicalGaze):
                gaze_str = f'Gaze target: {gaze_target} ({self.gaze.wayposes_idx})'
            else:
                gaze_str = f'Gaze target: {gaze_target}'

            img_dict = {
                'rgb': obs.rgb,
                '3dgs': render, 
                '3dgs_depth': cv2.cvtColor((render_d/self.gaze.object_state_map.max_depth* 255.0).astype(np.uint8),cv2.COLOR_GRAY2RGB),
                '3dgs_orig': render_orig,
                '3dgs_seg': render_sem, 
                '3dgs_instances': render_instances,
                '3dgs_uncertainty': render_u, 
                'depth': cv2.cvtColor((obs.depth/self.gaze.object_state_map.max_depth* 255.0).astype(np.uint8), cv2.COLOR_GRAY2RGB),
                'third_person': third_person,
                'occ': occ, 
                'value': policy_info['value_maps'][vm_idx]
            }
            
            text_dict = {
                'value_map': policy_info['target_objects'][vm_idx],
                'instruction': f'"{info["goal_name"]}"',
                'current_skill': f'Current skill is {curr_skill}',
                'gaze_target': gaze_str,
                'start_rec': obs.task_observations['start_recep_name'],
                'end_rec': obs.task_observations['place_recep_name'],
                'goal_obj': obs.task_observations['object_name'],
            }

            self.nav.habitat_vis.add_and_start_thread(Thread(target=process_and_write_frame, args=(
                f"{self.config.habitat_baselines.image_dir}/{self.gaze.object_state_map.episode_key}/combined_{info['timestep']}.png",
                obs, img_dict, text_dict, self.repeats
            )))
            self.repeats = 0

    @nvtx.annotate("HELIOS._nav_to_obj")
    def _nav_to_obj(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        # act
        if self.nav:   
            # update
            self.update(obs, info)         
            if info['collision']:
                raise RuntimeError(f'_nav_to_obj() collision detected {info["collision"]}, {self.action}')
            self.action = None
            if info['timestep']>=28: #force to stay in NAV_TO_OBJ for init
                has_object_goal = self.gaze.has_obj(obs.task_observations['object_goal'])
                has_start_recep_goal = self.gaze.has_obj(obs.task_observations['start_recep_goal'])
                if has_object_goal:
                    self.action = None
                    self.gaze.gaze_target = None
                    return self.action, info, Skill.GAZE_AT_OBJ
                elif isinstance(self.gaze, HierarchicalGaze):
                    self.gaze.end_on_gaze = False
                    if self.gaze.before_frontier or self.nav.chosen_frontier is None:
                        self.action, info, next_skill = self.gaze.act(obs,info,self.our_planner)
                elif has_start_recep_goal:
                    self.action = None
                    self.gaze.gaze_target = None
                    return self.action, info, Skill.GAZE_AT_OBJ
            if self.action is None:
                try:
                    self.action, info, next_skill = self.nav.act(obs, info, 0, planner=self.our_planner)
                except ValueError:
                    self.action, info, next_skill = self.nav.act(obs, info, 0, planner=self.our_planner, force_new=True)
            if info['timestep']>=28 and not self.action is None and hasattr(self.action, 'xyt') \
                and not isinstance(self.action, ContinuousFullBodyAction):
                joints = np.zeros(10) 
                joints[9] = (-np.pi / 12) - obs.joint[9] 
                self.action = ContinuousFullBodyAction(
                    joints=joints,
                    xyt=self.action.xyt,
                )
            self.nav.update_ovmm_vis_info(obs, info)
        else:
            self.action, info, next_skill = super()._nav_to_obj(obs, info)
        return self.action, info, next_skill

    @nvtx.annotate("HELIOS._gaze_at_obj")
    def _gaze_at_obj(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        # print("JOINT ANGLE: ", obs.joint[9], (-np.pi / 12))
        if self.gaze:
            # update
            self.update(obs, info)
            if self.gaze.has_obj(obs.task_observations['object_goal']):
                self.gaze.gaze_target = obs.task_observations['object_goal']
                if isinstance(self.gaze, HierarchicalGaze):
                    if not self.gaze.end_on_gaze:
                        self.gaze.end_on_gaze = True
                        self.gaze.wayposes = None #over-write any chosen waypoints to force it to go to obj
                        self.gaze.has_obj_return_false = False
            if self.gaze.gaze_target is None:
                if self.gaze.has_obj(obs.task_observations['start_recep_goal']) and not isinstance(self.gaze, HierarchicalGaze):
                    self.gaze.gaze_target = obs.task_observations['start_recep_goal']
                else:
                    return None, info, Skill.NAV_TO_OBJ
            self.action, info, next_skill = self.gaze.act(obs, info, self.our_planner)
            self.nav.update_ovmm_vis_info(obs, info)
            if info['collision']:
                raise RuntimeError(f'_gaze_at_obj() collision detected {info["collision"]}, {self.action}')
        else:
            self.action, info, next_skill = super()._gaze_at_obj(obs, info)
        return self.action, info, next_skill
    
    @nvtx.annotate("HELIOS._pick")
    def _pick(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        if self.use_our_pick:
            # print("PICK STEPS: ", self.timesteps[0], self.pick_start_step[0])
            pick_step = self.timesteps[0] - self.pick_start_step[0]
            if pick_step==0 or pick_step==2 or pick_step==3: #TODO properly check when skill has ended
                self.repeats = 5
            # update
            self.update(obs, info, skip_map_update=True)
            self.nav.update_ovmm_vis_info(obs, info)

            
            if pick_step == 0:
                # self.img_before_pick = obs.rgb
                # self.pose_before_pick= (np.array([obs.gps[0], obs.gps[1], obs.compass[0]]), obs.joint)
                self.action = DiscreteNavigationAction.MANIPULATION_MODE
                next_skill = None 
                if self.gaze and hasattr(self.gaze, 'has_obj_return_false'):
                    self.gaze.has_obj_return_false = False
            elif pick_step == 1: # add extend arm to half
                joints = np.zeros(10)
                xyt = np.zeros(3)
                # joints[9] = (-np.pi / 12) - obs.joint[9]
                # joints[0] = 0.25 - obs.joint[:4].sum()
                # joints[4] = 1.1 - obs.joint[4]
                # joints[0] = 0.4 - obs.joint[:4].sum()
                # joints[4] = 0.8 - obs.joint[4]

                obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                    -np.cos(obs.compass[0]) * 0.7 * 40,
                    np.sin(obs.compass[0]) * 0.7 * 40,
                ]).astype(np.int32) #note that compass reading seem to change when in MANIPULATION_MODE, the offset here will differ if not in this mode
                px_window_half_size = 10 #5
                obj_height = np.max(self.nav.map_policy.obstacle_map.height_map[
                    obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                    obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                ]) + 0.09
                # print("PICK OBJ: ", obj_px, obj_height)
                joints[0] = 0.4 - obs.joint[:4].sum() 
                joints[4] = np.clip(obj_height, 0.5, 1.1) - obs.joint[4] # lift to maximum

                # plt.imsave('height_map.png', self.nav.map_policy.obstacle_map.height_map)

                # px_window_half_size = 50
                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     np.cos(obs.compass[0]) * 0.7 * 40,
                #     np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('pick_height_map1.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     -np.cos(obs.compass[0]) * 0.7 * 40,
                #     np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('pick_height_map2.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     np.cos(obs.compass[0]) * 0.7 * 40,
                #     -np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('pick_height_map3.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     -np.cos(obs.compass[0]) * 0.7 * 40,
                #     -np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('pick_height_map4.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                self.action = ContinuousFullBodyAction(joints=joints, xyt=xyt)
                next_skill = None
            elif pick_step == 2:
                self.action = DiscreteNavigationAction.SNAP_OBJECT
                next_skill = None
            elif pick_step == 3:
                self.action = DiscreteNavigationAction.NAVIGATION_MODE
                next_skill = None
            elif pick_step == 4:
                if obs.task_observations["prev_grasp_success"]:
                    self.action = None
                    next_skill = Skill.NAV_TO_REC
                else:
                    if self.gaze:
                        if isinstance(self.gaze, HierarchicalGaze):
                            self.gaze.object_state_map.remove_invalid_pick(self.gaze.wayposes_idx)
                        else:
                            self.gaze.object_state_map.remove_invalid_pick(obs, info)
                        self.gaze.gaze_target = None
                    self.action = None
                    next_skill = Skill.NAV_TO_OBJ

            #     xyt = -np.array([obs.gps[0], obs.gps[1], obs.compass[0]])+self.pose_before_pick[0]
            #     joints = -obs.joint + self.pose_before_pick[1]
            #     self.action = ContinuousFullBodyAction(joints=joints, xyt=xyt)
            #     next_skill = None
            #     self.pose_after_pick = (np.array([obs.gps[0], obs.gps[1], obs.compass[0]]), obs.joint)
            # else:
            #     xyt = -np.array([obs.gps[0], obs.gps[1], obs.compass[0]])+self.pose_after_pick[0]
            #     joints = -obs.joint + self.pose_after_pick[1]
            #     print("IS HOLDING (PICK): ", obs.task_observations["prev_grasp_success"])

            #     # print("CHANGE IN POSE: ", xyt, joints, np.all(np.abs(xyt)<1e-3) and np.all(np.abs(joints)<1e-3))
                
            #     if np.all(np.abs(xyt)<1e-3) and np.all(np.abs(joints)<1e-3):
            #         # check if image after pick has changed
            #         diff_img = self.img_before_pick-obs.rgb #TODO: use something which can deal with small changes in localization etc
            #         # print("DIFF_IMG MEAN: ", np.mean(diff_img)) #Failed: 0.40482747395833335, 0.14716362847222222; Successful: 98.57625434027777, 3.403154296875
            #         # plt.imsave("pick_img.png", np.hstack([self.img_before_pick,obs.rgb]))
            #         if np.mean(diff_img) < 1: #pick failed
            #             if self.gaze:
            #                 if isinstance(self.gaze, HierarchicalGaze):
            #                     self.gaze.object_state_map.remove_invalid_pick(self.gaze.wayposes_idx)
            #                 else:
            #                     self.gaze.object_state_map.remove_invalid_pick(obs, info)
            #                 self.gaze.gaze_target = None
            #             self.action = None
            #             next_skill = Skill.NAV_TO_OBJ
            #         else:
            #             self.action = None
            #             next_skill = Skill.NAV_TO_REC
            #         self.repeats = 5
            #     else:
            #         xyt = -np.array([obs.gps[0], obs.gps[1], obs.compass[0]])+self.pose_before_pick[0]
            #         joints = -obs.joint + self.pose_before_pick[1]
            #         self.action = ContinuousFullBodyAction(joints=joints, xyt=xyt)
            #         next_skill = None
            #         self.pose_after_pick = (np.array([obs.gps[0], obs.gps[1], obs.compass[0]]), obs.joint)
        else:
            self.action, info, next_skill = super()._pick(obs, info)
        return self.action, info, next_skill
    
    @nvtx.annotate("HELIOS._nav_to_rec")
    def _nav_to_rec(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        if self.gaze:
            # update
            self.update(obs, info)
            has_end_recep_goal = self.gaze.has_obj(obs.task_observations['end_recep_goal'])
            self.gaze.gaze_target = obs.task_observations['end_recep_goal']
            # print("HAS END RECP: ", has_end_recep_goal)
            if info['collision']:
                raise RuntimeError(f'_nav_to_rec() collision detected {info["collision"]}, {self.action}')
            elif has_end_recep_goal: # if obj found, transition to gaze
                self.action = None
                if hasattr(self.gaze, 'wayposes'):
                    self.gaze.wayposes = None
                next_skill = Skill.GAZE_AT_REC
                self.nav.update_ovmm_vis_info(obs, info)
            else:
                self.action = None
                if isinstance(self.gaze, HierarchicalGaze):
                    self.gaze.end_on_gaze = False
                    if self.nav.chosen_frontier is None:
                        self.action, info, next_skill = self.gaze.act(obs,info, self.our_planner)
                if self.action is None:
                    try:
                        self.action, info, next_skill = self.nav.act(obs, info, 1, planner=self.our_planner)
                    except ValueError:
                        self.action, info, next_skill = self.nav.act(obs, info, 1, planner=self.our_planner, force_new=True)
                self.nav.update_ovmm_vis_info(obs, info)
        else:
            self.action, info, next_skill = super()._nav_to_rec(obs, info)
        if self.gaze and not self.action is None and not isinstance(self.action, ContinuousFullBodyAction):
            joints = np.zeros(10)
            joints[9] = (-np.pi / 12) - obs.joint[9] 
            self.action = ContinuousFullBodyAction(
                joints=joints,
                xyt=self.action.xyt,
            )
        return self.action, info, next_skill

    @nvtx.annotate("HELIOS._gaze_at_rec")
    def _gaze_at_rec(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        if self.gaze:
            # update
            self.update(obs, info)
            if self.gaze.gaze_target is None:
                self.gaze.gaze_target = obs.task_observations['end_recep_goal']
            if isinstance(self.gaze, HierarchicalGaze):
                has_end_recep_goal = self.gaze.has_obj(obs.task_observations['end_recep_goal'])
                if has_end_recep_goal and not self.gaze.end_on_gaze:
                    self.gaze.end_on_gaze = True
                    self.wayposes = None #over-write any chosen waypoints to force it to go to obj
            self.action, info, next_skill = self.gaze.act(obs, info, self.our_planner)
            self.nav.update_ovmm_vis_info(obs, info)
            if info['collision']:
                raise RuntimeError(f'_gaze_at_rec() collision detected {info["collision"]}, {self.action}')
        else:
            self.action, info, next_skill = super()._gaze_at_rec(obs, info)
        return self.action, info, next_skill
    
    @nvtx.annotate("HELIOS._place")
    def _place(
        self,
        obs: Observations,
        info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        if self.use_our_place:
            self.update(obs, info, True)
            self.nav.update_ovmm_vis_info(obs, info)
            place_step = self.timesteps[0] - self.place_start_step[0]
            if place_step==2:
                self.repeats=5
            # print("IS HOLDING (PLACE): ", obs.task_observations["prev_grasp_success"])
            
            
            if place_step == 0:
                self.action = DiscreteNavigationAction.MANIPULATION_MODE
                next_skill = None
            elif place_step == 1: # lift and extend arm to maximum
                obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                    -np.cos(obs.compass[0]) * 0.7 * 40,
                    np.sin(obs.compass[0]) * 0.7 * 40,
                ]).astype(np.int32)
                px_window_half_size = 10 #5
                obj_height = np.max(self.nav.map_policy.obstacle_map.height_map[
                    obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                    obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                ]) + 1.2 #TODO: adapt to object height. We also need to change wrist orientation in some cases
                # colormap = matplotlib.colormaps['viridis']
                # height_map = self.nav.map_policy.obstacle_map.height_map[obj_px[0]-80:obj_px[0]+80, obj_px[1]-80:obj_px[1]+80]
                # height_map = (height_map - np.min(height_map)) / (np.max(height_map) - np.min(height_map))
                # height_map = (colormap(height_map)[..., :3] * 255).astype(np.uint8)
                # imageio.imwrite('height_map_obj.png', height_map)
                # self_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size)
                # height_map = self.nav.map_policy.obstacle_map.height_map[self_px[0]-80:self_px[0]+80, self_px[1]-80:self_px[1]+80]
                # height_map = (height_map - np.min(height_map)) / (np.max(height_map) - np.min(height_map))
                # height_map = (colormap(height_map)[..., :3] * 255).astype(np.uint8)
                # imageio.imwrite('height_map_self.png', height_map)
                # print(f'obj_px: {obj_px}, obj_height: {obj_height}')
                joints = np.zeros(10)
                xyt = np.zeros(3)
                joints[0] = 0.52 - obs.joint[:4].sum() # extend to maximum
                joints[4] = np.clip(obj_height, 0.5, 1.1) - obs.joint[4] # lift to maximum
                self.action = ContinuousFullBodyAction(joints=joints, xyt=xyt)
                next_skill = None

                # plt.imsave('height_map.png', self.nav.map_policy.obstacle_map.height_map)

                # px_window_half_size = 50
                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     np.cos(obs.compass[0]) * 0.7 * 40,
                #     np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('place_height_map1.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     -np.cos(obs.compass[0]) * 0.7 * 40,
                #     np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('place_height_map2.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     np.cos(obs.compass[0]) * 0.7 * 40,
                #     -np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('place_height_map3.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

                # obj_px = gps_to_px(obs.gps, self.nav.map_policy.obstacle_map.pixels_per_meter, self.nav.map_policy.obstacle_map.size) + np.array([
                #     -np.cos(obs.compass[0]) * 0.7 * 40,
                #     -np.sin(obs.compass[0]) * 0.7 * 40,
                # ]).astype(np.int32)

                # plt.imsave('place_height_map4.png', self.nav.map_policy.obstacle_map.height_map[
                #     obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
                #     obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
                # ])

            elif place_step == 2:
                self.action = DiscreteNavigationAction.DESNAP_OBJECT
                next_skill = None
            elif place_step == 3: # reset to default
                joints = np.zeros(10)
                xyt = np.zeros(3)
                joints[0] = 0 - obs.joint[:4].sum() 
                joints[4] = 0.775 - obs.joint[4]
                self.action = ContinuousFullBodyAction(joints=joints, xyt=xyt)
                next_skill = None
            else:
                self.action = None
                next_skill = Skill.FALL_WAIT
                self.repeats=5
        else:
            self.action, info, next_skill = super()._place(obs, info)

        return self.action, info, next_skill
    
    def _fall_wait(
        self, obs: Observations, info: Dict[str, Any]
    ) -> Tuple[DiscreteNavigationAction, Any, Optional[Skill]]:
        if self.visualize and self.nav:
            self.update(obs, info)
            self.nav.update_ovmm_vis_info(obs, info)
            if self.timesteps[0] - self.fall_wait_start_step[0] < self._fall_wait_steps:
                action = DiscreteNavigationAction.EMPTY_ACTION
            else:
                action = DiscreteNavigationAction.STOP
            return action, info, None
        else:
            return super()._fall_wait(obs, info)
    
    def end_eval(self):
        if self.gaze and hasattr(self.gaze.object_state_map, 'stop_backend'):
            self.gaze.object_state_map.stop_backend()
    
    @nvtx.annotate("HELIOS.generate_video")
    def generate_video(self, dir, i_episode, current_episode_key):
        if self.nav:
            self.nav.habitat_vis.flush_frames()
        if self.gaze:
            if hasattr(self.gaze, 'object_state_map') and hasattr(self.gaze.object_state_map, 'model') \
                and hasattr(self.gaze.object_state_map.model, 'pause'):
                self.gaze.object_state_map.model.pause()
        os.makedirs(dir, exist_ok=True)
        with imageio.get_writer(
            f"{dir}/{i_episode:04d}-{current_episode_key}_ovmm.mp4",
            fps=self.config.habitat_baselines.video_fps,
        ) as writer:
            for snapshot_path in natsorted(glob(f"{self.config.habitat_baselines.image_dir}/{current_episode_key}/snapshot_*.png")):
                snapshot = imageio.imread(snapshot_path)
                planner_snapshot_path = snapshot_path.replace('snapshot', 'planner_snapshot')
                planner_snapshot_path = planner_snapshot_path.split('_')
                planner_snapshot_path[-1] = str(int(planner_snapshot_path[-1].replace('.png', ''))) + '.png'
                planner_snapshot_path = '_'.join(planner_snapshot_path)
                if os.path.exists(planner_snapshot_path):
                    planner_snapshot = imageio.imread(planner_snapshot_path)
                    if len(planner_snapshot.shape) == 2:
                        planner_snapshot = np.stack([
                            planner_snapshot[:, planner_snapshot.shape[0] * i_channel:planner_snapshot.shape[0] * (i_channel + 1)]
                            for i_channel in range(3)
                        ], axis=-1)
                    planner_snapshot = cv2.resize(
                        planner_snapshot,
                        (snapshot.shape[0], snapshot.shape[0])
                    )
                else:
                    planner_snapshot = np.zeros((snapshot.shape[0], snapshot.shape[0], 3), dtype=np.uint8)
                combined = np.concatenate([snapshot, planner_snapshot], axis=1)
                writer.append_data(combined)
        if self.gaze and hasattr(self.gaze, 'visualize_3dgs') and self.gaze.visualize_3dgs and self.gaze.visualize_3dgs:
            with imageio.get_writer(
                f"{dir}/{i_episode:04d}-{current_episode_key}.mp4",
                fps=self.config.habitat_baselines.video_fps,
            ) as writer:
                for snapshot_path in natsorted(glob(f"{self.config.habitat_baselines.image_dir}/{current_episode_key}/combined_*.png")):
                    snapshot = imageio.imread(snapshot_path)
                    writer.append_data(snapshot)
            instance_paths = natsorted(glob(f"{self.config.habitat_baselines.image_dir}/{current_episode_key}/instance_*.png"))
            if len(instance_paths) > 0:
                with imageio.get_writer(
                    f"{dir}/{i_episode:04d}-{current_episode_key}_instance.mp4",
                    fps=self.config.habitat_baselines.video_fps,
                ) as writer:
                    for instance_path in instance_paths:
                        instance = imageio.imread(instance_path)
                        writer.append_data(instance)
        if self.nav and hasattr(self.nav, 'generate_video'):
            self.nav.generate_video(dir, i_episode, current_episode_key=current_episode_key)
        if self.gaze and hasattr(self.gaze, 'generate_video'):
            self.gaze.generate_video(dir, i_episode, current_episode_key=current_episode_key)
        shutil.rmtree(f"{self.config.habitat_baselines.image_dir}/{current_episode_key}")


def check_collision(
    last_pose: np.ndarray,
    curr_pose: np.ndarray,
    last_action: ContinuousNavigationAction,
    pixels_per_m: int,
    platform: str
):
    if last_pose is None or curr_pose is None:
        actual_displacement = 0
    else:
        actual_displacement = np.linalg.norm(np.array(curr_pose[:2]) - np.array(last_pose[:2]))
    if last_action == DiscreteNavigationAction.MOVE_FORWARD:
        intended_displacement = 0.1
    elif hasattr(last_action, 'xyt') and last_action.xyt is not None and np.linalg.norm(last_action.xyt[:2]) > 0:
        intended_displacement = np.linalg.norm(last_action.xyt[:2])
    else:
        intended_displacement = 0

    error = abs(actual_displacement - intended_displacement)

    if platform == 'sim':
        return 0 if error < np.sqrt(2) / pixels_per_m else error
    elif platform == 'spot':
        return 0 if error < 0.9 else error