import dataclasses
import os
from typing import Any, Dict, Optional, Tuple, Union

import cv2
from home_robot.perception.constants import RearrangeDETICCategories
import nvtx
import skimage
import torch
import imageio
import numpy as np
from scipy.spatial.transform.rotation import Rotation
from vlfm.utils.geometry_utils import get_point_cloud, transform_points
from habitat_baselines.common.tensor_dict import TensorDict
from home_robot.core.interfaces import ContinuousFullBodyAction, ContinuousNavigationAction, DiscreteNavigationAction, Observations
from omegaconf import DictConfig

from home_robot.perception.constants import d3_40_colors_rgb
from helios.agent.nav.generating_habitat_vis import GeneratingHabitatVis
from helios.agent.nav.saving_habitat_vis import SavingHabitatVis
from helios.agent.nav.vlfm.vlfm_map_policy import VlfmMapPolicy
import skimage.draw

from helios.agent.planner.fm2_planner import get_times
from helios.agent.utils.utils import get_bounds, gps_to_px, obs_to_tf, xy_to_px

from depth_camera_filtering import filter_depth
from vlfm.utils.img_utils import fill_small_holes

class VlfmNav:
    def __init__(
        self,
        habitat_vis: SavingHabitatVis,
        video_dir: str,
        video_fps: int,
        visualize: bool,
        min_depth: float,
        max_depth: float,
        semantic_map_config: DictConfig,
        map_policy: VlfmMapPolicy,
        keep_exploring: bool,
    ):
        self.habitat_vis = habitat_vis
        self.video_dir = video_dir
        self.video_fps = video_fps

        self.ovmm_semantic_map = np.empty((map_policy.map_size, map_policy.map_size))

        self._found_dist = 0.4
        self._last_goal = np.zeros(2)

        self.semantic_map_config = semantic_map_config
        self._visualize = visualize
        self._min_depth = min_depth
        self._max_depth = max_depth

        self.map_policy = map_policy
        self.keep_exploring = keep_exploring

        self.chosen_frontier = None

    def reset(self):
        self.habitat_vis.reset()
        self.ovmm_semantic_map.fill(4)
        self.map_policy.reset()
        self._last_goal = np.zeros(2)
        self.prev_goals = []

    def set_vis_dir(self, scene_id: str, episode_id: str):
        if hasattr(self.habitat_vis, 'set_vis_dir'):
            self.habitat_vis.set_vis_dir(scene_id, episode_id)

    @nvtx.annotate('VlfmNav.update')
    def update(self, obs: Observations, info: Dict[str, Any],info_only=False):
        self.target_object = obs.task_observations['start_recep_name']
        
        #Update OVMM semantic map
        tf_camera_to_episodic = obs_to_tf(obs)
        if not info_only:
            depth = obs.depth.copy()
            depth = (depth - self._min_depth) / (self._max_depth - self._min_depth)
            depth[depth > 1.0] = 1.1  # still filter out but make viz less bad
            depth = filter_depth(depth.reshape([depth.shape[-2],depth.shape[-1]]), blur_type=None)
            if self.map_policy.obstacle_map._hole_area_thresh == -1:
                filled_depth = depth.copy()
                filled_depth[depth == 0] = 1.0
            else:
                filled_depth = fill_small_holes(depth, self.map_policy.obstacle_map._hole_area_thresh)
            scaled_depth = filled_depth * (self._max_depth - self._min_depth) + self._min_depth
            mask = scaled_depth < self._max_depth
            fx, fy = obs.camera_K[0, 0], obs.camera_K[1, 1]
            point_cloud_camera_frame = get_point_cloud(scaled_depth, mask, fx, fy)
            point_cloud_episodic_frame = transform_points(tf_camera_to_episodic, point_cloud_camera_frame)
            # point_cloud_episodic_frame = point_cloud_episodic_frame[np.argsort(point_cloud_episodic_frame[:, 2])]
            xy_points = point_cloud_episodic_frame[:, :2]
            pixel_points = self.map_policy.obstacle_map._xy_to_px(xy_points)
        
            # if np.max(obs.semantic)==4:
            #     mask = obs.semantic < 4
            # else:
            #     mask = obs.semantic>-1

            # point_cloud_camera_frame = get_point_cloud(depth, mask, fx, fy)
            # point_cloud_episodic_frame = transform_points(tf_camera_to_episodic, point_cloud_camera_frame)        
            # xy_points = point_cloud_episodic_frame[:, :2]
            # pixel_points = xy_to_px(xy_points, self.map_policy.obstacle_map.pixels_per_meter, self.map_policy.obstacle_map.size)
            self.ovmm_semantic_map[pixel_points[:, 1], pixel_points[:, 0]] = obs.semantic[mask]

        # inputs from home-robot to vlfm
        obs_dict = dataclasses.asdict(obs)
        obs_dict['compass'] -= np.pi / 2
        obs_dict['gps'] = np.array([obs_dict['gps'][1], -obs_dict['gps'][0]])
        # obs_dict['depth'] = (obs_dict['depth'] - self._min_depth) / (self._max_depth - self._min_depth)
        obs_dict['tf_camera_to_episodic'] = tf_camera_to_episodic
        
        if not info_only:
            self.map_policy.update(obs_dict, point_cloud_episodic_frame, pixel_points)

        ## vlfm "map" vis
        if self._visualize:
            window_size = 200
            start_px = gps_to_px(obs.gps, self.map_policy.obstacle_map.pixels_per_meter, self.map_policy.obstacle_map.size)
            y0, x0 = np.maximum(start_px - window_size // 2, 0)
            y1, x1 = y0 + window_size, x0 + window_size
            info.update({
                'top_down_map': {
                    'map': (self.map_policy.obstacle_map._map[::-1]).astype(np.uint8),
                    'fog_of_war_mask': None,
                    'agent_map_coord': [(xy_to_px(obs.gps, self.map_policy.obstacle_map.pixels_per_meter, self.map_policy.obstacle_map.size))],
                    'agent_angle': [obs_dict['compass'].item() - np.pi],
                    'lower_bound': (y0, x0),
                    'upper_bound': (y1, x1),
                }
            })

        info['rgb'] = obs_dict['rgb']
        info['depth'] = (obs_dict['depth'] - self._min_depth) / (self._max_depth - self._min_depth)

        # SimpleObjectMap.update()
        info.update({
            'hole_area_thresh': self.map_policy.obstacle_map._hole_area_thresh,
            'min_depth': self.map_policy._min_depth,
            'max_depth': self.map_policy._max_depth,
            'fx': self.map_policy.fx,
            'fy': self.map_policy.fy,
            'pixels_per_meter': self.map_policy.obstacle_map.pixels_per_meter,
            'navigable_map': self.map_policy.obstacle_map._navigable_map,
            'map_size': self.map_policy.map_size,
            'obstacle_map': self.map_policy.obstacle_map._map,
            'dilated_obstacle_map': self.map_policy.obstacle_map.dilated_obstacle_map,
            'seen_map': self.map_policy.obstacle_map.seen_map,
            'wall_map': self.map_policy.obstacle_map.wall_map,
            'bounds': self.map_policy.obstacle_map.bounds,
        })
        self.map_policy.obs = obs
        self.map_policy.info = info

    @nvtx.annotate('VlfmNav.act')
    def act(
        self,
        obs: Observations,
        info: Dict[str, Any],
        channel: int,
        planner: Any,
        force_new: bool = False
    ):
        obs_dict = dataclasses.asdict(obs)
        obs_dict['compass'] -= np.pi / 2
        obs_dict['gps'] = np.array([obs_dict['gps'][1], -obs_dict['gps'][0]])
        obs_dict['depth'] = (obs_dict['depth'] - self._min_depth) / (self._max_depth - self._min_depth)

        last_goal_gps = np.array([-self._last_goal[1], self._last_goal[0]])
        last_goal_px = gps_to_px(last_goal_gps, self.map_policy.obstacle_map.pixels_per_meter, self.map_policy.obstacle_map.size)
        if not self.map_policy.obstacle_map._navigable_map[last_goal_px[0], last_goal_px[1]]:
            force_new = True
        if force_new:
            print("Last goal is not navigable. Forcing new frontier.")
        
        if np.linalg.norm(self._last_goal - obs_dict['gps']) < self._found_dist or force_new:
            if len(self.map_policy.obstacle_map.frontiers) == 0:
                self.chosen_frontier = None
                if self.keep_exploring:
                    self.prev_goals = list(filter(
                        lambda goal: self.map_policy.obstacle_map._navigable_map[tuple(gps_to_px(
                            np.array([-goal[1], goal[0]]),
                            self.map_policy.obstacle_map.pixels_per_meter,
                            self.map_policy.obstacle_map.size
                        ))],
                        self.prev_goals
                    ))
                    far_enough_prev_goals = list(filter(
                        lambda goal: np.linalg.norm(goal - obs_dict['gps']) > self._found_dist, self.prev_goals
                    ))
                    if len(far_enough_prev_goals) > 0:
                        self._last_goal = far_enough_prev_goals[np.random.randint(len(far_enough_prev_goals))]
                else:
                    raise RuntimeError("No frontiers found during exploration.")
            else:
                if (channel < 0) or (channel >= self.map_policy.value_map._value_map.shape[-1]):
                    raise ValueError(
                        f"Channel out of bounds for vlfm_nav: {channel} \
                        (for {self.map_policy.value_map._value_map.shape[-1]} channels)"
                    )
                if self.chosen_frontier is None:
                    self._last_goal, _ = self.map_policy.get_best_frontier(
                        obs_dict,
                        self.map_policy.obstacle_map.frontiers,
                        channel,
                        force_new
                    )
                else:
                    self._last_goal = self.chosen_frontier
                    self.chosen_frontier = None
                self.prev_goals.append(self._last_goal)

        # viewpoint
        if info['timestep'] < 14:
            joints = np.zeros(10)
            joints[9] = (-np.pi / 3) - obs.joint[9] # look 45 deg down
            xyt = np.zeros(3)
            xyt[2] = np.pi / 6 # rotate 30 deg
            action = ContinuousFullBodyAction(
                joints=joints,
                xyt=xyt,
            )
        elif info['timestep'] < 28:
            joints = np.zeros(10)
            joints[9] = (-np.pi / 12) - obs.joint[9] # look 15 deg down
            xyt = np.zeros(3)
            xyt[2] = np.pi / 6 # rotate 30 deg
            action = ContinuousFullBodyAction(
                joints=joints,
                xyt=xyt,
            )
        else:
            action = planner.plan_to_pose(
                obstacle_map=info['dilated_obstacle_map'],
                seen_map=info['seen_map'], # optional (can be none) but recommended
                xyt_start=np.array([obs.gps[0], obs.gps[1], obs.compass[0]]),
                xyt_goal=np.array([-self._last_goal[1], self._last_goal[0]]), # xy (vlfm) -> gps
                timestep=info['timestep'], # just for visualization
            )
            action = watch_out_for_close_obstacles(action, obs, info)

        return action, info, None

    def reached_frontier(self, obs):
        if np.all(self._last_goal == np.zeros(2)):
            return False
        robot_xy = np.array([obs.gps[1], -obs.gps[0]])
        return np.linalg.norm(self._last_goal - robot_xy) < self._found_dist
    
    def update_ovmm_vis_info(self, 
        obs: Observations, 
        info: Dict[str, Any],
    ):
        if self._visualize:
            # outputs from vlfm to home-robot
            start_px = gps_to_px(obs.gps, self.map_policy.obstacle_map.pixels_per_meter, self.map_policy.obstacle_map.size)
            start = start_px[::-1] / self.map_policy.obstacle_map.pixels_per_meter

            ## ovmm semantic map vis
            goal_px = xy_to_px(self._last_goal, self.map_policy.obstacle_map.pixels_per_meter, self.map_policy.obstacle_map.size)[::-1]
            goal_map = np.zeros_like(self.map_policy.obstacle_map._map)
            goal_map[tuple(goal_px)] = 1
            factor = self.map_policy.map_size // 1000
            info.update({
                'obstacle_map': self.map_policy.obstacle_map._map[::factor, ::factor],
                'dilated_obstacle_map': self.map_policy.obstacle_map.dilated_obstacle_map[::factor, ::factor],
                'goal_map': goal_map[::factor, ::factor],
                'seen_map': self.map_policy.obstacle_map.seen_map[::factor, ::factor],
                'sensor_pose': np.array([
                    start[0], # x
                    start[1], # y
                    np.degrees(obs.compass[0]), # theta
                    0,
                    1000,
                    0,
                    1000,
                ]),
                'semantic_map': self.ovmm_semantic_map[::factor, ::factor].copy(),
                'explored_map': self.map_policy.obstacle_map.explored_area[::factor, ::factor],
                'semantic_map_config': self.semantic_map_config,
            })

            return info
    
    def get_debug_str(
        self,
        action,
        collision,
        n_collisions,
        timestep,
    ):
        return (
            f'[VlfmNav()] '
            + f'step: {timestep} '
            + f'target object: {self.target_object} '
            + f'action: ' + (
                action.name
                if hasattr(action, 'name') else
                f'(x={action.xyt[0]:.2f}, y={action.xyt[1]:.2f}, t={np.degrees(action.xyt[2]):.0f})'
            ) + f' n_collisions: {n_collisions} '
            + (f' COLLISION ' if collision else '')
        )
    
    def generate_video(self, dir, i_episode, current_episode_key):
        if not os.path.exists(f'{dir}/{i_episode:04d}-{current_episode_key}_ovmm.mp4'):
            return
        with imageio.get_writer(f'{dir}/{i_episode:04d}-{current_episode_key}_vlfm.mp4', fps=self.video_fps) as writer:
            with imageio.get_reader(f'{dir}/{i_episode:04d}-{current_episode_key}_ovmm.mp4') as reader:
                first_frame_size = None
                for vlfm_frame, ovmm_frame in zip(
                    self.habitat_vis.flush_frames(),
                    reader
                ):
                    ovmm_frame = cv2.resize(ovmm_frame, (vlfm_frame.shape[1], ovmm_frame.shape[0] * vlfm_frame.shape[1] // ovmm_frame.shape[1]))
                    frame = np.concatenate((vlfm_frame, ovmm_frame), axis=0)
                    if first_frame_size is None:
                        first_frame_size = frame.shape
                    elif frame.shape != first_frame_size:
                        frame = cv2.resize(frame, (first_frame_size[1], first_frame_size[0]))
                    writer.append_data(frame)


def watch_out_for_close_obstacles(action, obs, info):
    if isinstance(action, ContinuousNavigationAction):
        agent_window_size = int(round(2.0 * info['pixels_per_meter']))
        agent_px = gps_to_px(obs.gps, info['pixels_per_meter'], info['map_size'])
        y0, x0 = np.clip(agent_px - agent_window_size // 2, 0, info['map_size'] - agent_window_size)
        y1, x1 = y0 + agent_window_size, x0 + agent_window_size
        dist_from_agent = get_times(
            obstacle_map=info['dilated_obstacle_map'][y0:y1, x0:x1],
            goal_map=tuple(agent_px - (y0, x0)),
            speed=None,
            dx=1 / info['pixels_per_meter'],
        )
        dist_from_agent[cv2.dilate(info['seen_map'][y0:y1, x0:x1].astype(np.uint8), np.ones((3, 3))).astype(bool)] = np.inf
        joints = np.zeros(10)
        joints[4] = 0.775 # arm height
        joints[6] = -np.pi / 2 # gripper pitch
        joints[9] = -np.pi / 12# look 15 deg down
        if np.min(dist_from_agent) < np.inf:
            unseen_pxs = np.stack(np.unravel_index(np.argsort(dist_from_agent, axis=None), dist_from_agent.shape), axis=1)
            for unseen_px in unseen_pxs:
                if dist_from_agent[tuple(unseen_px)] == np.inf:
                    break
                ray = np.stack(skimage.draw.line(agent_window_size // 2, agent_window_size // 2, *unseen_px), axis=1)
                if not np.any(info['obstacle_map'][ray[:, 0] + y0, ray[:, 1] + x0]):
                    dist = np.linalg.norm(unseen_px - agent_window_size // 2) / info['pixels_per_meter']
                    heading = np.arctan2(
                        unseen_px[0] - agent_window_size // 2,
                        unseen_px[1] - agent_window_size // 2,
                    )
                    tilt = np.clip(-np.pi / 2 + np.arctan2(dist, 1.0), -np.pi / 3, -np.pi / 12)
                    joints = np.zeros(10)
                    joints[4] = 0.775 # arm height
                    joints[6] = -np.pi / 2 # gripper pitch
                    joints[9] = tilt
                    action.xyt = np.array([0, 0, heading - obs.compass[0]])
                    print(f'Looking at closest unseen: {dist:.2f}m, pan {np.degrees(heading - obs.compass[0]):.0f} deg, tilt {np.degrees(tilt):.0f} deg')
                    break
        if not np.allclose(joints, obs.joint):
            action = ContinuousFullBodyAction(
                joints=joints - obs.joint,
                xyt=action.xyt,
            )
    return action