#

#
from pytorch3d.transforms import quaternion_apply, quaternion_multiply
import torch
import math
from util.utils import get_rotation_matrix
from depth_diff_gaussian_rasterization_min_features import GaussianRasterizationSettings, GaussianRasterizer
# from depth_diff_gaussian_rasterization_min_features import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh import eval_sh
import torch.nn.functional as F
from utils.general import build_rotation, rotation2normal
# from gsplat.rendering import rasterization, rasterization_2dgs

def render(viewpoint_camera, pc: GaussianModel, opt, bg_color: torch.Tensor, scaling_modifier=1.0, override_color=None, render_visible=False,render_seg = False, exclude_sky=False, exclude_fg = False,render_current = False, fg_only=False,remove_list = []):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
 
    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    # screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    screenspace_points = torch.zeros_like(pc.get_xyz_all, dtype=pc.get_xyz_all.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug=opt.debug,
        include_feature = render_seg
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    # means3D = pc.get_xyz
    means3D = pc.get_xyz_all
    means2D = screenspace_points
    # opacity = pc.get_opacity_with_3D_filter
    opacity = pc.get_opacity_all

    segs = pc.get_seg_all
    
        
    # opacity = pc.get_opacity_with_3D_filter_all

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if opt.compute_cov3D_python:
        # cov3D_precomp = pc.get_covariance(scaling_modifier)
        cov3D_precomp = pc.get_covariance_all(scaling_modifier)
    else:
        # scales = pc.get_scaling_with_3D_filter
        # rotations = pc.get_rotation
        # scales = pc.get_scaling_with_3D_filter_all
        scales = pc.get_scaling_all
        rotations = pc.get_rotation_all

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    if override_color is None:
        if opt.convert_SHs_python:
            # shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
            # dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
            shs_view = pc.get_features_all.transpose(1, 2).view(-1, 3)
            # dir_pp = (pc.get_xyz_all - viewpoint_camera.camera_center.repeat(pc.get_features_all.shape[0], 1))
            # dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
            # sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            # colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
            colors_precomp = pc.color_activation(shs_view)
        else:
            # shs = pc.get_features
            shs = pc.get_features_all
    else:
        colors_precomp = override_color


    if render_visible:
        visibility_filter_all = pc.visibility_filter_all & ~pc.delete_mask_all  # Seen in screen
    else:
        visibility_filter_all = ~pc.delete_mask_all

    if exclude_sky:
        visibility_filter_all = visibility_filter_all & ~pc.is_sky_filter

    if fg_only:
        visibility_filter_all = visibility_filter_all & pc.is_fg_filter
    
    if exclude_fg:
        # print(pc.is_fg_filter)
        visibility_filter_all = visibility_filter_all & ~pc.is_fg_filter
    if render_current:
        current_mask = torch.zeros_like(means3D[:,0]).bool()
        current_mask[:pc.get_xyz.shape[0]] = True
        visibility_filter_all = visibility_filter_all & current_mask

    if len(remove_list):
        label = segs.argmax(-1)
        remove_mask = torch.zeros_like(means3D[:,0]).bool()
        for k in remove_list:
            remove_mask[label == k] = 1
        visibility_filter_all = visibility_filter_all & (~remove_mask)

    means3D = means3D[visibility_filter_all]
    means2D = means2D[visibility_filter_all]
    shs = None if shs is None else shs[visibility_filter_all]
    colors_precomp = None if colors_precomp is None else colors_precomp[visibility_filter_all]
    opacity = opacity[visibility_filter_all]
    scales = scales[visibility_filter_all]
    rotations = rotations[visibility_filter_all]
    cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp[visibility_filter_all]
    segs = segs[visibility_filter_all]
        
    # # Rasterize visible Gaussians to image, obtain their radii (on screen). 
    if render_seg:
        rendered_image, rendered_seg,radii, depth, median_depth, final_opacity = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            language_feature_precomp = segs, #[...,:30],
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp)
        
        # for i in range(2):
            
        #     _, rendered_seg_tmp,_, _, _, _ = rasterizer(
        #         means3D = means3D,
        #         means2D = means2D,
        #         shs = shs,
        #         colors_precomp = colors_precomp,
        #         language_feature_precomp = segs[...,30+30*i:60+30*i],
        #         opacities = opacity,
        #         scales = scales,
        #         rotations = rotations,
        #         cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp)
        #     rendered_seg = torch.cat([rendered_seg, rendered_seg_tmp], dim = 0)

        

    else:
        rendered_image, _, radii, depth, median_depth, final_opacity = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = cov3D_precomp)
        rendered_seg = None

   
    return {"render": rendered_image,
            "render_seg" : rendered_seg,
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii,
            "final_opacity": final_opacity,
            "depth": depth,
            "median_depth": median_depth,}

