from typing import Any, Union

import cv2
import numpy as np
from frontier_exploration.frontier_detection import detect_frontier_waypoints
from frontier_exploration.utils.fog_of_war import reveal_fog_of_war
import nvtx

import skimage.draw
from vlfm.mapping.base_map import BaseMap
from vlfm.utils.geometry_utils import extract_yaw, get_point_cloud, transform_points
from vlfm.utils.img_utils import fill_small_holes

from vlfm.mapping.obstacle_map import ObstacleMap as ObstacleMap_VLFM

from helios.agent.planner.fm2_planner import get_times
from helios.agent.utils.utils import get_bounds, xy_to_px

class ObstacleMap(ObstacleMap_VLFM):
    """Generates two maps; one representing the area that the robot has explored so far,
    and another representing the obstacles that the robot has seen so far.
    """

    _map_dtype: np.dtype = np.dtype(bool)
    _frontiers_px: np.ndarray = np.array([])
    frontiers: np.ndarray = np.array([])
    radius_padding_color: tuple = (100, 100, 100)

    def __init__(
        self,
        min_height: float,
        max_height: float,
        agent_radius: float,
        area_thresh: float,  # square meters
        hole_area_thresh: int,  # square pixels
        size: int,
        pixels_per_meter: int,
    ):
        super().__init__(min_height, max_height, agent_radius, area_thresh, hole_area_thresh, size, pixels_per_meter)

        self.mm_robot_radius = 300
        px_robot_radius = int(self.mm_robot_radius * pixels_per_meter // 1000 + 1)
        self._navigable_kernel = np.mgrid[-px_robot_radius:px_robot_radius+1, -px_robot_radius:px_robot_radius+1]
        self._navigable_kernel = np.linalg.norm(self._navigable_kernel, axis=0) <= px_robot_radius
        self.seen_map = np.zeros((size, size), dtype=bool)
        self.wall_map = np.zeros((size, size), dtype=bool)
        self.height_map = np.zeros((size, size), dtype=np.float32)
        self.min_wall_height = 1.1
        self.max_wall_height = 1.3

    def reset(self) -> None:
        super().reset()
        self.seen_map.fill(0)
        self.wall_map.fill(0)
        self.height_map.fill(0)

    @nvtx.annotate('ObstacleMap.update_map')
    def update_map(
        self,
        depth: Union[np.ndarray, Any],
        robot_xy: np.ndarray,
        tf_camera_to_episodic: np.ndarray,
        min_depth: float,
        max_depth: float,
        fx: float,
        fy: float,
        point_cloud_episodic_frame,
        pixel_points
    ) -> None:
        robot_px = xy_to_px(robot_xy, self.pixels_per_meter, self.size)

        # if self._hole_area_thresh == -1:
        #     filled_depth = depth.copy()
        #     filled_depth[depth == 0] = 1.0
        # else:
        #     filled_depth = fill_small_holes(depth, self._hole_area_thresh)
        # scaled_depth = filled_depth * (max_depth - min_depth) + min_depth
        # mask = scaled_depth < max_depth
        # point_cloud_camera_frame = get_point_cloud(scaled_depth, mask, fx, fy)
        # point_cloud_episodic_frame = transform_points(tf_camera_to_episodic, point_cloud_camera_frame)
        # # point_cloud_episodic_frame = point_cloud_episodic_frame[np.argsort(point_cloud_episodic_frame[:, 2])]
        # xy_points = point_cloud_episodic_frame[:, :2]
        # pixel_points = self._xy_to_px(xy_points)
        map_size = self.height_map.shape[0]
        px_mask = (pixel_points[:,0]<map_size)*(pixel_points[:,1]<map_size)*(pixel_points[:,0]>-map_size)*(pixel_points[:,1]>-map_size)
        pixel_points = pixel_points[px_mask]
        self.height_map[pixel_points[:, 1], pixel_points[:, 0]] = point_cloud_episodic_frame[px_mask][:, 2]

        obstacle_cloud = filter_points_by_height(point_cloud_episodic_frame, self._min_height, self._max_height)
        xy_points = obstacle_cloud[:, :2]
        xy_points = xy_points[np.linalg.norm(xy_points - robot_xy[None], axis=1) > self.mm_robot_radius / 1000] # added robot radius check to obstacle points
        pixel_points = self._xy_to_px(xy_points)
        px_mask = (pixel_points[:,0]<map_size)*(pixel_points[:,1]<map_size)*(pixel_points[:,0]>-map_size)*(pixel_points[:,1]>-map_size)
        pixel_points = pixel_points[px_mask]
        self._map[pixel_points[:, 1], pixel_points[:, 0]] = True

        wall_cloud = filter_points_by_height(point_cloud_episodic_frame, self.min_wall_height, self.max_wall_height)
        xy_points = wall_cloud[:, :2]
        pixel_points = self._xy_to_px(xy_points)
        px_mask = (pixel_points[:,0]<map_size)*(pixel_points[:,1]<map_size)*(pixel_points[:,0]>-map_size)*(pixel_points[:,1]>-map_size)
        pixel_points = pixel_points[px_mask]
        self.wall_map[pixel_points[:, 1], pixel_points[:, 0]] = True
        
        # flood fill obstacles
        with nvtx.annotate('flood fill obstacles'):
            mask = np.pad(self._map, 1, mode='constant', constant_values=1)
            retval, image, mask, rect = cv2.floodFill(
                self._map.astype(np.uint8),
                mask.astype(np.uint8),
                robot_px,
                2,
            )
            self._map = image != 2

        # dilate obstacles for navigable map
        with nvtx.annotate('dilate obstacles'):
            self.dilated_obstacle_map = cv2.dilate(
                self._map.astype(np.uint8),
                self._navigable_kernel.astype(np.uint8),
                iterations=1,
            ).astype(bool)
            self.bounds = get_bounds(self.dilated_obstacle_map, 0)
            self.bounds = (
                min(self.bounds[0], robot_px[1] - 1),
                max(self.bounds[1], robot_px[1] + 2),
                min(self.bounds[2], robot_px[0] - 1),
                max(self.bounds[3], robot_px[0] + 2),
            )
            self._navigable_map = ~self.dilated_obstacle_map
            self._navigable_map[:self.bounds[0]] = False
            self._navigable_map[self.bounds[1]:] = False
            self._navigable_map[:, :self.bounds[2]] = False
            self._navigable_map[:, self.bounds[3]:] = False
        # ensure current location is navigable
        if not self._navigable_map[robot_px[1], robot_px[0]]:
            with nvtx.annotate('ensure agent navigable'):
                dist_to_navigable = get_times(None, self._navigable_map, None, 1)
                min_dist = dist_to_navigable[robot_px[1], robot_px[0]] # don't round dist
                disk = np.stack(skimage.draw.disk((robot_px[1], robot_px[0]), min_dist), axis=-1)
                self._navigable_map[disk[:, 0], disk[:, 1]] = True
                self.dilated_obstacle_map[disk[:, 0], disk[:, 1]] = False
                # print('[ObstacleMap.update_map()] Current location was not navigable. Filling radius of', min_dist)
        # flood fill navigable map
        with nvtx.annotate('flood fill navigable'):
            mask = np.pad(~self._navigable_map, 1, mode='constant', constant_values=1)
            retval, image, mask, rect = cv2.floodFill(
                self._navigable_map.astype(np.uint8),
                mask.astype(np.uint8),
                robot_px,
                2,
            )
            self._navigable_map = image == 2

        # populate seen map
        xy_seen = point_cloud_episodic_frame[:, :2]
        px_seen = self._xy_to_px(xy_seen)
        px_mask = (px_seen[:,0]<map_size)*(px_seen[:,1]<map_size)*(px_seen[:,0]>-map_size)*(px_seen[:,1]>-map_size)
        px_seen = px_seen[px_mask]
        self.seen_map[px_seen[:, 1], px_seen[:, 0]] = 1
        x0, y0 = robot_px - self._navigable_kernel.shape[0]//2
        x1, y1 = x0 + self._navigable_kernel.shape[0], y0 + self._navigable_kernel.shape[1]
        self.seen_map[y0:y1, x0:x1] |= self._navigable_kernel

        # update explored area
        with nvtx.annotate('update explored area'):
            # self.explored_area[
            #     np.linalg.norm(
            #         np.mgrid[:self._map.shape[0], :self._map.shape[1]] - robot_px[::-1, None, None],
            #         axis=0,
            #     ) <= self.pixels_per_meter * 0.1
            # ] = 1
            px_seen = px_seen[np.linalg.norm(px_seen - robot_px, axis=1) <= self.pixels_per_meter * 3.0]
            kernel_size = self._navigable_kernel.shape[0]
            self.explored_area[
                robot_px[1] - kernel_size//2:robot_px[1] - kernel_size//2 + kernel_size,
                robot_px[0] - kernel_size//2:robot_px[0] - kernel_size//2 + kernel_size,
            ] |= self._navigable_kernel
            self.explored_area[px_seen[:, 1], px_seen[:, 0]] = 1
            self.explored_area[~self._navigable_map] = 0

        # Compute frontier locations
        with nvtx.annotate('compute frontiers'):
            self._frontiers_px = self._get_frontiers()
            if len(self._frontiers_px) == 0:
                self.frontiers = np.array([])
            else:
                self.frontiers = self._px_to_xy(self._frontiers_px)
                dist_frontiers = np.linalg.norm(self.frontiers - robot_xy, axis=1)
                if any(dist_frontiers < 0.5):
                    # print(f"[ObstacleMap.update_map()] Discarding {np.sum(dist_frontiers < 0.5)} frontiers because closer than 0.5m")
                    self._frontiers_px = self._frontiers_px[dist_frontiers > 0.5]
                    self.frontiers = self.frontiers[dist_frontiers > 0.5]

def filter_points_by_height(points: np.ndarray, min_height: float, max_height: float) -> np.ndarray:
    return points[(points[:, 2] >= min_height) & (points[:, 2] <= max_height)]
