from home_robot.core.interfaces import DiscreteNavigationAction, Observations
import numpy as np
import torch
from helios.agent.gaze.gaussian_representation.slam import SLAM
from typing import Any, Dict, Tuple, List
import nvtx

from helios.agent.utils.utils import obs_to_tf

import time

from helios.agent.utils.visualization import apply_mask
import cv2
from helios.agent.utils.utils import get_bounds
from home_robot.perception.constants import d3_40_colors_rgb

class GsObjectMap:
    def __init__(
        self,
        img_size: Tuple[int, int],
        hfov: float,
        model: SLAM,
        save_rep: bool,
        semantic_threshold_for_go: float,
        semantic_threshold_for_sr: float,
        semantic_threshold_for_er: float,
        min_depth: float = 0.0,
        max_depth: float = 10.0
    ):
        self.img_size = img_size
        self.K = np.array([
            [(img_size[1] / 2) / np.tan(np.radians(hfov) / 2.0), 0.0, img_size[1] / 2, 0.0],
            [0.0, (img_size[1] / 2) / np.tan(np.radians(hfov) / 2.0), img_size[0] / 2, 0.0],
            [0.0, 0.0, 1, 0],
            [0.0, 0.0, 0, 1],
        ])
        self.model = model
        self.initialized = False
        self.device = "cuda"
        self.save_rep = save_rep
        self.semantic_threshold_for_go = semantic_threshold_for_go
        self.semantic_threshold_for_sr = semantic_threshold_for_sr
        self.semantic_threshold_for_er = semantic_threshold_for_er
        self.min_depth = min_depth
        self.max_depth = max_depth
        self.episode_key = None

        self.class_mapping = {}

        self.did_update_3gds = False

    def reset_vectorized(self):
        self.reset()

    def set_episode_key(self, episode_key):
        self.episode_key = episode_key

    def reset(self):
        self.model.reset() # just reset no matter init or not
        self.initialized = False
        self.class_mapping = {}
        self.did_update_3gds = False

    def update(
        self,
        obs: Observations,
        info: Dict[str, Any],
    ):
        self.did_update_3gds = False
        instances = []
        if 'instance_map' in obs.task_observations.keys():
            semantic_c = np.zeros((
                    obs.task_observations['instance_map'].shape[0],
                    obs.task_observations['instance_map'].shape[1],
                    self.model.n_semantic_channels,
            ))
            for i_instance, (instance_class, instance_score) in enumerate(zip(obs.task_observations['instance_classes'], obs.task_observations['instance_scores'])):
                if instance_class == 4 or instance_class==0:
                    continue
                if instance_class == 1:
                    if instance_score < self.semantic_threshold_for_go:
                        continue
                elif instance_class == 2:
                    if instance_score < self.semantic_threshold_for_sr:
                        continue
                elif instance_class == 3:
                    if instance_score < self.semantic_threshold_for_er:
                        continue

                semantic_c[:, :, instance_class][obs.task_observations['instance_map'] == i_instance] = 1 #instance_score 
                instances += [(instance_class, torch.tensor(obs.task_observations['instance_map'] == i_instance,device=self.device))]
        else:
            semantic_c = np.zeros((
                    obs.semantic.shape[0],
                    obs.semantic.shape[1],
                    self.model.n_semantic_channels,
            ))
            # ["misc", "object_category", "start_receptacle", "goal_receptacle", "others"]
            semantic_c[:,:,2] = obs.semantic == obs.task_observations['start_recep_goal']
            semantic_c[:,:,1] = obs.semantic == obs.task_observations['object_goal']
            semantic_c[:,:,3] = obs.semantic == obs.task_observations['end_recep_goal']

        self.class_mapping[obs.task_observations['object_goal']]=1
        self.class_mapping[obs.task_observations['start_recep_goal']]=2
        self.class_mapping[obs.task_observations['end_recep_goal']]=3

        where_sem = np.sum(semantic_c[:,:,1:]>1e-2, axis=2)>0
        semantic_c[:,:,0][~where_sem] = 1

        #Set already seen parts which are not detected as an object to have a score of 0.5 as misc
        not_overlapping = 0
        if self.initialized:
            where_seen = self.get_where_seen(obs, remove_bad_views=True)

            if np.sum(where_seen)==0: #set not_overlapping to update threshold as there are no Gaussians present
                not_overlapping = 200
            else:
                not_overlapping = np.sum(where_sem*(~where_seen))
            where_update = where_sem + where_seen
        else:
            where_update = where_sem

        if np.sum(where_update) == 0:
            return 
        
        self.did_update_3gds = True
        
        transform = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
        tf_camera_to_episodic = obs_to_tf(obs)

        pose = tf_camera_to_episodic @ transform

        data = (
            torch.tensor(obs.rgb, device=self.device, dtype=torch.float),
            torch.tensor(obs.depth, device=self.device, dtype=torch.float).unsqueeze(-1),
            torch.tensor(semantic_c, device=self.device, dtype=torch.float),
            torch.tensor(self.K, device=self.device, dtype=torch.float),
            torch.tensor(pose, device=self.device, dtype=torch.float),
            torch.tensor(where_update, device=self.device, dtype=torch.bool),
            instances
        )

        # #Save data for making figures
        # folder = "save_data_for_gs"
        # t = info['timestep']
        # np.savez(f'{folder}/{t}.npz', rgb=obs.rgb, depth=obs.depth, semantic_c=semantic_c, K=self.K, pose=pose, where_update=where_update)

        # print("TIME 3DGS get data: ", time.time()-t0)
        if not self.initialized:
            self.model.first_frame(data)
            self.initialized = True
        else:
            self.model.step(data, not_overlapping=not_overlapping)
            if self.save_rep:
                self.model.save_model(self.episode_key)

    def has_obj(self, obj_class: int, threshold: float):
        if not self.initialized:
            return False
        return self.model.has_obj(self.class_mapping[obj_class], threshold).detach().cpu().item()

    def get_obj_locs(
        self,
        obj_class: int,
        threshold: float,
        use_uncertainty: bool,
        return_scores: bool,
        return_instances: bool
    ):
        if self.initialized:
            return self.model.get_obj_locs(
                self.class_mapping[obj_class],
                threshold,
                use_uncertainty,
                return_scores,
                return_instances
            )
        else:
            if return_scores:
                if return_instances:
                    return False, None, None, None
                else:
                    return False, None, None
            elif return_instances:
                return False, None, None
            else:
                return False, None

    def get_renderings(self, obs: Observations):
        if self.initialized:
            transform = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
            tf_camera_to_episodic = obs_to_tf(obs)

            pose = np.linalg.inv(tf_camera_to_episodic @ transform)

            rgb, semantic_c, uncertainty_c, instances, depth = self.model.get_renders(pose, return_instances=True)

            # semantic color
            semantic_masks = [semantic_c[:,:,self.class_mapping[cs]]>= self.thresholds[cs] for cs in self.class_mapping.keys()]
            semantic = apply_mask(rgb.copy(), semantic_masks)

            # instance color
            instance_uniques, instance_counts = np.unique(instances, return_counts=True)
            if instance_uniques[0] == 0:
                instance_uniques, instance_counts = instance_uniques[1:], instance_counts[1:]
            instance_visual = np.zeros_like(rgb)
            for i_instance, instance in enumerate(instance_uniques):
                if i_instance >= 40:
                    break
                instance_mask = instances == instance
                instance_visual[instance_mask] = rgb[instance_mask] // 2 + d3_40_colors_rgb[i_instance] // 2

            where_seen = self.get_where_seen(obs)
            render_where_seen = rgb.copy()

            channel = np.argmax(semantic_c, axis=-1)
            uncertainty = np.zeros((semantic.shape[0],semantic.shape[1]))
            for c in range(4):
                uncertainty[channel == c] = uncertainty_c[:,:,c][channel == c]
            uncertainty = cv2.applyColorMap((255*5*uncertainty).astype(np.uint8), cv2.COLORMAP_OCEAN)

            for i in range(3):
                render_where_seen[:,:,i] *= where_seen
                semantic[:,:,i] *= where_seen
                uncertainty[:,:,i] *= where_seen
                semantic_c[:,:,i+1] *= where_seen
                instance_visual[:,:,i] *= where_seen

            kernel = np.ones((50,50), np.uint8)

            # instance label
            for i_instance, instance in enumerate(instance_uniques):
                if i_instance >= 40:
                    break
                instance_mask = instances == instance
                if np.any(instance_mask*where_seen):
                    y0, y1, x0, x1 = get_bounds(instance_mask, 0)
                    ym, xm = (y0 + y1) // 2 - 10, (x0 + x1) // 2 - 20
                    ym, xm = np.clip(ym, 0, rgb.shape[0] - 80), np.clip(xm, 0, rgb.shape[1] - 100)
                    cv2.putText(instance_visual, f'{instance}', (xm, ym), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 4, cv2.LINE_AA)
                    cv2.putText(instance_visual, f'{instance}', (xm ,ym), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
            alpha = self.model.config.objective.alpha_uncertainty
            # semantic label
            for cs in self.class_mapping.keys():
                c = self.class_mapping[cs]
                if np.sum(semantic_c[:,:,c]>=self.thresholds[cs])>0:
                    mask = semantic_c[:,:,c]>=self.thresholds[cs]
                    n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
                        cv2.dilate(mask.astype(np.uint8)*255, kernel, iterations=1)
                    )

                    size_thresh = 200
                    for i in range(1, n_labels):
                        if stats[i, cv2.CC_STAT_AREA] >= size_thresh:
                            x = stats[i, cv2.CC_STAT_LEFT]
                            y = stats[i, cv2.CC_STAT_TOP]
                            w = stats[i, cv2.CC_STAT_WIDTH]
                            h = stats[i, cv2.CC_STAT_HEIGHT]

                            mask_i = semantic_c[:,:,c]>=self.thresholds[cs]
                            mask_i[:y] = False
                            mask_i[max(y+h,mask_i.shape[0]):] = False
                            mask_i[:,:x] = False
                            mask_i[:,min(x+w,mask_i.shape[1]):] = False

                            nz = np.nonzero(mask_i)

                            idx = np.argmax(nz[1])
                            i = max(nz[0][idx],80)
                            j = min(nz[1][idx],semantic.shape[1]-100)

                            score = np.mean(semantic_c[:,:,c][mask_i])

                            cv2.putText(semantic, f'{score:.4f}', (j,i), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 4, cv2.LINE_AA)
                            cv2.putText(semantic, f'{score:.4f}', (j,i), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
                            
                            # score with uncertainty
                            score_wu = np.mean(semantic_c[:,:,c][mask_i]-alpha*uncertainty_c[:,:,c][mask_i])
                            cv2.putText(uncertainty, f'{score_wu:.4f}', (j,i), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 4, cv2.LINE_AA)
                            cv2.putText(uncertainty, f'{score_wu:.4f}', (j,i), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)

            return render_where_seen, semantic, uncertainty, instance_visual, depth, rgb
        else:
            return (
                np.zeros_like(obs.rgb),
                np.zeros_like(obs.rgb),
                np.zeros_like(obs.rgb),
                np.zeros_like(obs.rgb),
                np.zeros_like(obs.depth),
                np.zeros_like(obs.rgb)
            )

    def get_where_seen(self, obs: Observations, remove_bad_views=False):
        transform = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
        tf_camera_to_episodic = obs_to_tf(obs)

        pose = np.linalg.inv(tf_camera_to_episodic @ transform)

        return self.model.get_where_seen(pose, obs.rgb, obs.depth, remove_bad_views)

    def get_eig_path(self, poses: List[torch.Tensor], instance: int, visualize: bool=False, obj_class: int = 1) -> Tuple[torch.Tensor,torch.Tensor]:
        return self.model.get_eig_path(poses, instance, visualize, obj_class)
    
    def get_expected_score(self, poses: List[torch.Tensor], instance: int, obj_class: int = 1) -> Tuple[torch.Tensor,torch.Tensor]:
        return self.model.get_expected_score(poses, instance, obj_class)

    def stop_backend(self):
        if hasattr(self.model, 'stop'):
            self.model.stop()