import dataclasses
from typing import Any, Dict, Optional, Tuple, Union

import cv2
import nvtx
import torch
import imageio
import numpy as np
from scipy.spatial.transform.rotation import Rotation
from vlfm.utils.geometry_utils import get_point_cloud, transform_points
from habitat_baselines.common.tensor_dict import TensorDict
from home_robot.core.interfaces import DiscreteNavigationAction, Observations

from helios.agent.nav.vlfm.mapping.obstacle_map import ObstacleMap
from depth_camera_filtering import filter_depth
from torch import Tensor
from trimesh import transformations as tra
from vlfm.mapping.frontier_map import FrontierMap
# from vlfm.mapping.obstacle_map import ObstacleMap
from vlfm.mapping.value_map import ValueMap
from vlfm.policy.utils.acyclic_enforcer import AcyclicEnforcer
from vlfm.utils.geometry_utils import (
    closest_point_within_threshold,
    xyz_yaw_to_tf_matrix,
)
from vlfm.vlm.blip2itm import BLIP2ITM
from vlfm.utils.img_utils import pixel_value_within_radius

from helios.agent.utils.utils import gps_to_px, xy_to_px

class VlfmMapPolicy:
    _selected__frontier_color: Tuple[int, int, int] = (0, 255, 0)
    _frontier_color: Tuple[int, int, int] = (0, 0, 255)
    _circle_marker_thickness: int = 2
    _circle_marker_radius: int = 5
    def __init__(
        self,
        min_depth: float,
        max_depth: float,
        min_obstacle_height: float,
        max_obstacle_height: float,
        agent_radius: float,
        obstacle_map_area_threshold: float,
        hole_area_thresh: int,
        image_width: int,
        camera_height: int,
        camera_fov: int,
        max_line_len: int,
        sort_by_distance: bool,
        pixels_per_meter: int,
        map_size: int,
    ):
        self.pixels_per_meter = pixels_per_meter
        self.map_size = map_size
        self.blip2itm = BLIP2ITM()
        self.obstacle_map = ObstacleMap(
            min_height=min_obstacle_height,
            max_height=max_obstacle_height,
            agent_radius=agent_radius,
            area_thresh=obstacle_map_area_threshold,  # square meters
            hole_area_thresh=hole_area_thresh,  # square pixels
            size=map_size,
            pixels_per_meter=pixels_per_meter,
        )
        
        self.value_map = ValueMap(
            value_channels=2,
            size=map_size,
            pixels_per_meter=pixels_per_meter,
            use_max_confidence=False,
            fusion_type="default",
            obstacle_map=self.obstacle_map,
        )

        self._frontier_radius = 0.5

        self._last_frontier = np.zeros(2)
        self._last_value = 0.0
        self._acyclic_enforcer = AcyclicEnforcer()

        self._compute_frontiers = True

        self.max_line_len = max_line_len
        self.fov = camera_fov * (np.pi / 180)
        self._min_depth = min_depth
        self._max_depth = max_depth
        self.image_width = image_width
        self.camera_height = camera_height
        self.fx = self.fy = self.image_width / (2 * np.tan(self.fov / 2))
        
        self.sort_by_distance = sort_by_distance

    def reset(self):
        self.value_map.reset()
        self.obstacle_map.reset()

        self._last_frontier = np.zeros(2)
        self._last_value = 0.0
    
    def update(self, obs_dict: Dict[str, Any], point_cloud_episodic_frame, pixel_points):
        self.target_object = obs_dict['task_observations']['start_recep_name']
        self.target_objects = [
            obs_dict['task_observations']['start_recep_name'],
            obs_dict['task_observations']['place_recep_name'],
        ]
        rgb = obs_dict['rgb']

        #### Normalize depth and filter (from VLFM)
        depth = obs_dict['depth']#.copy()
        # depth[depth > 1.0] = 1.1  # still filter out but make viz less bad
        # depth = filter_depth(depth.reshape([depth.shape[-2],depth.shape[-1]]), blur_type=None)

        ### VLFM map update ###

        prompt_start = (
            "Seems like there is a "
            + obs_dict['task_observations']["start_recep_name"]
            + " ahead."
        )
        prompt_goal = (
            "Seems like there is a "
            + obs_dict['task_observations']["place_recep_name"]
            + " ahead."
        )

        values = np.array(
            [
                self.blip2itm.cosine(rgb, prompt_start),
                self.blip2itm.cosine(rgb, prompt_goal),
            ]
        )

        self.obstacle_map.update_map(
            depth,
            obs_dict['gps'],
            obs_dict['tf_camera_to_episodic'],
            self._min_depth,
            self._max_depth,
            self.fx,
            self.fy,
            point_cloud_episodic_frame,
            pixel_points
        )

        self.value_map.update_map(
            values,
            depth,
            obs_dict['tf_camera_to_episodic'],
            self._min_depth,
            self._max_depth,
            self.fov,
        )

        self.value_map.update_agent_traj(obs_dict['gps'], obs_dict['compass'][0])
        self.obstacle_map.update_agent_traj(obs_dict['gps'], obs_dict['compass'][0])

        return obs_dict['tf_camera_to_episodic']
    
    def get_best_frontier(
        self,
        obs: Union[Dict[str, Tensor], "TensorDict"],
        frontiers: np.ndarray,
        channel: int,
        force_new: bool = False,
    ) -> Tuple[np.ndarray, float]:
        """Returns the best frontier and its value based on self._value_map.

        Args:
            observations (Union[Dict[str, Tensor], "TensorDict"]): The observations from
                the environment.
            frontiers (np.ndarray): The frontiers to choose from, array of 2D points.

        Returns:
            Tuple[np.ndarray, float]: The best frontier and its value.
        """
        robot_xy = np.array(obs['gps'])

        # The points and values will be sorted in descending order
        if self.sort_by_distance:
            values = np.array([np.linalg.norm(frontier - robot_xy) for frontier in frontiers])
            sorted_indices = np.argsort(values)
            sorted_pts = frontiers[sorted_indices]
            sorted_values = values[sorted_indices]
        else:
            sorted_pts, sorted_values = self.value_map.sort_waypoints(
                frontiers,
                radius=self._frontier_radius,
                reduce_fn=lambda vals: [v[channel] for v in vals],
            )

        best_frontier_idx = None
        top_two_values = tuple(sorted_values[:2])

        # If there is a last point pursued, then we consider sticking to pursuing it
        # if it is still in the list of frontiers and its current value is not much
        # worse than self._last_value.
        if (not np.array_equal(self._last_frontier, np.zeros(2))) and (not force_new):
            curr_index = None

            for idx, p in enumerate(sorted_pts):
                if np.array_equal(p, self._last_frontier):
                    # Last point is still in the list of frontiers
                    curr_index = idx
                    break

            if curr_index is None:
                closest_index = closest_point_within_threshold(
                    sorted_pts, self._last_frontier, threshold=0.5
                )

                if closest_index != -1:
                    # There is a point close to the last point pursued
                    curr_index = closest_index
            if curr_index is not None:
                curr_value = sorted_values[curr_index]
                if curr_value + 0.01 > self._last_value:
                    # The last point pursued is still in the list of frontiers and its
                    # value is not much worse than self._last_value
                    # print("Sticking to last point.") # this is a lot of print
                    best_frontier_idx = curr_index

        # If there is no last point pursued, then just take the best point, given that
        # it is not cyclic.
        if best_frontier_idx is None:
            for idx, frontier in enumerate(sorted_pts):
                cyclic = self._acyclic_enforcer.check_cyclic(
                    robot_xy, frontier, top_two_values
                )
                if cyclic:
                    print("Suppressed cyclic frontier.")
                    continue
                if force_new and (not np.array_equal(self._last_frontier, np.zeros(2))):
                    if np.array_equal(frontier, self._last_frontier):
                        print("Suppressed choosing previous frontier")
                        continue
                best_frontier_idx = idx
                break

        if (
            (best_frontier_idx is None)
            and force_new
            and (not np.array_equal(self._last_frontier, np.zeros(2)))
        ):
            for idx, frontier in enumerate(sorted_pts):
                if np.array_equal(frontier, self._last_frontier):
                    print("Suppressed choosing previous frontier")
                    continue
                best_frontier_idx = idx
                break
            if best_frontier_idx is None:
                print("Only previous frontier remaining.")

        if best_frontier_idx is None:
            print("All frontiers are cyclic. Just choosing the closest one.")
            best_frontier_idx = max(
                range(len(frontiers)),
                key=lambda i: np.linalg.norm(frontiers[i] - robot_xy),
            )

        best_frontier = sorted_pts[best_frontier_idx]
        best_value = sorted_values[best_frontier_idx]
        self._acyclic_enforcer.add_state_action(robot_xy, best_frontier, top_two_values)
        self._last_value = best_value
        self._last_frontier = best_frontier

        return best_frontier, best_value

    def get_value(self, point: np.ndarray, channel: int) -> float:
        radius_px = int(self._frontier_radius * self.value_map.pixels_per_meter)
        x, y = point
        px = int(-x * self.value_map.pixels_per_meter) + self.value_map._episode_pixel_origin[0]
        py = int(-y * self.value_map.pixels_per_meter) + self.value_map._episode_pixel_origin[1]
        point_px = (self.value_map._value_map.shape[0] - px, py)
        return pixel_value_within_radius(self.value_map._value_map[..., channel], point_px, radius_px)

    @nvtx.annotate("VlfmMapPolicy.get_policy_info")
    def get_policy_info(self, visualize: bool) -> Dict[str, Any]:
        policy_info = {
            "target_object": self.target_object.split("|")[0],
            "target_objects": self.target_objects,
            "gps": str(np.array([self.obs.gps[0], self.obs.gps[1]])),
            "yaw": np.rad2deg(self.obs.compass[0]),
            # breaks with config.GROUND_TRUTH_SEMANTICS=1
            # "target_detected": 1 in obs.task_observations['instance_classes'],
            "target_point_cloud": np.array([]),
            "target_point_clouds": [],
            "nav_goal": self._last_frontier,
            "stop_called": False,
            # don't render these on egocentric images when making videos:
            "render_below_images": [
                "target_object",
            ],
        }

        if visualize:
            policy_info.update(self.get_visual_info())

        return policy_info
    
    @nvtx.annotate("VlfmMapPolicy.get_visual_info")
    def get_visual_info(self) -> Dict[str, Any]:
        policy_info = {}

        # obstacle map
        obstacle_map_vis = self.obstacle_map.visualize()
        policy_info["obstacle_map"] = cv2.cvtColor(obstacle_map_vis, cv2.COLOR_BGR2RGB)
        frontiers = self.obstacle_map.frontiers
        marker_kwargs = {
            'radius': self._circle_marker_radius,
            'thickness': self._circle_marker_thickness,
            'color': None,
        }
        for frontier in frontiers:
            marker_kwargs['color'] = self._selected__frontier_color if np.array_equal(frontier, self._last_frontier) else self._frontier_color
            self.value_map._traj_vis.draw_circle(
                policy_info['obstacle_map'],
                frontier[:2],
                **marker_kwargs
            )

        # value maps
        visuals_bounds =  [
            self.value_map.visualize(None, reduce_fn=lambda m: m[..., value_channel], obstacle_map=self.obstacle_map)
            for value_channel in range(self.value_map._value_channels)
        ]
        policy_info['value_maps'] = []
        for value_channel, value_map in enumerate(visuals_bounds):
            sorted_frontiers, sorted_values = self.value_map.sort_waypoints(
                waypoints=frontiers,
                radius=0.1,
                reduce_fn=lambda waypoints_values: [waypoint_values[value_channel] for waypoint_values in waypoints_values],
            )
            value_map = cv2.cvtColor(
                value_map,
                cv2.COLOR_BGR2RGB,
            )
            for i_frontier, frontier in enumerate(sorted_frontiers):
                marker_kwargs['color'] = self._selected__frontier_color if i_frontier == 0 else self._frontier_color
                self.value_map._traj_vis.draw_circle(
                    value_map,
                    frontier[:2],
                    **marker_kwargs
                )
            policy_info['value_maps'].append(value_map)

        ## vlfm seen map vis
        window_size = 100
        y0, x0 = np.maximum(gps_to_px(self.obs.gps, self.pixels_per_meter, self.map_size) - window_size // 2, 0)
        y1, x1 = y0 + window_size, x0 + window_size
        seen_vis = np.zeros((y1-y0, x1-x0, 3), dtype=np.uint8)
        seen_vis[self.obstacle_map.seen_map[y1:y0:-1, x0:x1]] += np.array([0, 0, 255], dtype=np.uint8)
        seen_vis[~self.obstacle_map._navigable_map[y1:y0:-1, x0:x1]] += np.array([255, 0, 0], dtype=np.uint8)
        seen_vis[self.obstacle_map.wall_map[y1:y0:-1, x0:x1]] += np.array([255, 255, 255], dtype=np.uint8)
        seen_vis[(y1-y0) // 2, (x1-x0) // 2] += np.array([0, 255, 0], dtype=np.uint8)
        policy_info['value_maps'].append(seen_vis)
        policy_info['render_below_images'] = ['debug']
        policy_info['debug'] = ''

        return policy_info