from home_robot.agent.ovmm_agent.ovmm_agent import Skill
from home_robot.core.interfaces import DiscreteNavigationAction, Observations, ContinuousFullBodyAction
import numpy as np
from helios.agent.planner.fm2_planner import get_times
from helios.agent.utils.utils import *
from typing import Any, Dict
import nvtx
import matplotlib.pyplot as plt

from helios.agent.gaze.object_state_maps.global_objective_map import GlobalObjectiveMap
from helios.agent.nav.vlfm_nav import watch_out_for_close_obstacles

import cv2
import os
from helios.agent.gaze.utils.utils import calculate_instance_heightmap

from helios.agent.gaze.utils.utils import get_waypose

class HierarchicalGaze:
    def __init__(
        self,
        object_state_map: GlobalObjectiveMap,
        num_gaze_points: int,
        gaze_radius: float,
        cluster_radius: float,
        pixels_per_meter: float,
        pause: bool,
        pause_time: int,
        stuck_time: int,
        max_steps_per_dist: int,
        max_steps_per_view: int,
        threshold_sr: float,
        threshold_go: float,
        threshold_er: float,
        visualize_eig: bool,
        visualize_3dgs: bool,
        max_steps: int,
        alpha_d: float,
        use_global_objective: bool,
        simple_version: bool,
        use_dist_scale: bool
    ):
        self.object_state_map = object_state_map

        self.num_gaze_points = num_gaze_points
        self.gaze_radius = gaze_radius
        self.cluster_radius = cluster_radius
        self.pixels_per_meter = pixels_per_meter
        self.pause = pause
        self.pause_time = pause_time
        self.stuck_time = stuck_time
        self.max_steps_per_dist = max_steps_per_dist
        self.max_steps_per_view = max_steps_per_view
        self.threshold_sr = threshold_sr
        self.threshold_go = threshold_go
        self.threshold_er = threshold_er
        self.thresholds = {}
        self.visualize_eig = visualize_eig
        self.windowed_goal_map = {}
        self.visualize_3dgs = visualize_3dgs

        self.steps_local_search = 0
        self.prev_chosen_frontiers = []
        self.max_steps = max_steps
        self.a_d = alpha_d

        self.gaze_target = None
        self.wayposes = None
        self.i_waypose = None
        self.wayposes_idx = None
        self.end_on_gaze = False

        self.n_steps_stuck = 0
        self.max_steps_stuck = 20

        self.threshold_score_diff = 0.05

        self.before_frontier = True

        self.has_obj_return_false = False

        self.use_global_objective = use_global_objective
        self.simple_version = simple_version

        self.use_dist_scale = use_dist_scale

    @nvtx.annotate("HierarchicalGaze.reset()")
    def reset(self):
        self.object_state_map.reset()
        self.pause_counter = 0
        self.last_checked_pose = None
        self.start_recep_pos = None
        self.poses_around_start_recep = []
        self.curr_pose_index = 0
        self.windowed_goal_map = {}

        self.steps_local_search = 0
        self.prev_chosen_frontiers = []

        self.gaze_target = None
        self.wayposes = None
        self.i_waypose = None
        self.wayposes_idx = None
        self.end_on_gaze = False

        self.n_steps_stuck = 0

        self.before_frontier = True

        self.has_obj_return_false = False

    def set_objectnav_method(self, objnav):
        self.objnav = objnav

    def set_episode_key(self, episode_key):
        self.object_state_map.set_episode_key(episode_key=episode_key)

    def set_viz_save_dir(self, save_dir):
        self.viz_save_dir = save_dir
        os.makedirs(self.viz_save_dir,exist_ok=True)

    def set_map_interface(self, map_interface):
        self.map_interface = map_interface

    @nvtx.annotate("HierarchicalGaze.act()")
    def act(
        self,
        obs: Observations,
        info: Dict[str, Any],
        planner: Any
    ):
        # if self.steps_local_search >0: #continue local search
        if (not (self.wayposes is None) or self.end_on_gaze) and (not self.gaze_target is None):
            return self.local_search(obs,info,planner)
        else:
            #perform global search
            goal_type, info = self.global_search(obs,info)
            if goal_type == 'frontier':
                return None, info, Skill.NAV_TO_OBJ
            else:
                return self.local_search(obs,info,planner)

    def global_search(
        self,
        obs: Observations,
        info: Dict[str, Any],
    ):
        self.end_on_gaze = False
        if self.use_dist_scale:
            info['dist_scale'] = self.a_d*(self.max_steps+1)/(self.max_steps-info['timestep']+1)
        else:
            info['dist_scale'] = self.a_d
        info['gaze_radius'] = self.gaze_radius
        self.wayposes_idx = None

        best_score_c = -np.inf
        
        if self.gaze_target == obs.task_observations['end_recep_goal']:
            if (not self.simple_version) and self.has_obj_nc(obs.task_observations['end_recep_goal']):
                idx_best_c, best_score_c, wayposes_best_c, idx_best_c_diff, best_score_c_diff, wayposes_best_c_diff \
                    = self.object_state_map.get_best_wayposes_score(obs,info,obs.task_observations['end_recep_goal'])
                if best_score_c > self.object_state_map.threshold_er:
                    self.i_waypose = 0
                    self.wayposes = wayposes_best_c
                    self.wayposes_idx = idx_best_c
                    return ('end_recep', info)
                elif best_score_c_diff > self.threshold_score_diff:
                    self.i_waypose = 0
                    self.wayposes = wayposes_best_c_diff
                    self.wayposes_idx = idx_best_c_diff
                    return ('end_recep', info)
            else:
                frontiers = self.objnav.map_policy.obstacle_map.frontiers
                values_f = np.array([self.objnav.map_policy.get_value(pt, 1) for pt in frontiers])
                idx_best_f = np.argmax(values_f)
                self.objnav.chosen_frontier = frontiers[idx_best_f]
                return ('frontier', info)
        else:
            #candidate target objects
            if (not self.simple_version) and self.has_obj_nc(obs.task_observations['object_goal']):
                idx_best_o, best_score_o, wayposes_best_o, idx_best_o_diff, best_score_o_diff, wayposes_best_o_diff \
                    = self.object_state_map.get_best_wayposes_score(obs,info,obs.task_observations['object_goal'])
                # print("BEST SCORE OBJ: ", best_score_o, best_score_o_diff)
                if best_score_o > self.object_state_map.threshold_go_wu:
                    self.i_waypose = 0
                    self.wayposes = wayposes_best_o
                    self.wayposes_idx = idx_best_o
                    self.gaze_target = obs.task_observations['object_goal']
                    # print("SETTING obj_goal wayposes")
                    return ('goal_obj', info)
                elif best_score_o_diff > self.threshold_score_diff:
                    self.i_waypose = 0
                    self.wayposes = wayposes_best_o_diff
                    self.wayposes_idx = idx_best_o_diff
                    self.gaze_target = obs.task_observations['object_goal']
                    # print("SETTING obj_goal wayposes")
                    return ('goal_obj', info)
        
            # candidate start receptacles
            if (not self.simple_version) and self.has_obj(obs.task_observations['start_recep_goal']):
                idx_best_c, best_score_c, wayposes_best_c, instance_masks_c, scores_c, wayposes_c = self.object_state_map.get_best_wayposes_eig(
                    obs,info,obs.task_observations['start_recep_goal'],self.visualize_eig,self.viz_save_dir,debug_info=True,use_eig=self.use_global_objective)
            else:
                # print("HAS_OBJ returned false")
                idx_best_c, best_score_c, wayposes_best_c, instance_masks_c, scores_c, wayposes_c  = (None, -np.inf, None, [], [], [])

        if self.simple_version or self.use_global_objective or (best_score_c==-np.inf):
            # candidate frontiers
            frontiers = self.objnav.map_policy.obstacle_map.frontiers
            if len(frontiers)>0:
                values_f =  np.array([self.objnav.map_policy.get_value(pt, 0) for pt in frontiers])
                if self.use_global_objective:
                    idx_best_f, best_score_f, scores_f = self.object_state_map.score_frontiers(obs,info,frontiers,values_f,self.prev_chosen_frontiers)
                else:
                    idx_best_f = np.argmax(values_f)
                    best_score_f = values_f[idx_best_f]
                    scores_f = []
            else:
                idx_best_f = None
                best_score_f = -np.inf
                scores_f = []
        else:
            idx_best_f = None
            best_score_f = -np.inf
            scores_f = []

        # print("GLOBAL SEARCH SCORES (c, f): ", best_score_c, best_score_f)
        # self.visualize_global(info, frontiers, scores_f, instance_masks_c, scores_c, wayposes_c)

        if (not idx_best_c is None) and (not wayposes_best_c is None) and (best_score_c >= best_score_f):
            self.i_waypose = 0
            self.wayposes = wayposes_best_c
            self.wayposes_idx = idx_best_c
            if not self.gaze_target == obs.task_observations['end_recep_goal']:
                self.gaze_target = obs.task_observations['start_recep_goal']
            # print("SETTING wayposes")
            return ('start_rec', info)
        elif (not idx_best_f is None) and (best_score_f >= best_score_c):
            self.objnav.chosen_frontier = frontiers[idx_best_f]
            self.prev_chosen_frontiers += [(frontiers[idx_best_f],values_f[idx_best_f])]
            self.before_frontier = False
            # print("SETTING frontier")
            return ('frontier', info)
        else:
            # raise RuntimeError(f'global_search() no valid candidate')
            # No frontiers or start candidates, go back to exploring
            self.before_frontier = False
            # print("SETTING nav")
            return ('frontier', info)
    
    def set_best_waypose(self, obs, info, no_valid=False):
        obj_class = self.gaze_target
        if not self.object_state_map.has_obj(obj_class):
            return False

        best_score = 0
        best_idx = -1

        triggered_is_valid = False

        if self.simple_version:
            itlist = np.unique(self.object_state_map.instance_maps[obj_class])
        else:
            itlist = self.object_state_map.instance_maps[obj_class].keys()
        for idx in itlist:
            if idx==0 or (idx in self.object_state_map.bad_instances[obj_class]):
                continue
            if self.simple_version:
                mask = self.object_state_map.instance_maps[obj_class]==idx
                mask_size = np.sum(mask)
            else:
                mask = self.object_state_map.instance_maps[obj_class][idx]
                mask_size = torch.sum(mask).item()
            if obj_class==self.object_state_map.object_goal:
                # y0, y1, x0, x1 = get_bounds(mask, margin=5)
                # if not np.any(self.score_maps_wu[self.start_recep_goal][y0:y1, x0:x1]>self.thresholds_wu[self.start_recep_goal]):
                #     continue
                if self.simple_version:
                    n_over = np.sum((self.object_state_map.map==obj_class)*self.object_state_map.valid_pick*mask)
                else:
                    n_over = torch.sum(self.object_state_map.semantic_maps_wu[obj_class]*self.object_state_map.valid_pick*mask).item()
                if n_over >= self.object_state_map.min_cluster_size_go:
                    score = n_over/mask_size
                    # print("SCORE: ", idx, score, n_over, mask_size)
                else: 
                    score = 0
            else:
                if self.simple_version:
                    n_over = np.sum((self.object_state_map.map==obj_class)*mask)
                else:
                    n_over = torch.sum(self.object_state_map.semantic_maps_wu[obj_class]*mask).item()
                if n_over >= self.object_state_map.min_cluster_size:
                    score = n_over/mask_size
                else: 
                    score = 0
            if obj_class==self.object_state_map.end_recep_goal:
                if not self.simple_version:
                    mask = mask.cpu().numpy()
                isvalid, good_mask = calculate_instance_heightmap(mask, info["height_map"])
                if isvalid:
                    triggered_is_valid = True
                elif no_valid:
                    good_mask = mask
                if (isvalid or no_valid) and score > best_score:
                    try:
                        if self.simple_version:
                            self.wayposes = [get_waypose(
                                semantic_mask=good_mask,
                                instance_height=0,
                                obs=obs,
                                info=info,
                                min_goal_navigable_dist=None,
                                min_goal_waypose_dist=0.5,
                                max_goal_waypose_dist=0.7,
                                height_map = info['height_map']
                            )]
                        else:
                            self.wayposes = [get_waypose(
                            semantic_mask=good_mask,
                            instance_height=self.object_state_map.instance_heights[self.gaze_target][idx],
                            obs=obs,
                            info=info,
                            min_goal_navigable_dist=None,
                            min_goal_waypose_dist=0.5,
                            max_goal_waypose_dist=0.7
                            )]
                        best_score = score
                        best_idx = idx
                        # self.mask = mask
                    except Exception as e:
                        # print(e)
                        pass

            elif score > best_score:
                if self.gaze_target == obs.task_observations['object_goal']:
                    mask*= self.object_state_map.valid_pick
                # self.mask = mask.cpu().numpy()
                try:
                    if self.simple_version:
                            self.wayposes = [get_waypose(
                                semantic_mask=mask,
                                instance_height=0,
                                obs=obs,
                                info=info,
                                min_goal_navigable_dist=None,
                                min_goal_waypose_dist=0.5,
                                max_goal_waypose_dist=0.7,
                                height_map = info['height_map']
                            )]
                    else:
                        self.wayposes = [get_waypose(
                        semantic_mask=mask.cpu().numpy(),
                        instance_height=self.object_state_map.instance_heights[self.gaze_target][idx],
                        obs=obs,
                        info=info,
                        min_goal_navigable_dist=None,
                        min_goal_waypose_dist=0.5,
                        max_goal_waypose_dist=0.7
                        )]
                    best_score = score
                    best_idx = idx
                except Exception as e:
                    # print(e)
                    pass
        if best_idx == -1:
            if (obj_class==self.object_state_map.end_recep_goal) and (not no_valid) and (not triggered_is_valid):
                return self.set_best_waypose(obs, info, no_valid=True)
            return False
        self.wayposes_idx = best_idx
        # print("CHOSEN: ", self.wayposes_idx)
        return True
    
    def local_search(
        self,
        obs: Observations,
        info: Dict[str, Any],
        planner: Any
    ):
        # print("Local target: ", ("goal" if self.gaze_target == obs.task_observations['object_goal']
        #     else "start rec" if self.gaze_target == obs.task_observations['start_recep_goal']
        #     else "end rec" if self.gaze_target == obs.task_observations['end_recep_goal']
        #     else self.gaze_target))
        assert self.gaze_target is not None, '[MergedGaze.act()] gaze_target is None'
        if self.wayposes is None:
            self.i_waypose = 0
            if not self.end_on_gaze:
                raise RuntimeError(f'local_search() called for exploration without setting wayposes')
            else:
                if not self.set_best_waypose(obs, info):
                    # if instance in self.object_state_map.instance_visited[self.gaze_target].keys():
                    #     self.object_state_map.instance_visited[self.gaze_target][instance]=True
                    # self.object_state_map.bad_instances[self.gaze_target] += [instance]
                    next_skill = (
                        Skill.NAV_TO_REC
                        if self.gaze_target == obs.task_observations['end_recep_goal'] else
                        Skill.NAV_TO_OBJ
                    )
                    self.has_obj_return_false = True
                    self.gaze_target = None
                    self.wayposes = None
                    self.i_waypose = None
                    self.n_steps_stuck = 0
                    self.end_on_gaze = False
                    # print("LOCAL SEARCH no wayposes")
                    return None, info, next_skill


        if self.i_waypose < len(self.wayposes):
            px_goal = gps_to_px(self.wayposes[self.i_waypose][:2], info['pixels_per_meter'], info['map_size'])
            if not info['navigable_map'][px_goal[0], px_goal[1]]:
                # print('waypoint not navigable', self.wayposes[self.i_waypose])
                self.n_steps_stuck = 0
                self.i_waypose += 1
                self.end_on_gaze = False
                return self.local_search(obs, info, planner)
            if np.linalg.norm((self.wayposes[self.i_waypose][:2]-np.array([-obs.gps[0], obs.gps[1]])))<2.0:
                self.n_steps_stuck += 1
            else:
                self.n_steps_stuck = 0
            if self.n_steps_stuck > self.max_steps_stuck:
                self.n_steps_stuck = 0
                self.i_waypose += 1
                self.end_on_gaze = False
                # print("LOCAL SEARCH timeout")
                return self.local_search(obs, info, planner)

            action = planner.plan_to_pose(
                obstacle_map=info['dilated_obstacle_map'],
                seen_map=info['seen_map'],
                xyt_start=np.array([obs.gps[0], obs.gps[1], obs.compass[0]]),
                xyt_goal=self.wayposes[self.i_waypose][:3],
                timestep=info['timestep'], # just for visualization
                max_turn_degrees=180 if self.i_waypose > 0 else 30,
            )
            if action == DiscreteNavigationAction.STOP:
                # print("LOCAL SEARCH stop action")
                self.n_steps_stuck = 0
                self.i_waypose += 1
                # print("i_WAYPOSE: ", self.i_waypose)

                joints = np.zeros(10)
                # joints[9] = (-np.pi / 6) - obs.joint[9]
                joints[9] = self.wayposes[self.i_waypose-1][3] - obs.joint[9]
                xyt = np.zeros(3)
                action = ContinuousFullBodyAction(
                    joints=joints,
                    xyt=xyt,
                )
                return action, info, None
            if self.i_waypose == 0 and not action is None: #look up
                joints = np.zeros(10)
                # joints[9] = (-np.pi / 12) - obs.joint[9] 
                joints[9] = self.wayposes[self.i_waypose][3] - obs.joint[9]
                action = ContinuousFullBodyAction(
                    joints=joints,
                    xyt=action.xyt,
                )
            action = watch_out_for_close_obstacles(action, obs, info)
            return action, info, None
        else:
            if self.gaze_target == obs.task_observations['start_recep_goal']:
                next_skill = Skill.NAV_TO_OBJ
            elif self.end_on_gaze:
                next_skill = (
                    Skill.PICK
                    if self.gaze_target == obs.task_observations['object_goal'] else
                    Skill.PLACE
                )
            else:
                next_skill = (
                    Skill.NAV_TO_OBJ
                    if self.gaze_target == obs.task_observations['object_goal'] else
                    Skill.NAV_TO_REC
                )
            if not self.end_on_gaze:
                # print("SETTING VISITED: ", self.wayposes_idx)
                # if self.gaze_target != obs.task_observations['object_goal']:
                self.object_state_map.instance_visited[self.gaze_target][self.wayposes_idx] = True
                self.wayposes_idx = None
            self.gaze_target = None
            self.wayposes = None
            self.i_waypose = None
            self.n_steps_stuck = 0
            # print("LOCAL SEARCH change skill: ", next_skill, self.end_on_gaze)
            self.has_obj_return_false = False
            return None, info, next_skill
    
    @nvtx.annotate("HierarchicalGaze.update()")
    def update(
        self,
        obs: Observations,
        info: Dict[str, Any],
    ):
        if not self.thresholds:
            self.thresholds[obs.task_observations['start_recep_goal']] = self.threshold_sr
            self.thresholds[obs.task_observations['object_goal']] = self.threshold_go
            self.thresholds[obs.task_observations['end_recep_goal']] = self.threshold_er
        self.object_state_map.update(obs,info)
        info['merged_object_map'] = self.object_state_map

    def has_obj(self, obj_class: int):
        if obj_class == self.object_state_map.start_recep_goal:
            return self.object_state_map.has_obj(obj_class)
        else:
            if self.has_obj_return_false:
                print("HAS_OBJ return false")
                return False
            return self.object_state_map.has_obj_uncertainty_weighted(obj_class)
        
    def has_obj_nc(self, obj_class: int):
        return self.object_state_map.has_obj(obj_class)
    
    def visualize_global(self, info, frontiers, frontier_scores, start_recs, start_rec_scores, wayposes):
        if len(frontier_scores)>0:
            if len(start_rec_scores)>0:
                max_score = max(np.max(start_rec_scores), np.max(frontier_scores))
                min_score = min(np.min(start_rec_scores), np.min(frontier_scores))
            else:
                max_score = np.max(frontier_scores)
                min_score = np.min(frontier_scores)
        elif len(start_rec_scores)>0:
            max_score = np.max(start_rec_scores)
            min_score = np.min(start_rec_scores)
        else:
            return
        print("SCORES: ", min_score, max_score)
        min_score -= 0.1
        locs = np.ones((self.object_state_map.map_size,self.object_state_map.map_size,3), np.uint8)*255
        scores = np.zeros((self.object_state_map.map_size,self.object_state_map.map_size,3), np.uint8)
        
        for i in range(len(start_recs)):
            locs[start_recs[i]] = (0,0,255)
            scores[start_recs[i]] = 255*((start_rec_scores[i]-min_score)/(max_score-min_score))
            # for j in range(len(wayposes[i])):
            #     print("WAYPOSES: ", wayposes[i][j])
            #     pt = gps_to_px(wayposes[i][j][[0,1]], info['pixels_per_meter'], self.object_state_map.map_size)
            #     cv2.circle(locs, pt, radius=3, color=(255,0,0), thickness=-1)
        
        for i in range(len(frontiers)):
            fpt = xy_to_px(frontiers[i], info['pixels_per_meter'], self.object_state_map.map_size)
            print("PT: ", fpt)
            cv2.circle(locs, fpt, radius=3, color=(0,255,0), thickness=-1)
            cv2.circle(scores, fpt, radius=3, color=tuple(
                255*((frontier_scores[i]-min_score)/(max_score-min_score))*np.ones(3)), 
                thickness=-1)

        mask = np.all(scores==0,2)
        scores = cv2.applyColorMap(scores, cv2.COLORMAP_INFERNO)
        print("MASK: ", np.sum(mask))
        for i in range(3):
            scores[:,:,i][mask] = 255

        t = info['timestep']
        plt.imsave(f"global_vis/{t}_locs.png",locs)
        plt.imsave(f"global_vis/{t}_scores.png",scores)

def check_gaze_pose(
    gaze_pose: np.ndarray,
    obj_pose: np.ndarray,
    obs: Observations,
    info: Dict[str, Any],
):
    # check if wall is in the way
    xyt_goal = gaze_pose
    px_goal = gps_to_px(xyt_goal[:2], info['pixels_per_meter'], info['map_size'])
    gps_object = obj_pose 
    px_object = gps_to_px(gps_object, info['pixels_per_meter'], info['map_size'])
    wall_window_size = int(math.ceil(info['gaze_radius'] * 4 * info['pixels_per_meter'])) + 1
    y0, x0 = np.clip(px_object - wall_window_size // 2, 0, info['map_size'] - wall_window_size)
    y1, x1 = y0 + wall_window_size + 1, x0 + wall_window_size + 1
    wall_map = info['wall_map'][y0:y1, x0:x1]
    wall_times = get_times(wall_map, (wall_window_size // 2, wall_window_size // 2,), None, dx=1/info['pixels_per_meter'])

    if wall_times[tuple(px_goal - np.array([y0, x0]))] > info['gaze_radius'] * 1.5:
        return False
    elif not info['navigable_map'][tuple(px_goal)]:
        return False
    return True