def render_single_obj(viewpoint_camera, pc: GaussianModel, opt, bg_color: torch.Tensor,center, scale, angles, scaling_modifier=1.0, override_color=None, render_visible=False,render_seg = False, exclude_sky=False, exclude_fg = False,render_current = False, remove_list = []):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
 
    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    # screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    screenspace_points = torch.zeros_like(pc.get_xyz_all, dtype=pc.get_xyz_all.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug=opt.debug,
        include_feature = render_seg
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    # means3D = pc.get_xyz
    means3D = pc.get_xyz_all
    means2D = screenspace_points
    # opacity = pc.get_opacity_with_3D_filter
    opacity = pc.get_opacity_all

    segs = pc.get_seg_all
    
        
    # opacity = pc.get_opacity_with_3D_filter_all

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if opt.compute_cov3D_python:
        # cov3D_precomp = pc.get_covariance(scaling_modifier)
        cov3D_precomp = pc.get_covariance_all(scaling_modifier)
    else:
        # scales = pc.get_scaling_with_3D_filter
        # rotations = pc.get_rotation
        # scales = pc.get_scaling_with_3D_filter_all
        scales = pc.get_scaling_all
        rotations = pc.get_rotation_all

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    if override_color is None:
        if opt.convert_SHs_python:
            # shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
            # dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
            shs_view = pc.get_features_all.transpose(1, 2).view(-1, 3)
            # dir_pp = (pc.get_xyz_all - viewpoint_camera.camera_center.repeat(pc.get_features_all.shape[0], 1))
            # dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
            # sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            # colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
            colors_precomp = pc.color_activation(shs_view)
        else:
            # shs = pc.get_features
            shs = pc.get_features_all
    else:
        colors_precomp = override_color


    if render_visible:
        visibility_filter_all = pc.visibility_filter_all & ~pc.delete_mask_all  # Seen in screen
    else:
        visibility_filter_all = ~pc.delete_mask_all

    if exclude_sky:
        visibility_filter_all = visibility_filter_all & ~pc.is_sky_filter
    
    if exclude_fg:
        
        # print(pc.is_fg_filter)
        visibility_filter_all = visibility_filter_all & ~pc.is_fg_filter
    if render_current:
        current_mask = torch.zeros_like(means3D[:,0]).bool()
        current_mask[:pc.get_xyz.shape[0]] = True
        visibility_filter_all = visibility_filter_all & current_mask

    if len(remove_list):
        label = segs.argmax(-1)
        remove_mask = torch.zeros_like(means3D[:,0]).bool()
        for k in remove_list:
            remove_mask[label == k] = 1
        visibility_filter_all = visibility_filter_all & (~remove_mask)

    means3D = means3D[visibility_filter_all]
    means2D = means2D[visibility_filter_all]
    shs = None if shs is None else shs[visibility_filter_all]
    colors_precomp = None if colors_precomp is None else colors_precomp[visibility_filter_all]
    opacity = opacity[visibility_filter_all]
    scales = scales[visibility_filter_all]
    rotations = rotations[visibility_filter_all]
    cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp[visibility_filter_all]
    segs = segs[visibility_filter_all]

    
    scale = torch.exp(scale)
    means3D = (means3D - means3D.mean(0, keepdim=True)).detach()
    means3D_scaled = means3D * scale

    # 1. Rotate point cloud
    # Convert quaternion format to PyTorch3D format (x,y,z,w)
    

    # Directly apply quaternion_apply function to rotate point cloud
    means3D = means3D_scaled @ get_rotation_matrix(angles).T
    means3D = means3D + center

    # 2. Rotate quaternions
    # Convert rotations to PyTorch3D format
    # rotations_pytorch3d = rotations[:, [1, 2, 3, 0]]

    # Use quaternion multiplication
    # For rotation composition, we need to left-multiply
    # rotations_rotated_pytorch3d = quaternion_multiply(
    #     quaternion_pytorch3d.unsqueeze(0).expand(rotations.size(0), -1),
    #     rotations_pytorch3d
    # )

    # Convert result back to original (w,x,y,z) format
    # rotations = rotations_rotated_pytorch3d[:, [3, 0, 1, 2]]
    scales = scales * scale

        
    # # Rasterize visible Gaussians to image, obtain their radii (on screen). 
    if render_seg:
        rendered_image, rendered_seg,radii, depth, median_depth, final_opacity = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            language_feature_precomp = segs,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp)

        

    else:
        rendered_image, _, radii, depth, median_depth, final_opacity = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = cov3D_precomp)
        rendered_seg = None

   
    return {"render": rendered_image,
            "render_seg" : rendered_seg,
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii,
            "final_opacity": final_opacity,
            "depth": depth,
            "median_depth": median_depth,}



