from collections import defaultdict
import imageio
import nvtx
import cv2
import numpy as np
from typing import Tuple, Any, Dict
from home_robot.core.interfaces import Observations
from home_robot.perception.constants import d3_40_colors_rgb
from helios.agent.gaze.gs_object_map import GsObjectMap
from helios.agent.gaze.gaussian_representation.slam import SLAM
from helios.agent.utils.utils import get_bounds, obs_to_tf, xy_to_px, px_to_gps, gen_poses_to_tf_om, px_to_xy, xy_to_px_tensor

from depth_camera_filtering import filter_depth
from vlfm.utils.geometry_utils import get_point_cloud, transform_points
from vlfm.utils.img_utils import fill_small_holes
from helios.agent.planner.fm2_planner import get_times

from helios.agent.gaze.merged_gaze import get_wayposes
import matplotlib.pyplot as plt

from helios.agent.gaze.utils.utils import calculate_instance_heightmap

class IntegratedSimpleMap:
    def __init__(
        self,
        img_size: Tuple[int, int],
        hfov: float,
        semantic_threshold_for_go: float,
        semantic_threshold_for_sr: float,
        semantic_threshold_for_er: float,
        min_depth: float,
        max_depth: float,
        map_size: int,
        min_cluster_size_sr: int,
        min_cluster_size_go: int
    ):
        
        self.img_size = img_size
        self.K = np.array([
            [(img_size[1] / 2) / np.tan(np.radians(hfov) / 2.0), 0.0, img_size[1] / 2, 0.0],
            [0.0, (img_size[1] / 2) / np.tan(np.radians(hfov) / 2.0), img_size[0] / 2, 0.0],
            [0.0, 0.0, 1, 0],
            [0.0, 0.0, 0, 1],
        ])
        self.initialized = False
        self.device = "cuda"
        self.min_depth = min_depth
        self.max_depth = max_depth
        self.episode_key = None

        self.class_mapping = {}

        self.map_size = map_size
        self.instance_maps = {}
        self.instance_visited: dict[int, dict[int, bool]] = defaultdict(dict)
        self.thresholds: dict[int, float] = {}

        self.min_cluster_size = min_cluster_size_sr
        self.min_cluster_size_go = min_cluster_size_go


        self.valid_pick = np.ones((self.map_size, self.map_size), dtype=bool)#, device=self.device)

        self.bad_instances = {}

        self.od_thresh_go = semantic_threshold_for_go
        self.od_thresh_sr = semantic_threshold_for_sr
        self.od_thresh_er = semantic_threshold_for_er

        self.map = np.empty((map_size, map_size), dtype=np.uint8)
        self.initialized = False # for compatibility with GsObjectMap

    def reset(self):
        self.instance_maps = {}
        self.instance_visited = defaultdict(dict)
        self.instance_heights = defaultdict(dict)
        self.instance_orig_score_variables = defaultdict(dict)
        self.instance_curr_score_variables = defaultdict(dict)
        self.initialized = False 
        self.use_is_visited = None

        self.valid_pick.fill(1)
        self.map.fill(0)

        self.bad_instances = {}
        self.class_mapping = {}

    def reset_vectorized(self):
        self.reset()

    def set_episode_key(self, episode_key):
        self.episode_key = episode_key

    def get_renderings(self, obs: Observations):
        return (
            np.zeros_like(obs.rgb),
            np.zeros_like(obs.rgb),
            np.zeros_like(obs.rgb),
            np.zeros_like(obs.rgb),
            np.zeros_like(obs.depth),
            np.zeros_like(obs.rgb)
        )


    def update(
        self,
        obs: Observations,
        info: Dict[str, Any],
    ):
        self.class_mapping[obs.task_observations['start_recep_goal']]=2
        self.class_mapping[obs.task_observations['object_goal']]=1
        self.class_mapping[obs.task_observations['end_recep_goal']]=3


        tf_camera_to_episodic = obs_to_tf(obs)
        self.bounds = info['bounds']
        self.pixels_per_meter = info['pixels_per_meter']
        # for visual
        self.start_recep_goal = obs.task_observations['start_recep_goal']
        self.end_recep_goal = obs.task_observations['end_recep_goal']
        self.object_goal = obs.task_observations['object_goal']
        self.obstacle_map = info['obstacle_map']
        self.navigable_map = info['navigable_map']
        updated = False

        self.bad_instances[self.object_goal] = []
        self.bad_instances[self.end_recep_goal] = []

        # depth magic
        with nvtx.annotate('depth magic'):
            depth = obs.depth.copy()
            depth[depth > info['max_depth']] = info['max_depth']
            depth = (depth - info['min_depth']) / (info['max_depth'] - info['min_depth'])
            depth[depth > 1.0] = 1.1
            depth = filter_depth(depth.reshape([depth.shape[-2],depth.shape[-1]]), blur_type=None)
            depth = fill_small_holes(depth, info['hole_area_thresh'])
            depth = depth * (info['max_depth'] - info['min_depth']) + info['min_depth']

        # start recep
        # with nvtx.annotate('start recep'):
        #     mask = obs.semantic == obs.task_observations['start_recep_goal']
        #     points = get_point_cloud(depth, mask, info['fx'], info['fy']) # n, 3
        #     points = transform_points(tf_camera_to_episodic, points) # n, 3
        #     px = xy_to_px(points[..., :2], info['pixels_per_meter'], self.map_size)
        #     px = px[self.map[px[:, 1], px[:, 0]] != self.object_goal] # don't overwrite object
        #     px = px[info['obstacle_map'][px[:, 1], px[:, 0]]] # objects are obstacles
        #     self.map[px[:, 1], px[:, 0]] = obs.task_observations['start_recep_goal']

        mask_er =  np.zeros((
                    obs.rgb.shape[0],
                    obs.rgb.shape[1],
            ), dtype=bool)
        mask_go =  np.zeros((
                    obs.rgb.shape[0],
                    obs.rgb.shape[1],
            ), dtype=bool)

        if 'instance_map' in obs.task_observations.keys():
            for i_instance, (instance_class, instance_score) in enumerate(zip(obs.task_observations['instance_classes'], obs.task_observations['instance_scores'])):
                if (instance_class == 1) and (instance_score >= self.od_thresh_go):
                    mask_go[obs.task_observations['instance_map'] == i_instance] = True
                elif (instance_class == 3) and (instance_score >= self.od_thresh_er):
                    mask_er[obs.task_observations['instance_map'] == i_instance] = True
        else:
            mask_go = obs.semantic == obs.task_observations['object_goal']
            mask_er = obs.semantic == obs.task_observations['end_recep_goal']

        # end recep
        with nvtx.annotate('end recep'):
            points = get_point_cloud(depth, mask_er, info['fx'], info['fy'])
            points = transform_points(tf_camera_to_episodic, points)
            px = xy_to_px(points[..., :2], info['pixels_per_meter'], self.map_size)
            px = px[self.map[px[:, 1], px[:, 0]] != self.object_goal] # don't overwrite object
            px = px[info['obstacle_map'][px[:, 1], px[:, 0]]] # objects are obstacles
            self.map[px[:, 1], px[:, 0]] = obs.task_observations['end_recep_goal']

        # object
        with nvtx.annotate('object'):
            # if not np.isin(self.object_goal, self.map):
            if np.any(mask_go):
                points = get_point_cloud(depth, mask_go, info['fx'], info['fy']) # n, 3
                points = transform_points(tf_camera_to_episodic, points) # n, 3
                px = xy_to_px(points[..., :2], info['pixels_per_meter'], self.map_size)
                px = px[info['obstacle_map'][px[:, 1], px[:, 0]]] # objects are obstacles
                px_centroid = np.mean(px, axis=0).astype(np.int32)
                mm_window_size = 400
                px_window_size = mm_window_size * info['pixels_per_meter'] // 1000
                # if np.any(self.map[
                #     px_centroid[1]-px_window_size//2:px_centroid[1]+px_window_size//2+1,
                #     px_centroid[0]-px_window_size//2:px_centroid[0]+px_window_size//2+1    
                # ] == self.start_recep_goal): # if object is with 20cm of start recep, add
                self.map[px[:, 1], px[:, 0]] = obs.task_observations['object_goal']
                # else:
                #     input('detected an object, but was not near start recep.')
        
        # ensure objects are obstacles
        self.map[info['navigable_map']] = 0

        mask = self.map == obs.task_observations['object_goal']
        self.map[mask] *= self.valid_pick[mask]

        # update instance map
        for obj_class in [self.end_recep_goal, self.object_goal]:
            if not np.isin(obj_class, self.map):
                continue
            if obj_class not in self.instance_maps:
                self.instance_maps[obj_class] = np.zeros((self.map_size, self.map_size), dtype=np.uint8)
                self.instance_visited[obj_class] = {}

            y0, y1, x0, x1 = self.bounds
            instance_map = self.instance_maps[obj_class][y0:y1, x0:x1]
            semantic_mask = self.map[y0:y1, x0:x1] == obj_class
            retval, labels = cv2.connectedComponents(semantic_mask.astype(np.uint8), connectivity=8)
            for label in range(1, retval):
                new_mask = labels == label
                overlap_labels = np.sort(np.unique(instance_map[new_mask]))
                overlap_labels = overlap_labels[overlap_labels != 0]
                if len(overlap_labels) == 0: # new object
                    new_label = np.unique(instance_map)
                    new_label = np.isin(np.arange(1, len(new_label)+1), new_label, invert=True)
                    new_label = np.argmax(new_label) + 1
                    instance_map[new_mask] = new_label
                    self.instance_visited[obj_class][new_label] = False
                elif len(overlap_labels) == 1: # overlap with prev
                    instance_map[new_mask] = overlap_labels[0]
                else: # merge
                    instance_map[new_mask] = overlap_labels[0]
                    instance_map[np.isin(instance_map, overlap_labels[1:])] = overlap_labels[0]
                    self.instance_visited[obj_class][overlap_labels[0]] = any(
                        self.instance_visited[obj_class][overlap_label] for overlap_label in overlap_labels
                    ) # if any of the srcs are visited, merged is visited
                    for overlap_label in overlap_labels[1:]:
                        self.instance_visited[obj_class].pop(overlap_label)
    
    def has_obj(self, obj_class: int):
        if obj_class == self.object_goal:
            min_cluster_size = self.min_cluster_size_go
        else:
            min_cluster_size = self.min_cluster_size
        if obj_class in self.instance_maps.keys():
            for idx in np.unique(self.instance_maps[obj_class]):
                if idx==0:
                    continue
                mask = self.instance_maps[obj_class]==idx
                if obj_class == self.object_goal:
                    mask *= self.valid_pick
                if np.sum((self.map==obj_class)[mask])>min_cluster_size:
                        return True
        return False
    
    def has_obj_uncertainty_weighted(self, obj_class: int):
        return self.has_obj(obj_class)

    def get_semantic_map(self, obj_class: int):
        return self.map==obj_class
    
    def get_instance_map(self, obj_class: int, instance: int):
        return self.instance_maps[obj_class]==instance
    
    def mark_instance_visited(self, obj_class: int, instance: int):
        self.instance_visited[obj_class][instance] = True

    def remove_invalid_pick(self, obj_idx):
        #get closest instance which is not marked as invalid
        obj_class = self.object_goal

        if obj_idx in np.unique(self.instance_maps[obj_class]):
            # print("INVALID PICK: ", obj_idx, np.sum(self.instance_maps[obj_class]==obj_idx))
            self.valid_pick[self.instance_maps[obj_class]==obj_idx] = False
            self.instance_visited[obj_class][obj_idx] = True

    @nvtx.annotate("GlobalObjectiveMap.visualize")
    def visualize(self) -> np.ndarray:
        """Visualizes the map."""
        y0, y1, x0, x1 = self.bounds
        vis_img = np.ones((y1 - y0, x1 - x0, 3), dtype=np.uint8) * 255
        vis_img[self.obstacle_map[y0:y1, x0:x1]] = (0, 0, 0)
        try:
            vis_img[(self.map==self.object_goal)[y0:y1, x0:x1].cpu().numpy()] = d3_40_colors_rgb[1]
            vis_img[(self.map==self.end_recep_goal)[y0:y1, x0:x1].cpu().numpy()] = d3_40_colors_rgb[2]
            vis_img[(self.map==self.end_recep_goal)[y0:y1, x0:x1].cpu().numpy()] = d3_40_colors_rgb[3]
        except:
            pass
        vis_img = vis_img[::-1]

        return vis_img
