from collections import defaultdict
import imageio
import nvtx
import cv2
import numpy as np
from typing import Tuple, Any
from home_robot.core.interfaces import Observations
from home_robot.perception.constants import d3_40_colors_rgb
from helios.agent.gaze.gs_object_map import GsObjectMap
from helios.agent.gaze.gaussian_representation.slam import SLAM
from helios.agent.utils.utils import get_bounds, obs_to_tf, xy_to_px, px_to_gps, gen_poses_to_tf_om, px_to_xy, xy_to_px_tensor

import torch
import matplotlib.pyplot as plt

from helios.agent.gaze.utils.utils import calculate_instance_heightmap, get_wayposes

class GlobalObjectiveMap(GsObjectMap):
    def __init__(
        self,
        img_size: Tuple[int, int],
        hfov: float,
        model: SLAM,
        save_rep: bool,
        semantic_threshold_for_go: float,
        semantic_threshold_for_sr: float,
        semantic_threshold_for_er: float,
        min_depth: float,
        max_depth: float,
        map_size: int,
        threshold_sr: float,
        threshold_go: float,
        threshold_er: float,
        threshold_sr_wu: float,
        threshold_go_wu: float,
        threshold_er_wu: float,
        min_cluster_size_sr: int,
        min_cluster_size_go: int,
        allow_multiple_visits_sr: bool,
        allow_multiple_visits_go: bool,
        allow_multiple_visits_er: bool,
    ):
        super().__init__(
            img_size=img_size,
            hfov=hfov,
            model=model,
            save_rep=save_rep,
            semantic_threshold_for_go=semantic_threshold_for_go,
            semantic_threshold_for_sr=semantic_threshold_for_sr,
            semantic_threshold_for_er=semantic_threshold_for_er,
            min_depth=min_depth,
            max_depth=max_depth
        )
        self.map_size = map_size
        self.threshold_sr = threshold_sr
        self.threshold_go = threshold_go
        self.threshold_er = threshold_er

        self.semantic_maps: dict[int, torch.Tensor] = {}
        self.instance_maps: dict[int, dict[int, torch.Tensor]] = defaultdict(dict)
        self.instance_visited: dict[int, dict[int, bool]] = defaultdict(dict)
        self.thresholds: dict[int, float] = {}
        self.initialized = False 

        self.thresholds_wu: dict[int, float] = {}
        self.threshold_sr_wu = threshold_sr_wu
        self.threshold_go_wu = threshold_go_wu
        self.threshold_er_wu = threshold_er_wu

        self.min_cluster_size = min_cluster_size_sr
        self.min_cluster_size_go = min_cluster_size_go

        self.semantic_maps_wu: dict[int, torch.Tensor] = {}

        self.instance_orig_score_variables: dict[int, dict[str, Any]] = defaultdict(dict)
        self.instance_curr_score_variables: dict[int, dict[str, Any]] = defaultdict(dict)

        self.thresholds: dict[int, float] = {1: threshold_go, 2: threshold_sr, 3: threshold_er}

        self.valid_pick = torch.ones((self.map_size, self.map_size), dtype=bool, device=self.device)

        self.allow_multiple_visits = [allow_multiple_visits_sr, allow_multiple_visits_go, allow_multiple_visits_er]
        self.use_is_visited = None

        self.bad_instances = {}

    def reset(self):
        super().reset()
        self.semantic_maps = {}
        self.instance_maps = defaultdict(dict)
        self.instance_visited = defaultdict(dict)
        self.instance_heights = defaultdict(dict)
        self.thresholds = {}
        self.thresholds_wu = {}
        self.semantic_maps_wu = {}
        self.instance_orig_score_variables = defaultdict(dict)
        self.instance_curr_score_variables = defaultdict(dict)
        self.initialized = False 
        self.use_is_visited = None

        self.valid_pick.fill_(1)

        self.bad_instances = {}

    def update(
        self,
        obs: Observations,
        info
    ):
        super().update(obs, info) # update gs

        self.bounds = info['bounds']
        self.pixels_per_meter = info['pixels_per_meter']
        self.start_recep_goal = obs.task_observations['start_recep_goal']
        self.end_recep_goal = obs.task_observations['end_recep_goal']
        self.object_goal = obs.task_observations['object_goal']
        self.obstacle_map = info['obstacle_map']
        self.navigable_map = info['navigable_map']
        
        self.thresholds[self.start_recep_goal] = self.threshold_sr
        self.thresholds[self.object_goal] = self.threshold_go
        self.thresholds[self.end_recep_goal] = self.threshold_er
        self.thresholds_wu[self.start_recep_goal] = self.threshold_sr_wu
        self.thresholds_wu[self.object_goal] = self.threshold_go_wu
        self.thresholds_wu[self.end_recep_goal] = self.threshold_er_wu

        if not self.did_update_3gds:
            return

        if self.use_is_visited is None:
            self.use_is_visited = {}
            self.use_is_visited[self.start_recep_goal] = not self.allow_multiple_visits[0]
            self.use_is_visited[self.object_goal] = not self.allow_multiple_visits[1]
            self.use_is_visited[self.end_recep_goal] = not self.allow_multiple_visits[2]
            self.bad_instances[self.object_goal] = []
            self.bad_instances[self.end_recep_goal] = []

        obstacle_map_tensor = torch.tensor(info['obstacle_map'], device=self.device)

        # semantic and instance map
        for semantic_name in ['start_recep_goal', 'end_recep_goal', 'object_goal']:
            obj_class = obs.task_observations[semantic_name]
            if obj_class not in self.semantic_maps.keys():
                self.semantic_maps[obj_class] = torch.empty((self.map_size, self.map_size), dtype=bool, device=self.device)
            if obj_class not in self.semantic_maps_wu.keys():
                self.semantic_maps_wu[obj_class] = torch.empty((self.map_size, self.map_size), dtype=bool, device=self.device)
            self.semantic_maps[obj_class].fill_(False)
            self.semantic_maps_wu[obj_class].fill_(False)
            
            has_obj, points, scores, instances = self.get_obj_locs(
                obj_class=obj_class,
                threshold=self.thresholds[obj_class],
                use_uncertainty=False,
                return_scores=True,
                return_instances=True
            )
            if has_obj:
                # project
                heights = points[:, 2]#.detach().cpu().numpy()
                points = points[:, :2]#.detach().cpu().numpy()
                points = -torch.flip(points,[-1])
                px = xy_to_px_tensor(points, self.pixels_per_meter, self.map_size)
                if len(px.shape)==2:
                    mask = (px[:,0]<self.map_size)*(px[:,1]<self.map_size)*(px[:,0]>-self.map_size)*(px[:,1]>-self.map_size)
                    heights = heights[mask]
                    px = px[mask]
                else:
                    continue

                # update semantic map
                obstacle_mask = obstacle_map_tensor[px[:, 0], px[:, 1]]
                px = px[obstacle_mask] # objects are obstacles
                heights = heights[obstacle_mask]
                self.semantic_maps[obj_class][px[:, 0], px[:, 1]] = True


                # update instance map
                instances = instances.detach()#.cpu().numpy()
                instances = instances[mask][obstacle_mask] # objects are obstacles
                instances_to_remove = set(self.instance_maps[obj_class].keys())
                instance_uniques = torch.unique(instances)
                for instance in instance_uniques:
                    instance_mask = instances == instance

                    if (obj_class==self.object_goal) and (torch.sum(instance_mask) < self.min_cluster_size_go):
                        continue
                    elif torch.sum(instance_mask) < self.min_cluster_size:
                        continue

                    instance_item = instance.item()

                    instances_to_remove.discard(instance_item)

                    instance_px = px[instance_mask]
                    if instance not in self.instance_maps[obj_class]:
                        self.instance_maps[obj_class][instance_item] = torch.empty((self.map_size, self.map_size), dtype=bool, device=self.device)
                        self.instance_visited[obj_class][instance_item] = False
                    self.instance_maps[obj_class][instance_item].fill_(False)
                    self.instance_maps[obj_class][instance_item][instance_px[:, 0], instance_px[:, 1]] = True
                    self.instance_heights[obj_class][instance_item] = torch.max(heights[instance_mask]).cpu().numpy()

                    if obj_class == self.start_recep_goal:
                        center = torch.mean(torch.argwhere(self.instance_maps[obj_class][instance_item]).float(),0).unsqueeze(0).cpu().numpy()
                        center_xy = -px_to_xy(center, info['pixels_per_meter'], self.map_size)[:,::-1]
                        self.instance_curr_score_variables['center'][instance_item] = center_xy
                        if not instance in self.instance_orig_score_variables['center'].keys():
                            self.instance_orig_score_variables['center'][instance_item] = center_xy
                            self.instance_orig_score_variables['score'][instance_item] = torch.mean(scores[mask][obstacle_mask][instance_mask]).item()

                    self.instance_curr_score_variables[f'score_{obj_class}'][instance_item] = torch.mean(scores[mask][obstacle_mask][instance_mask]).item()
                # remove instances that are not present anymore
                for instance in instances_to_remove:
                    self.instance_maps[obj_class].pop(instance)
                    self.instance_visited[obj_class].pop(instance)
                    for k in self.instance_curr_score_variables.keys():
                        if instance in self.instance_curr_score_variables[k]:
                            self.instance_curr_score_variables[k].pop(instance)
                    for k in self.instance_orig_score_variables.keys():
                        if instance in self.instance_orig_score_variables[k]:
                            self.instance_orig_score_variables[k].pop(instance)    
                
            
            has_obj, points = self.get_obj_locs(
                obj_class=obj_class,
                threshold=self.thresholds_wu[obj_class],
                use_uncertainty=True,
                return_scores=False,
                return_instances=False
            )
            if has_obj:
                # print("~~~~~~~~~~~~~~~~~HAS_OBJ WU! ", points.shape)
                # project
                points = points[:, :2]
                points = -torch.flip(points,[-1])
                px = xy_to_px_tensor(points, self.pixels_per_meter, self.map_size)

                # update semantic map
                mask = (px[:,0]<self.map_size)*(px[:,1]<self.map_size)*(px[:,0]>-self.map_size)*(px[:,1]>-self.map_size)
                px = px[mask]
                obstacle_mask = obstacle_map_tensor[px[:, 0], px[:, 1]]
                px = px[obstacle_mask] # objects are obstacles               
                self.semantic_maps_wu[obj_class][px[:, 0], px[:, 1]] = True
    
    def has_obj(self, obj_class: int):
        if obj_class == self.object_goal:
            min_cluster_size = self.min_cluster_size_go
        else:
            min_cluster_size = self.min_cluster_size
        for idx in self.instance_maps[obj_class].keys():
            if idx==0:
                continue
            mask = self.instance_maps[obj_class][idx]
            if obj_class == self.object_goal:
                mask *= self.valid_pick
            if self.use_is_visited[obj_class]:
                if (not self.instance_visited[obj_class][idx]) and torch.sum(self.semantic_maps[obj_class][mask])>min_cluster_size:
                    return True
            else:
                if torch.sum(self.semantic_maps[obj_class][mask])>min_cluster_size:
                    return True
        return False
    
    def has_obj_uncertainty_weighted(self, obj_class: int):
        if obj_class == self.object_goal:
            for idx in self.instance_maps[obj_class].keys():
                if idx==0 or idx in self.bad_instances[obj_class]:
                    continue
                mask = self.instance_maps[obj_class][idx]
                # if np.median(((self.semantic_maps_wu[obj_class])*self.valid_pick)[mask])>self.thresholds_wu[obj_class] and np.sum(self.valid_pick*mask)>=self.min_cluster_size_go:
                if torch.sum(self.semantic_maps_wu[obj_class]*self.valid_pick*mask)>=self.min_cluster_size_go:
                    # y0, y1, x0, x1 = get_bounds(mask, margin=5)
                    # if np.any(self.semantic_maps_wu[self.start_recep_goal][y0:y1, x0:x1]>self.thresholds_wu[self.start_recep_goal]):
                    #     return True
                    return True
            return False
        for idx in self.instance_maps[obj_class].keys():
            if idx==0 or (not obj_class == self.start_recep_goal and idx in self.bad_instances[obj_class]):
                continue
            mask = self.instance_maps[obj_class][idx]
            # if np.median(self.semantic_maps_wu[obj_class][mask])>self.thresholds_wu[obj_class] and np.sum(mask)>self.min_cluster_size:
            if torch.sum(self.semantic_maps_wu[obj_class]*mask)>self.min_cluster_size:
                return True
        return False
        
    def get_best_wayposes_score(self, obs, info, obj_class):
        best_score = -np.inf
        best_score_diff = -np.inf
        idx_best = None
        idx_best_diff = None
        wayposes_best = None
        wayposes_best_diff = None

        if obj_class == self.object_goal:
            min_cluster_size = self.min_cluster_size_go
        else:
            min_cluster_size = self.min_cluster_size

        if obj_class == self.start_recep_goal:
            obj_class_model = 2
        elif obj_class == self.object_goal:
            obj_class_model = 1
        elif obj_class == self.end_recep_goal:
            obj_class_model = 3
        else:
            obj_class_model = 0

        #Generate viewpoints and update EIG
        for l in self.instance_maps[obj_class].keys():
            # print("BEST WAYPOSES: ", obj_class, l, torch.sum(self.instance_maps[obj_class][l]), min_cluster_size)
            if ((not self.instance_visited[obj_class][l]) or not self.use_is_visited[obj_class]) and torch.sum(
                self.instance_maps[obj_class][l])>=min_cluster_size:
                wayposes = get_wayposes(
                    instance_mask=self.instance_maps[obj_class][l].cpu().numpy(),
                    instance_height=self.instance_heights[obj_class][l],
                    obs=obs,
                    info=info,
                    max_dist_from_instance=0.7, #0.6,
                    min_dist_from_instance=0.45, #0.4,
                    min_dist_between_waypoints=0.4,
                )
                if len(wayposes)>0:
                    pose_3d = [torch.tensor(gen_poses_to_tf_om(p, obs), device=self.device) for p in wayposes]
                    score, old_score = self.get_expected_score(pose_3d, l, obj_class=obj_class_model)
                    score_diff = score - old_score

                    if score > best_score:
                        best_score = score
                        idx_best = l
                        wayposes_best = wayposes.copy()
                    if score_diff > best_score_diff:
                        best_score_diff = score_diff
                        idx_best_diff = l
                        wayposes_best_diff = wayposes.copy()

        return idx_best, best_score, wayposes_best, idx_best_diff, best_score_diff, wayposes_best_diff

    
    def get_best_wayposes_eig(self, obs, info, obj_class, visualize_eig=False, save_dir='', debug_info=False, use_eig=True):
        best_score = -np.inf
        idx_best = None
        wayposes_best = None

        if obj_class == self.object_goal:
            min_cluster_size = self.min_cluster_size_go
        else:
            min_cluster_size = self.min_cluster_size

        if obj_class == self.start_recep_goal:
            obj_class_model = 2
        elif obj_class == self.object_goal:
            obj_class_model = 1
        elif obj_class == self.end_recep_goal:
            obj_class_model = 3
        else:
            obj_class_model = 0

        instance_masks = []
        scores = []
        wayposes_l = []

        #Generate viewpoints and update EIG
        for l in self.instance_maps[obj_class].keys():
            # print("BEST WAYPOSES: ", obj_class, l, torch.sum(self.instance_maps[obj_class][l]), min_cluster_size)
            if (not self.instance_visited[obj_class][l] or not self.use_is_visited[obj_class]) and torch.sum(
                self.instance_maps[obj_class][l])>=min_cluster_size:
                try:
                    wayposes = get_wayposes(
                        instance_mask=self.instance_maps[obj_class][l].cpu().numpy(),
                        instance_height=self.instance_heights[obj_class][l],
                        obs=obs,
                        info=info,
                        max_dist_from_instance=0.7, #0.6,
                        min_dist_from_instance=0.45, #0.4,
                        min_dist_between_waypoints=0.4,
                    )
                except: # Exception as e:
                    # print("Exception: ", e)
                    wayposes = []
                # print("BEST WAYPOSES: ", len(wayposes))
                if len(wayposes)>0:
                    if use_eig:
                        pose_3d = [torch.tensor(gen_poses_to_tf_om(p[:3], obs), device=self.device) for p in wayposes]

                        eig, viz_images = self.get_eig_path(pose_3d, l, visualize=visualize_eig, obj_class=obj_class_model)

                        # print("BEST WAYPOSES: ", len(wayposes), eig)

                        if visualize_eig:
                            plt.imsave(f"{save_dir}/{info['timestep']}_{l}_visited={self.instance_visited[obj_class][l]}.png", np.hstack(viz_images))

                        if obj_class == self.start_recep_goal and not (l in self.instance_orig_score_variables['eig'].keys()):
                            self.instance_orig_score_variables['eig'][l] = eig
                        self.instance_curr_score_variables[f'eig_{obj_class}'][l] = eig
                    else:
                        self.instance_curr_score_variables[f'eig_{obj_class}'][l] = 1
                else:
                    # print("NO WAYPOSES!")
                    self.instance_curr_score_variables[f'eig_{obj_class}'][l] = -np.inf
                
                if obj_class == self.start_recep_goal:
                    score = self.score_srec(l, obs, info)
                else:
                    score = self.instance_curr_score_variables[f'eig_{obj_class}'][l] 

                if score > best_score:
                    best_score = score
                    idx_best = l
                    wayposes_best = wayposes.copy()
                if debug_info:
                    instance_masks += [self.instance_maps[obj_class][l]]
                    scores += [score]
                    wayposes_l += [wayposes.copy()]

        if debug_info:
            return idx_best, best_score, wayposes_best, instance_masks, scores, wayposes_l
        else:
            return idx_best, best_score, wayposes_best

    def score_srec(self, idx, obs, info):
        if not (idx in self.instance_curr_score_variables['center'].keys()):
            return -np.inf
        if self.instance_visited[self.start_recep_goal][idx] and self.use_is_visited[self.start_recep_goal]:
            # print("IS VISITED: ", idx)
            return -np.inf
        if torch.sum(self.instance_maps[self.start_recep_goal][idx])<self.min_cluster_size:
            return -np.inf
        robot_xy = np.array([obs.gps[1], -obs.gps[0]])
        dist_term = info['dist_scale']*np.linalg.norm(self.instance_curr_score_variables['center'][idx]-robot_xy)

        if (not idx in self.instance_curr_score_variables[f'score_{self.start_recep_goal}'].keys()):
            # print("NO SCORE")
            return -np.inf

        if not idx in self.instance_curr_score_variables[f'eig_{self.start_recep_goal}'].keys():
            return 0 #TODO: may want to get wayposes and update in this case?

        return self.instance_curr_score_variables[f'score_{self.start_recep_goal}'][idx]*self.instance_curr_score_variables[f'eig_{self.start_recep_goal}'][idx]-dist_term
    
    def score_frontiers(self, obs, info, frontiers, values_f, prev_chosen_frontiers):
        scores_f = []
        robot_xy = np.array([obs.gps[1], -obs.gps[0]])
        distances_f = np.array([np.linalg.norm(pt - robot_xy) for pt in frontiers])

        if len(self.instance_orig_score_variables['center'].keys())>0 and len(prev_chosen_frontiers)>0:
            sum_exp = 0.0
            for loc, score in prev_chosen_frontiers:
                exp_score_max = 0.0
                for i in self.instance_orig_score_variables['center'].keys():
                    orig_eig = 0.0
                    if i in self.instance_orig_score_variables[f'eig'].keys():
                        orig_eig = self.instance_orig_score_variables[f'eig'][i]

                    exp_score = self.instance_orig_score_variables['score'][i]*orig_eig-info['dist_scale']*np.linalg.norm(self.instance_orig_score_variables['center'][i]-loc)
                    if exp_score > exp_score_max:
                        exp_score_max = exp_score
                sum_exp += exp_score_max/score

            V0 = sum_exp/len(prev_chosen_frontiers)
        else: 
            V0 = 5.0
        for i in range(len(frontiers)):
            # print("FRONTIER: ", V0, values_f[i], info['dist_scale']*distances_f[i])
            scores_f += [V0*values_f[i]-info['dist_scale']*distances_f[i]]

        idx_best_f = np.argmax(scores_f) 
        return idx_best_f, scores_f[idx_best_f], scores_f

    def get_semantic_map(self, obj_class: int):
        return self.semantic_maps[obj_class].cpu().numpy()
    
    def get_instance_map(self, obj_class: int, instance: int):
        return self.instance_maps[obj_class][instance].cpu().numpy()
    
    def mark_instance_visited(self, obj_class: int, instance: int):
        self.instance_visited[obj_class][instance] = True

    def remove_invalid_pick(self, obj_idx):
        #get closest instance which is not marked as invalid
        obj_class = self.object_goal

        if obj_idx in self.instance_maps[obj_class].keys():
            self.valid_pick[self.instance_maps[obj_class][obj_idx]] = False
            self.instance_visited[obj_class][obj_idx] = True

    @nvtx.annotate("GlobalObjectiveMap.visualize")
    def visualize(self) -> np.ndarray:
        """Visualizes the map."""
        y0, y1, x0, x1 = self.bounds
        vis_img = np.ones((y1 - y0, x1 - x0, 3), dtype=np.uint8) * 255
        vis_img[self.obstacle_map[y0:y1, x0:x1]] = (0, 0, 0)
        try:
            vis_img[self.semantic_maps[self.object_goal][y0:y1, x0:x1].cpu().numpy()] = d3_40_colors_rgb[1]
            vis_img[self.semantic_maps[self.end_recep_goal][y0:y1, x0:x1].cpu().numpy()] = d3_40_colors_rgb[2]
            vis_img[self.semantic_maps[self.end_recep_goal][y0:y1, x0:x1].cpu().numpy()] = d3_40_colors_rgb[3]
        except:
            pass
        vis_img = vis_img[::-1]

        return vis_img