def render_with_mask(viewpoint_camera, pc: GaussianModel,  bg_color: torch.Tensor, scaling_modifier=1.0, override_color=None, render_visible=False,render_seg = False, mask = None, center = None, scale = None):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
 
    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    # screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    screenspace_points = torch.zeros_like(pc.get_xyz_all, dtype=pc.get_xyz_all.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug = False,
        include_feature = render_seg
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    # means3D = pc.get_xyz
    means3D = pc.get_xyz_all
    means2D = screenspace_points
    # opacity = pc.get_opacity_with_3D_filter
    opacity = pc.get_opacity_all

    segs = pc.get_seg_all
    
        
    # opacity = pc.get_opacity_with_3D_filter_all

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    
        
    scales = pc.get_scaling_all
    rotations = pc.get_rotation_all

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    
    shs = pc.get_features_all
    


    if render_visible:
        visibility_filter_all = pc.visibility_filter_all & ~pc.delete_mask_all  # Seen in screen
    else:
        visibility_filter_all = ~pc.delete_mask_all

    if mask is not None:
        visibility_filter_all = visibility_filter_all & mask

    if center is not None:
        means3D -= center.unsqueeze(0)
    if scale is not None:
        means3D /= scale
        scales /= scale
    means3D = means3D[visibility_filter_all]
    means2D = means2D[visibility_filter_all]
    shs = None if shs is None else shs[visibility_filter_all]
    colors_precomp = None if colors_precomp is None else colors_precomp[visibility_filter_all]
    opacity = opacity[visibility_filter_all]
    scales = scales[visibility_filter_all]
    rotations = rotations[visibility_filter_all]
    cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp[visibility_filter_all]
    segs = segs[visibility_filter_all]
        
    # # Rasterize visible Gaussians to image, obtain their radii (on screen). 
    if render_seg:
        rendered_image, rendered_seg,radii, depth, median_depth, final_opacity = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            language_feature_precomp = segs,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp)

        

    else:
        rendered_image, _, radii, depth, median_depth, final_opacity = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = cov3D_precomp)
        rendered_seg = None

   
    return {"render": rendered_image,
            "render_seg" : rendered_seg,
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii,
            "final_opacity": final_opacity,
            "depth": depth,
            "median_depth": median_depth,}



