# Adapted from SplaTAM: https://github.com/spla-tam/SplaTAM/
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F

from helios.agent.gaze.gaussian_representation.slam_utils import get_pointcloud, initialize_params, initialize_new_params, save_params_ckpt, keyframe_selection_overlap

from gsplat.rendering import rasterization, rasterization_2dgs
import nvtx
from torch_geometric.nn.pool import voxel_grid
from torch_geometric.nn.aggr import MinAggregation

class BackEnd:
    def step(self, dataset, not_overlapping):
        config = self.config
        self.c_dataset = dataset
        self.c_iter = 0
        # Load RGBD frames incrementally instead of all frames
        color, depth, semantic_c, _, pose, where_update, instances = dataset
        # Process poses
        gt_w2c = torch.linalg.inv(pose)
        # Process RGB-D Data
        color = color.permute(2, 0, 1) / 255
        depth = depth.permute(2, 0, 1)
        semantic_c = semantic_c.permute(2, 0, 1)

        self.gt_w2c_all_frames.append(gt_w2c)
        curr_gt_w2c = self.gt_w2c_all_frames
        # Optimize only current time step for tracking
        time_idx = self.time_idx
        iter_time_idx = time_idx
        # Initialize Mapping Data for selected frame
        curr_data = {
            "im": color,
            "depth": depth,
            "semantic_c": semantic_c,
            "id": iter_time_idx,
            "intrinsics": self.intrinsics,
            "w2c": gt_w2c, #self.first_frame_w2c,
            "iter_gt_w2c_list": curr_gt_w2c,
            "where_update": where_update
        }

        # Initialize the camera pose for the current frame
        # if time_idx > 0:
            # self.params["cam_unnorm_rots"].requires_grad = False
            # self.params["cam_trans"].requires_grad = False

            # rel_w2c = (
            #     torch.linalg.inv(self.first_frame_w2c) @ self.gt_w2c_all_frames[-1]
            # )
            # rel_w2c_rot = rel_w2c[:3, :3].unsqueeze(0).detach()
            # rel_w2c_rot_quat = matrix_to_quaternion(rel_w2c_rot)
            # rel_w2c_tran = rel_w2c[:3, 3].detach()
            # Update the camera parameters
            # self.params["cam_unnorm_rots"][..., time_idx] = rel_w2c_rot_quat
            # self.params["cam_trans"][..., time_idx] = rel_w2c_tran

        if time_idx > 0 and (not_overlapping>=200):
            # Add new Gaussians to the scene based on the Silhouette
            self.add_new_gaussians(
                curr_data,
                config["mapping"]["sil_thres"],
                time_idx,
                config["mean_sq_dist_method"],
                config["gaussian_distribution"],
                self.n_semantic_channels,
                gt_w2c
            )

        with torch.no_grad():
            curr_w2c =  gt_w2c
            # Select Keyframes for Mapping
            num_keyframes = config["mapping_window_size"] - 2
            selected_keyframes = keyframe_selection_overlap(
                depth, curr_w2c, self.intrinsics, self.keyframe_list[:-1], num_keyframes
            )
            selected_time_idx = [
                self.keyframe_list[frame_idx]["id"] for frame_idx in selected_keyframes
            ]
            if len(self.keyframe_list) > 0:
                # Add last keyframe to the selected keyframes
                selected_time_idx.append(self.keyframe_list[-1]["id"])
                selected_keyframes.append(len(self.keyframe_list) - 1)
            # Add current frame to the selected keyframes
            selected_time_idx.append(time_idx)
            selected_keyframes.append(-1)
        self.selected_keyframes = selected_keyframes

        # Reset Optimizer & Learning Rates for Full Map Optimization
        self.optimizer = self.initialize_optimizer(
            config["mapping"]["lrs"]
        )

        # Add frame to keyframe list
        if (
            (
                (time_idx == 1) 
                or ((time_idx < 10) and not_overlapping>=1000)
                or (not_overlapping>=4000)
            )
            and (not torch.isinf(curr_gt_w2c[-1]).any())
            and (not torch.isnan(curr_gt_w2c[-1]).any())
        ):
            with torch.no_grad():
                curr_w2c =  gt_w2c
                # Initialize Keyframe Info
                curr_keyframe = {
                    "id": time_idx,
                    "est_w2c": curr_w2c,
                    "color": color,
                    "depth": depth,
                    "where_update": where_update
                }
                # Add to keyframe list
                self.keyframe_list.append(curr_keyframe)
                self.keyframe_time_indices.append(time_idx)
        
        while self.c_iter < self.config["mapping"]["num_iters"]:
            self.refine() 

        self.time_idx += 1

    @torch.no_grad()
    def save_model(self, episode_key: str):
        if len(self.gt_w2c_all_frames) > 0:
            params = self.params.copy()

            semantic_scaled = torch.nan_to_num(self.params["semantic_c"]/torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))
            params["semantic_c"] = semantic_scaled
            
            params["Uncertainty"] = torch.sqrt(semantic_scaled*(1-semantic_scaled)/(1+ torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1)))

            # Save Parameters
            save_params_ckpt(params, os.path.join(self.output_dir, episode_key), self.time_idx)

    def first_frame(self, dataset):
        # Get RGB-D Data & Camera Parameters
        color, depth, _, intrinsics, pose, where_update, _ = dataset

        # Process RGB-D Data
        color = color.permute(2, 0, 1) / 255  # (H, W, C) -> (C, H, W)
        depth = depth.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)

        # Process Camera Parameters
        intrinsics = intrinsics[:3, :3]
        w2c = torch.linalg.inv(pose)

        self.img_w = color.shape[2]
        self.img_h = color.shape[1]

        # Get Initial Point Cloud (PyTorch CUDA Tensor)
        mask = (depth > 0) * where_update  # Mask out invalid depth values
        mask = mask.reshape(-1)
        init_pt_cld, mean3_sq_dist = get_pointcloud(
            color,
            depth,
            intrinsics,
            w2c,
            mask=mask,
            compute_mean_sq_dist=True,
            mean_sq_dist_method=self.config["mean_sq_dist_method"],
        )

        # Initialize Parameters
        self.params, self.variables = initialize_params(
            init_pt_cld,
            self.num_frames,
            mean3_sq_dist,
            self.config["gaussian_distribution"],
            self.n_semantic_channels,
            self.config["concentration_param_scaling_factor"]
        )
        if self.use_instances:
            self.params['instances'] = init_pt_cld.new_zeros(init_pt_cld.shape[0], dtype=torch.int32)

        # Initialize an estimate of scene radius for Gaussian-Splatting Densification
        self.variables["scene_radius"] = torch.max(depth) / self.config["scene_radius_depth_ratio"]

        self.intrinsics = intrinsics
        self.first_frame_w2c = w2c

        self.n_semantic_instances = torch.zeros(4, dtype=torch.int32, device="cuda")

        self.step(dataset, not_overlapping=0)

    def add_new_gaussians(
        self,
        curr_data,
        sil_thres,
        time_idx,
        mean_sq_dist_method,
        gaussian_distribution,
        n_semantic_channels,
        curr_w2c,
    ):
        with torch.no_grad():
            data = {"w2c": curr_data["w2c"].unsqueeze(0),
                        "intrinsics": self.intrinsics.unsqueeze(0),
                        "width": self.img_w,
                        "height": self.img_h}

            renders = self.rendering_function(data, self.params["rgb_colors"].unsqueeze(0), render_mode="RGB+D")

            render_depth = renders[:,:,:,3].squeeze(0)

        # Check for new foreground objects by using GT depth
        gt_depth = curr_data["depth"][0, :, :]
        # render_depth = depth_sil[0, :, :]
        depth_error = torch.abs(gt_depth - render_depth) * (gt_depth > 0)
        # non_presence_depth_mask = (render_depth > gt_depth) * (
        #     depth_error > 50 * depth_error.median()
        # )
        kernel = np.ones((5,5),np.uint8)
        depth_error_img = ((depth_error>1e-3 )*(render_depth==0)).cpu().numpy().astype(np.uint8)
        depth_error_eroded = torch.tensor(cv2.erode(depth_error_img,kernel,iterations = 1),device=depth_error.device)
        non_presence_depth_mask = (depth_error > 1)|depth_error_eroded
        # Determine non-presence mask
        non_presence_mask = non_presence_depth_mask*curr_data["where_update"]
        # Flatten mask
        non_presence_mask = non_presence_mask.reshape(-1)

        # Get the new frame Gaussians based on the Silhouette
        if torch.sum(non_presence_mask) > 0:
            valid_depth_mask = curr_data["depth"][0, :, :] > 0
            non_presence_mask = non_presence_mask & valid_depth_mask.reshape(-1)
            new_pt_cld, mean3_sq_dist = get_pointcloud(
                curr_data["im"],
                curr_data["depth"],
                curr_data["intrinsics"],
                curr_w2c,
                mask=non_presence_mask,
                compute_mean_sq_dist=True,
                mean_sq_dist_method=mean_sq_dist_method,
            )

            new_params = initialize_new_params(
                new_pt_cld, mean3_sq_dist, gaussian_distribution, n_semantic_channels, 
                self.config["concentration_param_scaling_factor"]
            )
            if self.use_instances:
                new_params['instances'] = new_pt_cld.new_zeros(new_pt_cld.shape[0], dtype=torch.int32)

            for k, v in new_params.items():
                self.params[k] = torch.cat((self.params[k], v), dim=0)
                if k != 'instances' and k!= 'semantic_c': # instances and semantic_c are just tensors
                    self.params[k] = torch.nn.Parameter(
                        self.params[k].requires_grad_(True)
                    )
            num_pts = self.params["means3D"].shape[0]
            self.variables["means2D_gradient_accum"] = torch.zeros(
                num_pts, device="cuda"
            ).float()
            self.variables["denom"] = torch.zeros(num_pts, device="cuda").float()
            self.variables["max_2D_radius"] = torch.zeros(num_pts, device="cuda").float()
            new_timestep = time_idx * torch.ones(new_pt_cld.shape[0], device="cuda").float()
            self.variables["timestep"] = torch.cat((self.variables["timestep"], new_timestep), dim=0)

    def get_updated_concentration_params(self, data, semantic_meas):        
        dummy_var = torch.ones((self.params["means3D"].shape[0],semantic_meas.shape[-1]), requires_grad = True, device=self.device, dtype=self.params["semantic_c"].dtype)
        renders_s = self.rendering_function(data, dummy_var.unsqueeze(0), render_mode="RGB")
        diff = torch.sum(semantic_meas*renders_s[0])
        diff.backward(inputs=dummy_var)
        return dummy_var.grad.squeeze()
    
    def refine(self):
        color, depth, semantic_c, _, _, where_update, instances = self.c_dataset
        color = color.permute(2, 0, 1) / 255
        depth = depth.permute(2, 0, 1)
        semantic_c = semantic_c.permute(2, 0, 1)
        optimizer = self.optimizer
        config = self.config

        # Randomly select a frame until current time step amongst keyframes
        rand_idx = np.random.randint(0, len(self.selected_keyframes))
        selected_rand_keyframe_idx = self.selected_keyframes[rand_idx]

        if self.c_iter < 2:
            selected_rand_keyframe_idx = -1

        if selected_rand_keyframe_idx == -1:
            # Use Current Frame Data
            iter_time_idx = self.time_idx
            iter_color = color
            iter_depth = depth
            iter_where_update = where_update
            w2c = self.gt_w2c_all_frames[-1]
        else:
            # Use Keyframe Data
            iter_time_idx = self.keyframe_list[selected_rand_keyframe_idx]["id"]
            iter_color = self.keyframe_list[selected_rand_keyframe_idx]["color"]
            iter_depth = self.keyframe_list[selected_rand_keyframe_idx]["depth"]
            iter_where_update = self.keyframe_list[selected_rand_keyframe_idx]["where_update"]
            w2c = self.keyframe_list[selected_rand_keyframe_idx]["est_w2c"]
        iter_gt_w2c = self.gt_w2c_all_frames[: iter_time_idx + 1]
        iter_data = {
            "im": iter_color,
            "depth": iter_depth,
            "id": iter_time_idx,
            "intrinsics": self.intrinsics,
            "w2c": w2c, #self.first_frame_w2c,
            "iter_gt_w2c_list": iter_gt_w2c,
            "where_update": iter_where_update
        }
        # Loss for current frame
        loss, losses = self.get_loss(
            iter_data,
            config["mapping"]["loss_weights"],
        )

        # Backprop
        loss.backward()
        # Optimizer Update -- need to have here not just later or it will not update the params!
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        
        #Semantic update
        if self.c_iter==0:
            data = {"w2c": iter_data["w2c"].unsqueeze(0),
                "intrinsics": iter_data["intrinsics"].unsqueeze(0),
                "width": iter_data["depth"].shape[2],
                "height": iter_data["depth"].shape[1]}
            mask = iter_data["where_update"] #mask out locations where object is not detected
            mask = mask.detach()
            semantic_meas = semantic_c.clone().permute((1,2,0))

            # semantic_meas[:,:,0] += torch.sum(semantic_meas[:,:,1:],dim=-1) < 0.3
            semantic_meas = semantic_meas*torch.tile(mask.unsqueeze(-1), (1, 1, self.n_semantic_channels))
            torch.clamp(semantic_meas,0,1)
            update = self.get_updated_concentration_params(data, semantic_meas)
            with torch.no_grad():
                self.params["semantic_c"] += self.config["concentration_param_scaling_factor"]*update
            del update

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # instance
            if self.use_instances:
                # with torch.no_grad():
                #     new_gs = torch.arange(self.prev_len,self.params["semantic_c"].shape[0])
                #     non_inst_mask = torch.ones(self.params['instances'].shape, device=self.device, dtype=torch.bool)
                #     non_inst_mask[self.prev_len:] = False
                if self.config.instances.use_rendering_for_instances_update:
                    iter_data["semantic_c"] = semantic_c
                    iter_data["instances"] = instances
                    self.update_instances_rendering(iter_data)
                else:
                    # with torch.no_grad():
                    self.update_instances_3d(iter_data)
                #initialize new parameters with average in their instance
                # with torch.no_grad():
                #     updated_values = self.params["semantic_c"][new_gs,:]-1.0 #minus 1 as initialized with this value
                #     for i in range(new_gs.shape[0]):
                #         insti = new_gs[i]
                #         inst = self.params['instances'][insti]
                #         if torch.sum((self.params['instances']==inst)*non_inst_mask)>0 and inst!=0:
                #             for c in range(4):
                #                 self.params["semantic_c"][insti,c] = torch.mean(self.params["semantic_c"][:,c][(self.params['instances']==inst)*non_inst_mask])
                #         else:
                #             updated_values[i,:] += 1.0 #Add initialization value back in as not initializing with average of instance
                #     self.params["semantic_c"][new_gs,:] += updated_values
                
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

        with torch.no_grad():
            # Prune Gaussians
            if config["mapping"]["prune_gaussians"]:
                self.prune_gaussians(
                    optimizer,
                    config["mapping"]["pruning_dict"],
                )

        # Optimizer Update
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        self.c_iter += 1

        # torch.cuda.empty_cache()

    def update_instances_rendering(self, iter_data):
        data = {"w2c": iter_data["w2c"].unsqueeze(0),
                "intrinsics": iter_data["intrinsics"].unsqueeze(0),
                "width": iter_data["depth"].shape[2],
                "height": iter_data["depth"].shape[1]}
        semantic_classes = torch.argmax(self.params['semantic_c'], dim=-1)

        for i in range(len(iter_data["instances"])):
            #get Gaussians which would contribute to render of instance
            gaussians_update = self.get_updated_concentration_params(data, iter_data["instances"][i][1].float().unsqueeze(-1))>0
            c = iter_data["instances"][i][0]
            with torch.no_grad():
                n_gu = torch.sum(gaussians_update)
                if (torch.sum((self.params["instances"]>10000)*(gaussians_update)) > 500) \
                    or ((c==1) and (torch.sum((self.params["instances"]>10000)*torch.sum(gaussians_update))>50)) \
                    or (n_gu/torch.sum((self.params["instances"]>10000)*(gaussians_update)))>0.8:
                    i0_unique = torch.unique(self.params["instances"][(self.params["instances"]>10000)*(gaussians_update)])
                    i_unique = []
                    for i in i0_unique:
                        vc = torch.sum(semantic_classes[self.params["instances"]==i]==c)
                        ninstup = torch.sum((self.params["instances"]==i)*gaussians_update)
                        if ((vc/torch.sum(self.params["instances"]==i))>0.5) and (
                                ninstup>500 
                                or ((c==1) and ninstup>50) 
                                or (n_gu/ninstup)>0.8
                            ):
                            i_unique += [i]
                    if len(i_unique)>1:
                        #check if centers are close to each other, if not ignore
                        # c0 = torch.mean(self.params['means3D'][self.params["instances"]==i_unique[0]],dim=0)
                        # all_close = True
                        # for i in range(1, len(i_unique)):
                        #     if not torch.linalg.norm(c0-torch.mean(self.params['means3D'][self.params["instances"]==i_unique[i]],dim=0))<1.5:
                        #         all_close=False
                        #         break
                        # if all_close:
                        self.params["instances"][gaussians_update] = i_unique[0]
                        for i in range(1,len(i_unique)):
                            self.params["instances"][self.params["instances"]==i_unique[i]] = i_unique[0]
                    elif len(i_unique)>0:
                        self.params["instances"][gaussians_update] = i_unique[0]
                        for i in range(1,len(i_unique)):
                            self.params["instances"][self.params["instances"]==i_unique[i]] = i_unique[0]
                    else:
                        self.params["instances"][gaussians_update] = self.n_semantic_instances[c] + 1 + c * 10000
                        self.n_semantic_instances[c]  += 1
                elif n_gu > 200 or ((c==1) and torch.sum(torch.sum(gaussians_update)) > 50):
                    self.params["instances"][gaussians_update] = self.n_semantic_instances[c] + 1 + c * 10000
                    self.n_semantic_instances[c]  += 1
    
    @nvtx.annotate("[Backend.refine()] instance")
    def update_instances_3d(self, iter_data):
        data = {"w2c": iter_data["w2c"].unsqueeze(0),
                "intrinsics": iter_data["intrinsics"].unsqueeze(0),
                "width": iter_data["depth"].shape[2],
                "height": iter_data["depth"].shape[1]}
        gaussians_update = self.get_updated_concentration_params(data, torch.ones(iter_data["depth"].shape,device=self.device).unsqueeze(-1))>0
        with torch.no_grad():
            affected_instances = torch.unique(self.params["instances"][(self.params["instances"]>10000)*(gaussians_update)])
            affected_mask = gaussians_update 
            for i in affected_instances:
                affected_mask |= self.params["instances"]==i

            min_aggregation = MinAggregation()
            semantic_classes = torch.argmax(self.params['semantic_c'], dim=-1)
            for semantic_class in range(1,self.n_semantic_channels):
                semantic_mask = (semantic_classes == semantic_class)*affected_mask
                if not torch.any(semantic_mask):
                    continue
                n_semantic = torch.sum(semantic_mask)
                semantic_means3d = self.params['means3D'][semantic_mask]
                semantic_instances = self.params['instances'][semantic_mask]

                # erase wrong instances
                semantic_instances[semantic_instances // 10000 != semantic_class] = 0

                # spatial cluster
                clusters = torch.arange(n_semantic, device=semantic_means3d.device, dtype=torch.long)
                max_obj_size = 10
                voxel_size = 0.5
                start = torch.min(semantic_means3d, dim=0).values
                index = voxel_grid(semantic_means3d, voxel_size, start=start)
                _, index = torch.unique(index, return_inverse=True)
                other_index = voxel_grid(semantic_means3d, voxel_size, start=start-voxel_size/2)
                _, other_index = torch.unique(other_index, return_inverse=True)
                for i in range(int(max_obj_size / voxel_size)):
                    clusters = min_aggregation(clusters, index, dim=0)[index]
                    clusters = min_aggregation(clusters, other_index, dim=0)[other_index]
                cluster_uniques, cluster_inverses = torch.unique(clusters, return_inverse=True)
                n_unique_clusters = len(cluster_uniques)

                # count instances
                semantic_instances: torch.Tensor
                semantic_instances = torch.concatenate([ # prepend 0 to ensure it exists
                    semantic_instances.new_zeros(1),
                    semantic_instances
                ], dim=0)
                instance_uniques, instance_inverses = torch.unique(
                    semantic_instances,
                    return_inverse=True
                )
                instance_inverses = instance_inverses[1:] # exclude 0 for counting
                n_unique_instances = len(instance_uniques)
                n_instance_clusters = semantic_instances.new_zeros(n_unique_instances, n_unique_clusters)
                n_instance_clusters[instance_inverses, cluster_inverses] += 1

                # assign existing instances to clusters
                n_instance_clusters = n_instance_clusters[1:] # exclude 0 for assignment
                n_unique_instances -= 1
                cluster_instances = semantic_instances.new_zeros(n_unique_clusters)
                # greedily assign most common instance-cluster
                while torch.any(n_instance_clusters):
                    i_instance_cluster = torch.argmax(n_instance_clusters.flatten())
                    i_instance, i_cluster = i_instance_cluster // n_unique_clusters, i_instance_cluster % n_unique_clusters
                    n_instance_clusters[i_instance] = 0
                    n_instance_clusters[:, i_cluster] = 0
                    cluster_instances[i_cluster] = instance_uniques[i_instance + 1] # include 0 for indexing

                # assign new instances to remaining clusters
                is_new = cluster_instances == 0
                n_new_clusters = torch.sum(is_new)
                cluster_instances[is_new] = (
                    self.n_semantic_instances[semantic_class]
                    + torch.arange(
                        n_new_clusters,
                        device=is_new.device,
                        dtype=torch.int32
                    )
                    + semantic_class * 10000
                )
                self.n_semantic_instances[semantic_class] += n_new_clusters

                # update params
                self.params['instances'][semantic_mask] = cluster_instances[cluster_inverses]

    def prune_gaussians_sem(self, optimizer, prune_dict):
        if self.c_iter == prune_dict['stop_after']:
            remove_threshold = prune_dict['final_removal_semantic_threshold']
        else:
            remove_threshold = prune_dict['removal_semantic_threshold']
        semantic_scaled = torch.nan_to_num(self.params["semantic_c"]/torch.sum(self.params["semantic_c"],dim=-1).unsqueeze(-1))
        to_remove = torch.all(semantic_scaled[:,1:] < remove_threshold, dim=1).squeeze()
        self.params, self.variables = self.remove_points(to_remove, optimizer)

    def initialize_optimizer(self, lrs_dict):
        lrs = lrs_dict
        param_groups = [{"params": [v], "name": k, "lr": lrs[k]} for k, v in self.params.items() if k in lrs] # 'instances' not in lr
        return torch.optim.Adam(param_groups)

    @torch.enable_grad()
    def get_loss(
        self,
        curr_data,
        loss_weights,
    ):
        # Initialize Loss Dictionary
        losses = {}

        # Initialize Render Variables
        data = {"w2c": curr_data["w2c"].unsqueeze(0),
            "intrinsics": curr_data["intrinsics"].unsqueeze(0),
            "width": curr_data["depth"].shape[2],
            "height": curr_data["depth"].shape[1]}

        renders = self.rendering_function(data, self.params["rgb_colors"].unsqueeze(0), render_mode="RGB+D")
        im = renders[:,:,:,:3].squeeze(0)
        depth = renders[:,:,:,3]

        mask = curr_data["where_update"]*(curr_data["depth"]<100) #mask out locations where object is not detected or depth is invalid

        # Depth loss
        mask = mask.detach()
        losses["depth"] = torch.abs(curr_data["depth"].squeeze() - depth.squeeze())[mask.squeeze()].mean()

        # RGB Loss
        color_mask = torch.tile(mask, (3, 1, 1))
        color_mask = color_mask.detach()
        losses["im"] = torch.abs(curr_data["im"].permute((1,2,0)) - im)[color_mask.permute((1,2,0))].mean()

        # masked_img = curr_data["im"].permute((1,2,0)).detach().clone()
        # masked_img[~color_mask.permute((1,2,0))] = 0

        # masked_render = im.detach().clone()
        # masked_render[~color_mask.permute((1,2,0))] = 0
        
        # vis = (torch.vstack([torch.hstack([curr_data["im"].permute((1,2,0)),im]),
        #                     torch.hstack([masked_img, masked_render])]).detach()*255).cpu().numpy().astype(np.uint8)
        # cid = curr_data["id"] #needs to be outside the string formatting or it throws an error
        # plt.imsave(f"refine_debug/rgb_{self.time_idx}_{cid}.png", vis)


        # masked_depth = curr_data["depth"].squeeze().detach().clone()
        # masked_depth[~mask.squeeze()] = 0

        # masked_depth_render = depth.squeeze().detach().clone()
        # masked_depth_render[~mask.squeeze()] = 0
        
        # vis = (torch.vstack([torch.hstack([curr_data["depth"].squeeze(),depth.squeeze()]),
        #                     torch.hstack([masked_depth, masked_depth_render])]).detach()).cpu().numpy()
        # plt.imsave(f"refine_debug/depth_{self.time_idx}_{cid}.png", vis)

        weighted_losses = {k: v * loss_weights[k] for k, v in losses.items()}
        loss = sum(weighted_losses.values())
        return loss, weighted_losses
    
    def rendering_function(self, data, colors, render_mode, backgrounds=None):
        if self.config.use_2dgs:
            renders, _, _, _, _, _, _ = rasterization_2dgs(
                    means = self.params["means3D"],  # [N, 3]
                    quats = self.params["unnorm_rotations"], # [N, 4]
                    scales = torch.exp(self.params["log_scales"]), # [N, 3]
                    opacities = torch.sigmoid(self.params["logit_opacities"].squeeze()),  # [N]
                    colors = colors,  # [(C,) N, D] or [(C,) N, K, 3]
                    viewmats = data["w2c"],  # [C, 4, 4]
                    Ks = data["intrinsics"],  # [C, 3, 3]
                    width = data["width"],
                    height = data["height"],
                    near_plane = 0.01,
                    far_plane = 1e10,
                    radius_clip = 0.0,
                    eps2d = 0.3,
                    sh_degree = None,
                    packed = False,
                    tile_size = 16,
                    backgrounds = backgrounds,
                    render_mode = render_mode,
                    sparse_grad = False,
                    absgrad = False)
        else:
            renders, _, _ = rasterization(
                    means = self.params["means3D"],  # [N, 3]
                    quats = self.params["unnorm_rotations"], # [N, 4]
                    scales = torch.exp(self.params["log_scales"]), # [N, 3]
                    opacities = torch.sigmoid(self.params["logit_opacities"].squeeze()),  # [N]
                    colors = colors,  # [(C,) N, D] or [(C,) N, K, 3]
                    viewmats = data["w2c"],  # [C, 4, 4]
                    Ks = data["intrinsics"],  # [C, 3, 3]
                    width = data["width"],
                    height = data["height"],
                    near_plane = 0.01,
                    far_plane = 1e10,
                    radius_clip = 0.0,
                    eps2d = 0.3,
                    sh_degree = None,
                    packed = False,
                    tile_size = 16,
                    backgrounds = backgrounds,
                    render_mode = render_mode,
                    sparse_grad = False,
                    absgrad = False,
                    rasterize_mode = "classic", #["classic", "antialiased"]
                    covars = None)

        return renders

    # remove instances separately because not part of optimizer
    def prune_gaussians(self, optimizer, prune_dict):
        if self.c_iter <= prune_dict['stop_after']:
            if (self.c_iter >= prune_dict['start_after']) and (self.c_iter % prune_dict['prune_every'] == 0):
                if self.c_iter == prune_dict['stop_after']:
                    remove_threshold = prune_dict['final_removal_opacity_threshold']
                else:
                    remove_threshold = prune_dict['removal_opacity_threshold']
                # Remove Gaussians with low opacity
                to_remove = (torch.sigmoid(self.params['logit_opacities']) < remove_threshold).squeeze()
                # Remove Gaussians that are too big
                if self.c_iter >= prune_dict['remove_big_after']:
                    big_points_ws = torch.exp(self.params['log_scales']).max(dim=1).values > 0.1 * self.variables['scene_radius']
                    to_remove = torch.logical_or(to_remove, big_points_ws)
                if torch.any(to_remove):
                    instances = self.params.pop('instances', None)
                    self.remove_points(~to_remove, optimizer)
                    if instances is not None:
                        self.params['instances'] = instances[~to_remove]
                torch.cuda.empty_cache()

    def remove_points(self, to_keep, optimizer):
        keys = [k for k in self.params.keys() if k in ['semantic_c', 'instances']]
        for k in keys:
            self.params[k] = self.params[k][to_keep]
        keys = [k for k in self.params.keys() if k not in ['cam_unnorm_rots', 'cam_trans', 'semantic_c', 'instances']]
        for k in keys:
            group = [g for g in optimizer.param_groups if g['name'] == k][0]
            stored_state = optimizer.state.get(group['params'][0], None)
            if stored_state is not None:
                stored_state["exp_avg"] = stored_state["exp_avg"][to_keep]
                stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][to_keep]
                del optimizer.state[group['params'][0]]
                group["params"][0] = torch.nn.Parameter((group["params"][0][to_keep].requires_grad_(True)))
                optimizer.state[group['params'][0]] = stored_state
                self.params[k] = group["params"][0]
            else:
                group["params"][0] = torch.nn.Parameter(group["params"][0][to_keep].requires_grad_(True))
                self.params[k] = group["params"][0]
        self.variables['means2D_gradient_accum'] = self.variables['means2D_gradient_accum'][to_keep]
        self.variables['denom'] = self.variables['denom'][to_keep]
        self.variables['max_2D_radius'] = self.variables['max_2D_radius'][to_keep]
        if 'timestep' in self.variables.keys():
            self.variables['timestep'] = self.variables['timestep'][to_keep]

