from typing import Tuple, List
import os
import cv2
import numpy as np
import torch

class SLAM:
    def __init__(self, 
        config: dict, 
        max_frames: int, 
        n_semantic_channels: int, 
    ):
        # Create Output Directories
        self.output_dir = os.path.join(config["workdir"], 'objects')
        os.makedirs(self.output_dir, exist_ok=True)

        # Get Device
        self.device = torch.device(config["primary_device"])

        self.num_frames = max_frames
        self.config = config
        self.n_semantic_channels = n_semantic_channels

        self.time_idx = 0

        self.obj_locs = {}
        self.obj_scores = {}

        self.params = None
        self.first_frame_w2c = None
        self.intrinsics = None

        self.use_instances = config.instances.use_instance

    def reset(self):
        self.params = None
        self.variables = None
        self.time_idx = 0

        self.obj_locs = {}
        self.obj_scores = {}

        self.params = None
        self.first_frame_w2c = None
        self.intrinsics = None

    def first_frame(self, dataset):
        raise NotImplementedError


    def step(self, dataset, not_overlapping):
        raise NotImplementedError

    @torch.no_grad()
    def has_obj(self, obj_class: int, threshold: float, use_uncertainty: bool):
        return self.get_obj_locs(obj_class, threshold, use_uncertainty)[0]

    @torch.no_grad()
    def get_obj_score(self, obj_class: int, use_uncertainty: bool) -> float:
        with torch.no_grad():
            if use_uncertainty:
                score = self.semantic_scaled[:, obj_class]-self.config.objective.alpha_uncertainty*self.uncertainty[:, obj_class]
            else:
                score = torch.nan_to_num(self.params["semantic_c"][:, obj_class]/torch.sum(self.params["semantic_c"],dim=-1))
            
            if self.config.instances.use_instances_for_score:
                for i in torch.unique(self.params["instances"]):
                    mask = self.params["instances"]==i
                    score[mask] = torch.mean(score[mask])
            
            return score
            
    
    @torch.no_grad()
    def get_obj_locs(
        self,
        obj_class: int,
        threshold: float,
        use_uncertainty: bool,
        return_scores: bool,
        return_instances: bool
    ):
        has_obj = False
        means3d = None
        scores = None
        instances = None
        if len(self.params.keys())>0:
            scores = self.get_obj_score(obj_class, use_uncertainty)
            where = scores > threshold
            if torch.any(where):
                has_obj = True
                means3d = self.params["means3D"][where].detach()
                if return_scores:
                    scores = scores[where].detach()#.cpu().numpy()
                if return_instances:
                    instances = self.params["instances"][where].detach()
        
        ret = [has_obj, means3d]
        if return_scores:
            ret.append(scores)
        if return_instances:
            ret.append(instances)
        return tuple(ret)

    @torch.no_grad()
    def get_where_seen(self, w2c, rgb, depth, remove_bad_views=False):
        rgb_rend, _, _, _, depth_rend = self.get_renders(
            w2c,
            return_instances=False,
            return_uncertainty=False,
            return_semantics=False
        )
        # rgb_rend2, _, _, _, _ = self.get_renders(
        #     w2c,
        #     return_instances=False,
        #     return_uncertainty=False,
        #     return_semantics=False,
        #     backgrounds=1
        # )
        if self.config.use_2dgs:
            # res = (np.sum(np.abs(rgb_rend-rgb),axis=-1)<20)*(np.sum(np.abs(rgb_rend2-rgb),axis=-1)<20)
            kernel = np.ones((3,3), np.uint8)  #cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3))
            res = cv2.morphologyEx(cv2.morphologyEx(255*((np.abs(depth-depth_rend)<0.5).astype(np.uint8)),cv2.MORPH_CLOSE,kernel),cv2.MORPH_OPEN,kernel)>0
        else:
            # res = (np.sum(rgb_rend-rgb,-1)<5)*(np.sum(rgb_rend2-rgb,-1)<5)
            # res = (np.abs(depth-depth_rend)<0.1)*(depth_rend>0)
            kernel = np.ones((3,3), np.uint8)  #cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3))
            res = cv2.morphologyEx(cv2.morphologyEx(255*((np.abs(depth-depth_rend)<0.5).astype(np.uint8)),cv2.MORPH_CLOSE,kernel),cv2.MORPH_OPEN,kernel)>0

        # if remove_bad_views:
        #     #don't update if object is cut-off
        #     numLabels, labels, stats, _ = cv2.connectedComponentsWithStats(255*res.astype(np.uint8), 8, cv2.CV_32S)
        #     for i in range(1, numLabels):
        #         if np.sum(labels==i) < 500:
        #             x = stats[i, cv2.CC_STAT_LEFT]
        #             y = stats[i, cv2.CC_STAT_TOP]
        #             w = stats[i, cv2.CC_STAT_WIDTH]
        #             h = stats[i, cv2.CC_STAT_HEIGHT]
        #             if (x==0) or ((x+w)==res.shape[0]) or ((y+h)==res.shape[1]):
        #                 res[labels==i] = False

        #     res = res > 0 
        #     res[depth>5.0] = False #don't update if object is too far away
        #     # res[depth>0.05] = False #don't update if object is too close

        return res
    
    @torch.no_grad()
    def get_renders(self, w2c, return_instances, return_uncertainty=True, return_semantics=True, backgrounds=None):
        if isinstance(w2c, np.ndarray):
            w2c = torch.from_numpy(w2c).cuda()
        w2c = w2c.float()

        if backgrounds is None:
            backgrounds_expanded = None

        with torch.no_grad():
            data = {"w2c": w2c.unsqueeze(0),
                    "intrinsics": self.intrinsics.unsqueeze(0),
                    "width": self.img_w,
                    "height": self.img_h}
            
            if not backgrounds is None:
                backgrounds_expanded = torch.ones((1,3),dtype=torch.uint8,device=self.device)*backgrounds

            renders = self.rendering_function(data, self.params["rgb_colors"].unsqueeze(0), render_mode="RGB+D", backgrounds=backgrounds_expanded)

            im = (renders[:,:,:,:3].squeeze(0).detach().cpu().numpy()*255).astype(np.uint8)
            depth = renders[:,:,:,3].squeeze(0).detach().cpu().numpy()

            instances = None
            if return_instances:
                instances: torch.Tensor = self.params["instances"]
                instances = torch.concatenate([ # include 0 instance
                    instances.new_zeros((1,)),
                    instances
                ], dim=0)
                instance_uniques, instance_inverses = torch.unique(instances, return_inverse=True)
                instances = instances[1:] # exclude 0 point
                instance_inverses = instance_inverses[1:] # exclude 0 point
                n_points = instances.shape[0]
                n_instances = instance_uniques.shape[0]
                instance_logits = instances.new_zeros((n_points, n_instances)) # include 0 instance
                instance_logits.scatter_(1, instance_inverses[:, None], 1)
                instance_logits = self.rendering_function(data, instance_logits.float().unsqueeze(0), render_mode='RGB').squeeze(0)
                instances = instance_uniques[torch.argmax(instance_logits, dim=-1)]
                instances = instances.detach().cpu().numpy()
            
            if return_semantics:
                # if not backgrounds is None:
                #     backgrounds_expanded = torch.ones((1,self.n_semantic_channels),dtype=torch.uint8,device=self.device)*backgrounds
                backgrounds_expanded = None

                if self.config.instances.use_instances_for_score:
                    score = self.semantic_scaled.clone()
                    for i in torch.unique(self.params["instances"]):
                        mask = self.params["instances"]==i
                        score[mask] = torch.mean(score[mask])
                    renders = self.rendering_function(data, score.unsqueeze(0), render_mode="RGB", backgrounds=backgrounds_expanded)
                    del score

                else:
                    renders = self.rendering_function(data, self.semantic_scaled.unsqueeze(0), render_mode="RGB", backgrounds=backgrounds_expanded)
                semantic = renders.squeeze(0).detach().cpu().numpy()
            else:
                semantic = None
            
        # if self.use_uncertainty and return_uncertainty:
        #     rel_w2c=torch.linalg.inv(self.first_frame_w2c) @ w2c
        #     _, uncertainty = self.compute_Hessian(rel_w2c, return_uncertainty_img=True)

        #     uncertainty = np.log(uncertainty[0].detach().cpu().numpy()+1.0)

        #     uncertainty = np.clip(uncertainty,0.0,10.0)
        #     uncertainty = (uncertainty)/10.0

        #     uncertainty = cv2.applyColorMap((255*uncertainty).astype(np.uint8), cv2.COLORMAP_OCEAN)
        # el
            if return_uncertainty:
                # if not backgrounds is None:
                #     backgrounds_expanded = torch.ones((1,1),dtype=torch.uint8,device=self.device)*backgrounds
                backgrounds_expanded = None

                if self.config.instances.use_instances_for_score:
                    score = self.uncertainty.clone()
                    for i in torch.unique(self.params["instances"]):
                        mask = self.params["instances"]==i
                        score[mask] = torch.mean(score[mask])
                    renders = self.rendering_function(data, score.unsqueeze(0), render_mode="RGB", backgrounds=backgrounds_expanded)
                    del score
                else:
                    renders = self.rendering_function(data, self.uncertainty.unsqueeze(0), render_mode="RGB", backgrounds=backgrounds_expanded)
                uncertainty = renders.squeeze(0).detach().cpu().numpy()
            else:
                uncertainty = None

        return im, semantic, uncertainty, instances, depth
    
    def get_eig_path(self, poses: List[torch.Tensor], instance: int, visualize: bool=False, obj_class: int = 2) -> Tuple[torch.Tensor,torch.Tensor]:
        viz_imgs = []

        instance_mask = self.params["instances"]==instance

        semantic_meas = torch.zeros((self.img_h,self.img_w,self.n_semantic_channels), device=self.device)
        semantic_meas[:,:,obj_class] = 1
        with torch.no_grad():
            semantic_cp = self.params["semantic_c"][instance_mask].detach().clone()

        for c2w in poses:
            w2c =  torch.linalg.inv(c2w)
            
            data = {"w2c": w2c.unsqueeze(0),
            "intrinsics": self.intrinsics.unsqueeze(0),
            "width": self.img_w,
            "height": self.img_h}

            update = self.get_updated_concentration_params(data, semantic_meas)[instance_mask]

            semantic_cp += update

            if visualize:
                with torch.no_grad():
                    semantic_i = self.params["semantic_c"][instance_mask].detach().clone() + update[instance_mask]

                T = torch.sum(self.params["semantic_c"][instance_mask].detach(),dim=-1)
                lB = torch.sum(torch.special.gammaln(self.params["semantic_c"][instance_mask].detach()),axis=-1)-torch.special.gammaln(T)
                S = torch.sum((self.params["semantic_c"][instance_mask].detach()-1)* torch.digamma(self.params["semantic_c"][instance_mask].detach()),dim=-1)


                entropy_orig = torch.nan_to_num(lB + (T - self.n_semantic_channels)*torch.digamma(T)-S)

                T = torch.sum(semantic_i.detach(),dim=-1)
                lB = torch.sum(torch.special.gammaln(semantic_i.detach()),axis=-1)-torch.special.gammaln(T)
                S = torch.sum((semantic_i.detach()-1)* torch.digamma(semantic_i.detach()),dim=-1)

                entropy_new = torch.nan_to_num(lB + (T - self.n_semantic_channels)*torch.digamma(T)-S)

                uncertainty = entropy_orig - entropy_new

                # renders_s = self.rendering_function(data, uncertainty.unsqueeze(0), render_mode="RGB")

                renders_s = self.rendering_function(data, self.params["rgb_colors"].unsqueeze(0), render_mode="RGB")
                
                img_u = torch.mean(uncertainty[uncertainty>0]).detach().cpu().numpy()
                color = (255, 255, 255) 
                # img_np = cv2.applyColorMap((255*7*renders_s[0].detach().cpu().numpy()).astype(np.uint8), cv2.COLORMAP_OCEAN)
                img_np = (renders_s[0]*255).detach().cpu().numpy().astype(np.uint8)
                viz_imgs.append(cv2.putText(img_np, f'Uncertainty: {img_u}', (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA))
        
        T = torch.sum(self.params["semantic_c"][instance_mask].detach(),dim=-1)
        lB = torch.sum(torch.special.gammaln(self.params["semantic_c"][instance_mask].detach()),axis=-1)-torch.special.gammaln(T)
        S = torch.sum((self.params["semantic_c"][instance_mask].detach()-1)* torch.digamma(self.params["semantic_c"][instance_mask].detach()),dim=-1)
        
        entropy_orig = torch.nan_to_num(lB + (T - self.n_semantic_channels)*torch.digamma(T)-S)

        T = torch.sum(semantic_cp.detach(),dim=-1)
        lB = torch.sum(torch.special.gammaln(semantic_cp.detach()),axis=-1)-torch.special.gammaln(T)
        S = torch.sum((semantic_cp.detach()-1)* torch.digamma(semantic_cp.detach()),dim=-1)

        entropy_new = torch.nan_to_num(lB + (T - self.n_semantic_channels)*torch.digamma(T)-S)
        
        u_diff = entropy_orig - entropy_new 


        return torch.nan_to_num(torch.mean(u_diff[u_diff>0])).detach().cpu().numpy(), viz_imgs #TODO: need to give coords of obj as input (easier with instances)

        
    def get_expected_score(self, poses: List[torch.Tensor], instance: int, obj_class: int = 1) -> Tuple[torch.Tensor,torch.Tensor]:
        instance_mask = self.params["instances"]==instance

        semantic_meas = torch.zeros((self.img_h,self.img_w,self.n_semantic_channels), device=self.device)
        semantic_meas[:,:,obj_class] = 1
        with torch.no_grad():
            semantic_cp = self.params["semantic_c"][instance_mask].detach().clone()

        for c2w in poses:
            w2c =  torch.linalg.inv(c2w)
            
            data = {"w2c": w2c.unsqueeze(0),
            "intrinsics": self.intrinsics.unsqueeze(0),
            "width": self.img_w,
            "height": self.img_h}

            semantic_cp += self.get_updated_concentration_params(data, semantic_meas)[instance_mask]

        semantic_scaled = torch.nan_to_num(semantic_cp[:, obj_class]/torch.sum(semantic_cp,dim=-1))
        uncertainty = torch.sqrt(semantic_scaled*(1-semantic_scaled)/(1+ torch.sum(semantic_cp,dim=-1)))

        return torch.mean(semantic_scaled-self.config.objective.alpha_uncertainty*uncertainty), \
            torch.mean(self.semantic_scaled-self.config.objective.alpha_uncertainty*self.uncertainty)


    @torch.no_grad()
    def save_model(self, episode_key: str):
        raise NotImplementedError