def render_precomp(viewpoint_camera,pc, means3D, opacity,shs, covariance,visibility_filter_all, opt, bg_color: torch.Tensor, scaling_modifier=1.0, ):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
 
    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    # screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    screenspace_points = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=0,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug=opt.debug,
        include_feature = False
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    # means3D = pc.get_xyz
   
    means2D = screenspace_points
    # opacity = pc.get_opacity_with_3D_filter
    

    
        
    # opacity = pc.get_opacity_with_3D_filter_all

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    
    cov3D_precomp =covariance

    
    shs_view = shs.transpose(1, 2).view(-1, 3)
    colors_precomp = pc.color_activation(shs_view)
    shs = None

    means3D = means3D[visibility_filter_all]
    means2D = means2D[visibility_filter_all]
    shs = None if shs is None else shs[visibility_filter_all]
    colors_precomp = None if colors_precomp is None else colors_precomp[visibility_filter_all]
    opacity = opacity[visibility_filter_all]
    
    
    cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp[visibility_filter_all]
    
            
    

    rendered_image, _, radii, depth, median_depth, final_opacity = rasterizer(
        means3D = means3D,
        means2D = means2D,
        shs = shs,
        colors_precomp = colors_precomp,
        opacities = opacity,
        scales =None,
        rotations = None,
        cov3D_precomp =cov3D_precomp)
    rendered_seg = None

   
    return {"render": rendered_image,
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii,
            "final_opacity": final_opacity,
            "depth": depth,
            "median_depth": median_depth,}


def render_gsplat(viewpoint_camera, pc: GaussianModel, opt, bg_color: torch.Tensor, scaling_modifier=1.0, override_color=None, render_visible=False,render_seg = False, exclude_sky=False, remove_list = []):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
 
    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    # screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    # screenspace_points = torch.zeros_like(pc.get_xyz_all, dtype=pc.get_xyz_all.dtype, requires_grad=True, device="cuda") + 0
    # try:
    #     screenspace_points.retain_grad()
    # except:
    #     pass

    
    # means3D = pc.get_xyz
    means3D = pc.get_xyz_all
    # means2D = screenspace_points
    # opacity = pc.get_opacity_with_3D_filter
    opacity = pc.get_opacity_all

    segs = pc.get_seg_all
    
        
    # opacity = pc.get_opacity_with_3D_filter_all

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if opt.compute_cov3D_python:
        # cov3D_precomp = pc.get_covariance(scaling_modifier)
        cov3D_precomp = pc.get_covariance_all(scaling_modifier)
    else:
        # scales = pc.get_scaling_with_3D_filter
        # rotations = pc.get_rotation
        # scales = pc.get_scaling_with_3D_filter_all
        scales = pc.get_scaling_all
        rotations = pc.get_rotation_all

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    
    shs = pc.get_features_all
    colors_precomp = None



    if render_visible:
        visibility_filter_all = pc.visibility_filter_all & ~pc.delete_mask_all  # Seen in screen
    else:
        visibility_filter_all = ~pc.delete_mask_all

    if exclude_sky:
        visibility_filter_all = visibility_filter_all & ~pc.is_sky_filter

    if len(remove_list):
        label = segs.argmax(-1)
        remove_mask = torch.zeros_like(means3D[:,0]).bool()
        for k in remove_list:
            remove_mask[label == k] = 1
        visibility_filter_all = visibility_filter_all & (~remove_mask)

    means3D = means3D[visibility_filter_all]
    # means2D = means2D[visibility_filter_all]
    shs = None if shs is None else shs[visibility_filter_all]
    colors_precomp = None if colors_precomp is None else colors_precomp[visibility_filter_all]
    opacity = opacity[visibility_filter_all]
    scales = scales[visibility_filter_all]
    rotations = rotations[visibility_filter_all]
    cov3D_precomp = None if cov3D_precomp is None else cov3D_precomp[visibility_filter_all]
    segs = segs[visibility_filter_all]
        
    # rendered_image, alphas, normals, surf_normals, distort, median_depth, meta = rasterization_2dgs(
    #     means3D, 
    #     rotations, 
    #     scales,
    #     opacity.squeeze(), 
    #     shs, 
    #     viewpoint_camera.world_view_transform.unsqueeze(0), 
    #     viewpoint_camera.K.unsqueeze(0), 
    #     width = int(viewpoint_camera.image_width), 
    #     height = int(viewpoint_camera.image_height),
    #     sh_degree = 0,
    #     backgrounds = bg_color.unsqueeze(0),
    #     render_mode = "RGB"
    # )
    color, alphas, meta = rasterization(
        means3D, 
        rotations, 
        scales,
        opacity.squeeze(), 
        shs, 
        viewpoint_camera.world_view_transform.unsqueeze(0), 
        viewpoint_camera.K.unsqueeze(0), 
        width = int(viewpoint_camera.image_width), 
        height = int(viewpoint_camera.image_height),
        sh_degree = 0,
        backgrounds = bg_color.unsqueeze(0),
        render_mode = "RGB+ED"
    )
    radii, final_opacity = meta['radii'].squeeze().max(-1).values, alphas.squeeze().unsqueeze(0)
    rendered_image = color[...,:3]
    median_depth = color[...,-1:]
    

    # R = torch.tensor(viewpoint_camera.R, device=means3D.device, dtype=torch.float32)
    # point_normals_in_world = rotation2normal(rotations)
    # point_normals_in_screen = point_normals_in_world @ R

    # render_normal, _, _, _, _ = rasterizer(
    #     means3D = means3D,
    #     means2D = means2D,
    #     shs = None,
    #     colors_precomp = point_normals_in_screen,
    #     opacities = opacity,
    #     scales = scales,
    #     rotations = rotations,
    #     cov3D_precomp = cov3D_precomp)
    # render_normal = F.normalize(render_normal, dim = 0)   
     # Rasterize visible Gaussians to image, obtain their radii (on screen). 
    rendered_seg = None
    if render_seg:
        rendered_seg, _, _= rasterization(
        means3D.detach(), 
        rotations.detach(), 
        scales,
        opacity.squeeze().detach(), 
        segs, 
        viewpoint_camera.world_view_transform.unsqueeze(0), 
        viewpoint_camera.K.unsqueeze(0), 
        width = int(viewpoint_camera.image_width), 
        height = int(viewpoint_camera.image_height),
        # backgrounds = torch.zeros([1,segs.shape[-1]],dtype = torch.float32,device = means3D.device),
        render_mode = "RGB",
        )
        
            
        # print(rendered_seg.max(),rendered_seg.min())
        # print(rendered_seg.sum(0))
        rendered_seg = rendered_seg / (rendered_seg.sum(0, keepdim=True) + 1e-5)
        
        rendered_seg = torch.clamp(rendered_seg,min = 1e-5, max = 1 - 1e-5)    
        # print(rendered_seg.shape,rendered_seg.min(),rendered_seg.max(),segs.max(),segs.min())

    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
    # They will be excluded from value updates used in the splitting criteria.
    # import pdb
    # pdb.set_trace()
    viewspace_points = meta['means2d'] #meta['gradient_2dgs']
    viewspace_points.retain_grad()
    
    return {"render": rendered_image.squeeze(0).permute(2,0,1), #[H,W,C] -->[C,H,W] to align with the original code
            "render_seg" : rendered_seg.squeeze(0).permute(2,0,1) if rendered_seg is not None else None,  
            "viewspace_points": viewspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii,
            "final_opacity": final_opacity,
            # "means2d": meta['means2d'],
            # "depth": depth,
            "median_depth": median_depth.squeeze().unsqueeze(0),}