#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import torch.nn.functional as F
import math
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from scene.gaussian_model_latent_strands import GaussianModelHair
from scene.gaussian_render import GaussRenderer
from utils.sh_utils import eval_sh
from utils.general_utils import build_rotation
import time

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    conic_precomp = pc.get_conic(viewpoint_camera, scaling_modifier)
    screenspace_points = pc.get_mean_2d(viewpoint_camera)
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=torch.tan(viewpoint_camera.FoVx * 0.5).item(),
        tanfovy=torch.tan(viewpoint_camera.FoVy * 0.5).item(),
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=True,
        debug=pipe.debug
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    means3D = pc.get_xyz
    means2D_precomp = screenspace_points
    opacity = pc.get_opacity

    shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
    dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
    cov3D_precomp = pc.cov
    dir3D = pc.get_direction_2d(viewpoint_camera)
    colors_precomp = torch.cat(
        [
            torch.clamp_min(sh2rgb + 0.5, 0.0), 
            pc.get_label, 
            torch.ones_like(pc.get_label), # foreground mask
            dir3D, 
            pc.get_orient_conf, 
            pc.get_depths(viewpoint_camera)
        ], 
        dim=-1
    )

    points_mask = pc.filter_points(viewpoint_camera)

    means3D = means3D[points_mask]
    means2D_precomp = means2D_precomp[points_mask]
    colors_precomp = colors_precomp[points_mask]
    opacity = opacity[points_mask]
    cov3D_precomp = cov3D_precomp[points_mask]
    conic_precomp = conic_precomp[points_mask]

    radii = torch.zeros_like(pc.get_xyz[:, 0]).int()

    renders, _radii = rasterizer(
        means3D = means3D,
        means2D = means2D_precomp,
        shs = None,
        colors_precomp = colors_precomp,
        opacities = opacity,
        scales = None,
        rotations = None,
        cov3D_precomp = cov3D_precomp,
        conic_precomp = conic_precomp)

    radii[points_mask] = _radii

    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)

    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi

    return {"render": rendered_image,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii}
def render_hair(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    # viewpoint_camera.set_device("cuda")

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=True,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    # import pdb; pdb.set_trace()
    
    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf

    # shs_hair_final = shs_hair_final.reshape(-1,99,3, (pc_hair.max_sh_degree+1)**2)
    # random_values = torch.rand(shs_hair_final.shape[0], 1, 3, (pc_hair.max_sh_degree+1)**2).cuda()
    # shs_hair_final = (a_max-a_min) * random_values.expand(shs_hair_final.shape[0],99,3, (pc_hair.max_sh_degree+1)**2) + a_min
    # shs_hair_final = shs_hair_final.reshape(-1, 3, (pc_hair.max_sh_degree+1)**2)
    
    screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step)[pc.mask_precomp].detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    screenspace_points.requires_grad_(True)
    
    # import imageio
    # import numpy as np
    # img_height = viewpoint_camera.image_height
    # img_width = viewpoint_camera.image_width
    # gaussian_means_2D_2D = pc_hair.get_mean_2d(viewpoint_camera)[:, :2]
    # shift = torch.tensor([img_width / 2, img_height / 2], device="cuda")
    # gaussian_means_2D_2D = gaussian_means_2D_2D * shift.reshape(1,2) 
    # gaussian_means_2D_2D = gaussian_means_2D_2D + shift.reshape(1,2)
    # gaussian_means_2D_2D = gaussian_means_2D_2D.cpu().detach().numpy().astype(np.int64)
    # gaussian_means_2D_2D_mask = (gaussian_means_2D_2D[:,0] > 0) & (gaussian_means_2D_2D[:,1] > 0) & \
    #     (gaussian_means_2D_2D[:,0] < img_width) & (gaussian_means_2D_2D[:,1] < img_height)
    # image_mask = np.zeros_like(viewpoint_camera.original_mask[0].cpu().numpy()).astype(np.uint8)
    # gaussian_means_2D_2D = gaussian_means_2D_2D[gaussian_means_2D_2D_mask]
    # image_mask[gaussian_means_2D_2D[:,1],gaussian_means_2D_2D[:,0]] = 255
    # imageio.imwrite("./gaussian_means_2D_2D.png", image_mask)
    
    try:
        screenspace_points.retain_grad()
    except:
        pass
    conic_precomp = torch.cat([
        pc.get_conic(viewpoint_camera, scaling_modifier,viewpoint_camera.time_step)[pc.mask_precomp], 
        pc_hair.get_conic(viewpoint_camera, scaling_modifier)]
    )
        
    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    means2D_precomp = screenspace_points
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])

    points_mask = torch.cat([
        pc.filter_points(viewpoint_camera,viewpoint_camera.time_step)[pc.mask_precomp], 
        pc_hair.filter_points(viewpoint_camera)]
    )
    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])

    # shs_view = torch.cat([pc.shs_view, pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)])
    # shs_face = pc.shs_view
    # torch.manual_seed(42)
    # shs_face = torch.rand(shs_face.shape).cuda()
    # shs_view = torch.cat([shs_face, shs_hair_final])
    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    dir_pp = (means3D - viewpoint_camera.camera_center.repeat(shs_view.shape[0], 1))
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])
    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_orient_conf])
    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])
    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step)[pc.mask_precomp], pc_hair.get_depths(viewpoint_camera)])
    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)

    radii = torch.zeros_like(means3D[:, 0]).int()

    # print("pc.xyz_precomp   ",pc.xyz_precomp.shape)
    # print("pc_hair.get_xyz   ",pc_hair.get_xyz.shape)
    means3D = means3D[points_mask]
    # print("means3D   ",means3D.shape)
    means2D_precomp = means2D_precomp[points_mask]
    colors_precomp = colors_precomp[points_mask]
    opacity = opacity[points_mask]
    scales = scales[points_mask]
    rotations = rotations[points_mask]
    conic_precomp = conic_precomp[points_mask]


    renders, _radii = rasterizer(
        means3D = means3D,
        means2D = means2D_precomp,
        shs = None,
        colors_precomp = colors_precomp,
        opacities = opacity,
        scales = scales,
        rotations = rotations,
        cov3D_precomp = None,
        conic_precomp = conic_precomp)

    radii[points_mask] = _radii
    
    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)
    
    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi
    # import ipdb;ipdb.set_trace()

    return {"render": rendered_image,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii}
def render_hair_weight(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    # viewpoint_camera.set_device("cuda")

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=True,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    
    # import pdb; pdb.set_trace()
    
    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf

    # shs_hair_final = shs_hair_final.reshape(-1,99,3, (pc_hair.max_sh_degree+1)**2)
    # random_values = torch.rand(shs_hair_final.shape[0], 1, 3, (pc_hair.max_sh_degree+1)**2).cuda()
    # shs_hair_final = (a_max-a_min) * random_values.expand(shs_hair_final.shape[0],99,3, (pc_hair.max_sh_degree+1)**2) + a_min
    # shs_hair_final = shs_hair_final.reshape(-1, 3, (pc_hair.max_sh_degree+1)**2)
    
    screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step)[pc.mask_precomp].detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    screenspace_points.requires_grad_(True)
    
    # import imageio
    # import numpy as np
    # img_height = viewpoint_camera.image_height
    # img_width = viewpoint_camera.image_width
    # gaussian_means_2D_2D = pc_hair.get_mean_2d(viewpoint_camera)[:, :2]
    # shift = torch.tensor([img_width / 2, img_height / 2], device="cuda")
    # gaussian_means_2D_2D = gaussian_means_2D_2D * shift.reshape(1,2) 
    # gaussian_means_2D_2D = gaussian_means_2D_2D + shift.reshape(1,2)
    # gaussian_means_2D_2D = gaussian_means_2D_2D.cpu().detach().numpy().astype(np.int64)
    # gaussian_means_2D_2D_mask = (gaussian_means_2D_2D[:,0] > 0) & (gaussian_means_2D_2D[:,1] > 0) & \
    #     (gaussian_means_2D_2D[:,0] < img_width) & (gaussian_means_2D_2D[:,1] < img_height)
    # image_mask = np.zeros_like(viewpoint_camera.original_mask[0].cpu().numpy()).astype(np.uint8)
    # gaussian_means_2D_2D = gaussian_means_2D_2D[gaussian_means_2D_2D_mask]
    # image_mask[gaussian_means_2D_2D[:,1],gaussian_means_2D_2D[:,0]] = 255
    # imageio.imwrite("./gaussian_means_2D_2D.png", image_mask)
    
    try:
        screenspace_points.retain_grad()
    except:
        pass
    conic_precomp = torch.cat([
        pc.get_conic(viewpoint_camera, scaling_modifier,viewpoint_camera.time_step)[pc.mask_precomp], 
        pc_hair.get_conic(viewpoint_camera, scaling_modifier)]
    )
        
    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    means2D_precomp = screenspace_points
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])

    points_mask = torch.cat([
        pc.filter_points(viewpoint_camera,viewpoint_camera.time_step)[pc.mask_precomp], 
        pc_hair.filter_points(viewpoint_camera)]
    )
    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])

    # shs_view = torch.cat([pc.shs_view, pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)])
    # shs_face = pc.shs_view
    # torch.manual_seed(42)
    # shs_face = torch.rand(shs_face.shape).cuda()
    # shs_view = torch.cat([shs_face, shs_hair_final])
    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    dir_pp = (means3D - viewpoint_camera.camera_center.repeat(shs_view.shape[0], 1))
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])
    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_orient_conf])
    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])
    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step)[pc.mask_precomp], pc_hair.get_depths(viewpoint_camera)])
    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)

    radii = torch.zeros_like(means3D[:, 0]).int()
    gaussian_weight_input = torch.cat([dir_pp_normalized, depth],dim=-1)
    opacity = pc_hair.gaussianWeight(gaussian_weight_input) * opacity
    # print("pc.xyz_precomp   ",pc.xyz_precomp.shape)
    # print("pc_hair.get_xyz   ",pc_hair.get_xyz.shape)
    colors_precomp.requires_grad_(True)
    colors_precomp.retain_grad()
    opacity.requires_grad_(True)
    opacity.retain_grad()
    if pc_hair.idx_active_mask != None: 
        num_precomp = pc.xyz_precomp.shape[0]
        mask_prefix = torch.randint(0, 2, (num_precomp,), device=points_mask.device).bool()
        mask_suffix = torch.cat([mask_prefix, pc_hair.idx_active_mask])
        mask_suffix[pc.xyz_precomp.shape[0]:] = pc_hair.idx_active_mask
        points_mask_active = points_mask & mask_suffix
        points_mask_wo_active = points_mask & (~mask_suffix)
        active_idx = torch.nonzero(points_mask_active, as_tuple=True)[0]
        wo_active_idx = torch.nonzero(points_mask_wo_active, as_tuple=True)[0]
        means3D_active = means3D.index_select(0,active_idx)
        means2D_precomp_active = means2D_precomp.index_select(0,active_idx)
        colors_precomp_active = colors_precomp.index_select(0,active_idx)
        opacity_active = opacity.index_select(0,active_idx)
        scales_active = scales.index_select(0,active_idx)
        rotations_active = rotations.index_select(0,active_idx)
        conic_precomp_active = conic_precomp.index_select(0,active_idx)
        means3D_wo_active = means3D.index_select(0,wo_active_idx)
        means2D_precomp_wo_active = means2D_precomp.index_select(0,wo_active_idx)
        colors_precomp_wo_active = colors_precomp.index_select(0,wo_active_idx)
        opacity_wo_active = opacity.index_select(0,wo_active_idx)
        scales_wo_active = scales.index_select(0,wo_active_idx)
        rotations_wo_active = rotations.index_select(0,wo_active_idx)
        conic_precomp_wo_active = conic_precomp.index_select(0,wo_active_idx)
        # active render
        renders_active, accum_alpha_active, _radii_active = rasterizer(
        means3D = means3D_active,
        means2D = means2D_precomp_active,
        shs = None,
        colors_precomp = colors_precomp_active,
        opacities = opacity_active,
        scales = scales_active,
        rotations = rotations_active,
        cov3D_precomp = None,
        conic_precomp = conic_precomp_active)
        # wo active render
        with torch.no_grad():
            renders_wo_active, accum_alpha_wo_active, _radii_wo_active = rasterizer(
            means3D = means3D_wo_active,
            means2D = means2D_precomp_wo_active,
            shs = None,
            colors_precomp = colors_precomp_wo_active,
            opacities = opacity_wo_active,
            scales = scales_wo_active,
            rotations = rotations_wo_active,
            cov3D_precomp = None,
            conic_precomp = conic_precomp_wo_active)
        radii.index_copy_(0, active_idx, _radii_active)
        radii.index_copy_(0, wo_active_idx, _radii_wo_active)
        inv = (accum_alpha_active + accum_alpha_wo_active + 1e-8).reciprocal()
        renders = (renders_active + renders_wo_active).mul_(inv)
    else:
        # points_mask_active = points_mask
        active_idx = torch.nonzero(points_mask, as_tuple=True)[0]
        means3D_active = means3D.index_select(0,active_idx)
        means2D_precomp_active = means2D_precomp.index_select(0,active_idx)
        colors_precomp_active = colors_precomp.index_select(0,active_idx)
        opacity_active = opacity.index_select(0,active_idx)
        scales_active = scales.index_select(0,active_idx)
        rotations_active = rotations.index_select(0,active_idx)
        conic_precomp_active = conic_precomp.index_select(0,active_idx)
        renders, accum_alpha_active, _radii_active = rasterizer(
        means3D = means3D_active,
        means2D = means2D_precomp_active,
        shs = None,
        colors_precomp = colors_precomp_active,
        opacities = opacity_active,
        scales = scales_active,
        rotations = rotations_active,
        cov3D_precomp = None,
        conic_precomp = conic_precomp_active)
        inv_active = (accum_alpha_active + 1e-8).reciprocal()
        renders.mul_(inv_active)
        # radii[points_mask_active] = _radii_active
        radii.index_copy_(0, active_idx, _radii_active)
        # import ipdb;ipdb.set_trace()
    
    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)
    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi
    # import ipdb;ipdb.set_trace()

    return {"render": rendered_image,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "colors_precomp": colors_precomp,
            "opacity": opacity,
            "visibility_filter" : radii > 0,
            "radii": radii}
    
def render_hair_weight_coarse(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    viewpoint_camera.set_device("cuda")

    # Set up rasterization configuration
    # time1 = time.time()
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    
    with torch.no_grad():
        pc.mask_precomp = pc.get_label(viewpoint_camera.time_step)[..., 0] < 0.6
        pc.points_mask_head_indices = pc.mask_precomp.nonzero(as_tuple=True)[0]
        pc.xyz_precomp = pc.get_xyz(viewpoint_camera.time_step).detach()
        pc.opacity_precomp = pc.get_opacity(viewpoint_camera.time_step).detach()
        pc.scaling_precomp = pc.get_scaling(viewpoint_camera.time_step).detach()
        pc.rotation_precomp = pc.get_rotation(viewpoint_camera.time_step).detach()
        pc.cov3D_precomp = pc.get_covariance(scaling_modifier = 1.0,time_index = viewpoint_camera.time_step).detach()
        pc.shs_view = pc.get_features(viewpoint_camera.time_step).detach().transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1)**2)
    
    # import pdb; pdb.set_trace()
    # torch.cuda.synchronize()
    # time2 = time.time()
    # render_state = 'coarse'
    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf
    # shs_hair_final = shs_hair_final.reshape(-1,99,3, (pc_hair.max_sh_degree+1)**2)
    # random_values = torch.rand(shs_hair_final.shape[0], 1, 3, (pc_hair.max_sh_degree+1)**2).cuda()
    # shs_hair_final = (a_max-a_min) * random_values.expand(shs_hair_final.shape[0],99,3, (pc_hair.max_sh_degree+1)**2) + a_min
    # shs_hair_final = shs_hair_final.reshape(-1, 3, (pc_hair.max_sh_degree+1)**2)
    # torch.cuda.synchronize()
    # time3 = time.time()
    screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step).detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    screenspace_points.requires_grad_(True)
    # torch.cuda.synchronize()
    # time3_1 = time.time()
    try:
        screenspace_points.retain_grad()
    except:
        pass
    # conic_precomp = torch.cat([
    #     pc.get_conic(viewpoint_camera, scaling_modifier,viewpoint_camera.time_step), 
    #     pc_hair.get_conic(viewpoint_camera, scaling_modifier)]
    # )
    # torch.cuda.synchronize()
    # time3_2 = time.time()   
    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    means2D_precomp = screenspace_points
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])
    # torch.cuda.synchronize()
    # time3_3 = time.time()   
    # points_mask = torch.cat([
    #     pc.filter_points(viewpoint_camera,viewpoint_camera.time_step), 
    #     pc_hair.filter_points(viewpoint_camera)]
    # )
    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])
    # torch.cuda.synchronize()
    # time3_4 = time.time()
    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    # dir_pp = (means3D - viewpoint_camera.camera_center.repeat(shs_view.shape[0], 1))
    dir_pp = (means3D - viewpoint_camera.camera_center[None,:])
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    # torch.cuda.synchronize()
    # time3_5 = time.time()
    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])
    # torch.cuda.synchronize()
    # time3_6 = time.time()
    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_orient_conf])
    # torch.cuda.synchronize()
    # time3_7 = time.time()
    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])
    # torch.cuda.synchronize()
    # time3_8 = time.time()
    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step), pc_hair.get_depths(viewpoint_camera)])
    # torch.cuda.synchronize()
    # time3_9 = time.time()
    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)
    radii = torch.zeros_like(means3D[:, 0]).int()
    # torch.cuda.synchronize()
    # time3_10 = time.time()
    gaussian_weight_input = torch.cat([dir_pp_normalized, depth],dim=-1)
    # import ipdb;ipdb.set_trace()
    opacity = pc_hair.gaussianWeight(gaussian_weight_input) * opacity
    # torch.cuda.synchronize()
    # time3_11 = time.time()
    means3D.requires_grad_(True)
    means3D.retain_grad()
    colors_precomp.requires_grad_(True)
    colors_precomp.retain_grad()
    opacity.requires_grad_(True)
    opacity.retain_grad()
    scales.requires_grad_(True)
    scales.retain_grad()
    rotations.requires_grad_(True)
    rotations.retain_grad()
    # print("pc.xyz_precomp   ",pc.xyz_precomp.shape)
    # print("pc_hair.get_xyz   ",pc_hair.get_xyz.shape)
    # torch.cuda.synchronize()
    # time4 = time.time()
    if pc_hair.idx_hair_active_mask != None: 
        # points_mask_active = points_mask & pc_hair.idx_active_mask
        num_precomp = pc.xyz_precomp.shape[0]
        mask_prefix = torch.ones((num_precomp,), device='cuda').bool()
        pc_hair.idx_active_mask = torch.cat([mask_prefix,pc_hair.idx_hair_active_mask],dim=0)
        points_mask_active = pc_hair.idx_active_mask
        points_mask_wo_active = (~pc_hair.idx_active_mask)
        active_idx = torch.nonzero(points_mask_active, as_tuple=True)[0]
        means3D_active = means3D.index_select(0,active_idx)
        means2D_precomp_active = means2D_precomp.index_select(0,active_idx)
        colors_precomp_active = colors_precomp.index_select(0,active_idx)
        opacity_active = opacity.index_select(0,active_idx)
        scales_active = scales.index_select(0,active_idx)
        rotations_active = rotations.index_select(0,active_idx)
        # conic_precomp_active = conic_precomp.index_select(0,active_idx)
        # active render
        renders_active, accum_alpha_active, _radii_active = rasterizer(
        means3D = means3D_active,
        means2D = means2D_precomp_active,
        shs = None,
        colors_precomp = colors_precomp_active,
        opacities = opacity_active,
        scales = scales_active,
        rotations = rotations_active,
        cov3D_precomp = None,
        conic_precomp = None)
        # wo active render
        # points_mask_wo_active = points_mask & (~pc_hair.idx_active_mask)
        render_wo_active_data = pc_hair.wo_active_set_data[viewpoint_camera.camera_index]
        with torch.no_grad():
            if render_wo_active_data != None:
                points_mask_wo_active, renders_wo_active, accum_alpha_wo_active, _radii_wo_active = \
                render_wo_active_data["points_mask_wo_active"], render_wo_active_data["renders_wo_active"], render_wo_active_data["accum_alpha_wo_active"], render_wo_active_data["_radii_wo_active"]     
                wo_active_idx = torch.nonzero(points_mask_wo_active, as_tuple=True)[0]
            else:
                wo_active_idx = torch.nonzero(points_mask_wo_active, as_tuple=True)[0]
                means3D_wo_active = means3D.index_select(0,wo_active_idx)
                means2D_precomp_wo_active = means2D_precomp.index_select(0,wo_active_idx)
                colors_precomp_wo_active = colors_precomp.index_select(0,wo_active_idx)
                opacity_wo_active = opacity.index_select(0,wo_active_idx)
                scales_wo_active = scales.index_select(0,wo_active_idx)
                rotations_wo_active = rotations.index_select(0,wo_active_idx)
                # conic_precomp_wo_active = conic_precomp.index_select(0,wo_active_idx)
                renders_wo_active, accum_alpha_wo_active, _radii_wo_active = rasterizer(
                means3D = means3D_wo_active,
                means2D = means2D_precomp_wo_active,
                shs = None,
                colors_precomp = colors_precomp_wo_active,
                opacities = opacity_wo_active,
                scales = scales_wo_active,
                rotations = rotations_wo_active,
                cov3D_precomp = None,
                conic_precomp = None)
        # import ipdb; ipdb.set_trace()
        radii.index_copy_(0, active_idx, _radii_active)
        radii.index_copy_(0, wo_active_idx, _radii_wo_active)
        inv = (accum_alpha_active + accum_alpha_wo_active + 1e-8).reciprocal()
        renders = (renders_active + renders_wo_active).mul_(inv)
    else:
        renders, accum_alpha_active, _radii_active = rasterizer(
        means3D = means3D,
        means2D = means2D_precomp,
        shs = None,
        colors_precomp = colors_precomp,
        opacities = opacity,
        scales = scales,
        rotations = rotations,
        cov3D_precomp = None,
        conic_precomp = None)
        inv_active = (accum_alpha_active + 1e-8).reciprocal()
        renders.mul_(inv_active)
        # renders, _radii_active = rasterizer(
        # means3D = means3D,
        # means2D = means2D_precomp,
        # shs = None,
        # colors_precomp = colors_precomp,
        # opacities = opacity,
        # scales = scales,
        # rotations = rotations,
        # cov3D_precomp = None,
        # conic_precomp = None)
        # radii = _radii_active
        # import ipdb;ipdb.set_trace() 
    # torch.cuda.synchronize()
    # time5 = time.time()
    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)
    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi
    
    # import ipdb;ipdb.set_trace()

    return {"render": rendered_image,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "means3D": means3D,
            "colors_precomp": colors_precomp,
            "opacity": opacity,
            "scales": scales,
            "rotations": rotations,
            "visibility_filter" : radii > 0,
            "radii": radii}
    
def render_hair_weight_fine(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    viewpoint_camera.set_device("cuda")

    # Set up rasterization configuration
    # time1 = time.time()
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    with torch.no_grad():
        pc.mask_precomp = pc.get_label(viewpoint_camera.time_step)[..., 0] < 0.6
        pc.points_mask_head_indices = pc.mask_precomp.nonzero(as_tuple=True)[0]
        pc.xyz_precomp = pc.get_xyz(viewpoint_camera.time_step).detach()
        pc.opacity_precomp = pc.get_opacity(viewpoint_camera.time_step).detach()
        pc.scaling_precomp = pc.get_scaling(viewpoint_camera.time_step).detach()
        pc.rotation_precomp = pc.get_rotation(viewpoint_camera.time_step).detach()
        pc.cov3D_precomp = pc.get_covariance(scaling_modifier = 1.0,time_index = viewpoint_camera.time_step).detach()
        pc.shs_view = pc.get_features(viewpoint_camera.time_step).detach().transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1)**2)
    
    # import pdb; pdb.set_trace()
    # torch.cuda.synchronize()
    # time2 = time.time()
    # render_state = 'coarse'

    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        if pc_hair.idx_active_mask != None:
            nums_gaussian_head = pc.mask_precomp.shape[0]
            points_mask_active_hair = (pc_hair.idx_active_mask)[nums_gaussian_head:]
            points_mask_wo_active_hair = (~pc_hair.idx_active_mask)[nums_gaussian_head:]
            points_mask_active_hair_indices = points_mask_active_hair.nonzero(as_tuple=True)[0]
            points_mask_wo_active_hair_indices = points_mask_wo_active_hair.nonzero(as_tuple=True)[0]
            shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps, points_mask_active_hair_indices,points_mask_wo_active_hair_indices)
        else:
            shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf        
    # shs_hair_final = shs_hair_final.reshape(-1,99,3, (pc_hair.max_sh_degree+1)**2)
    # random_values = torch.rand(shs_hair_final.shape[0], 1, 3, (pc_hair.max_sh_degree+1)**2).cuda()
    # shs_hair_final = (a_max-a_min) * random_values.expand(shs_hair_final.shape[0],99,3, (pc_hair.max_sh_degree+1)**2) + a_min
    # shs_hair_final = shs_hair_final.reshape(-1, 3, (pc_hair.max_sh_degree+1)**2)
    # torch.cuda.synchronize()
    # time3 = time.time()
    # screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step).detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    # screenspace_points.requires_grad_(True)
    screenspace_points = torch.zeros_like(torch.cat([pc.xyz_precomp, pc_hair.get_xyz], dim=0),requires_grad=True, device="cuda") + 0
    # torch.cuda.synchronize()
    # time3_1 = time.time()
    try:
        screenspace_points.retain_grad()
    except:
        pass
    # conic_precomp = torch.cat([
    #     pc.get_conic(viewpoint_camera, scaling_modifier,viewpoint_camera.time_step), 
    #     pc_hair.get_conic(viewpoint_camera, scaling_modifier)]
    # )
    # torch.cuda.synchronize()
    # time3_2 = time.time()   
    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    means2D_precomp = screenspace_points
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])
    # torch.cuda.synchronize()
    # time3_3 = time.time()   
    # points_mask = torch.cat([
    #     pc.filter_points(viewpoint_camera,viewpoint_camera.time_step), 
    #     pc_hair.filter_points(viewpoint_camera)]
    # )
    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])
    # torch.cuda.synchronize()
    # time3_4 = time.time()
    # shs_view = torch.cat([pc.shs_view, pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)])
    # shs_face = pc.shs_view
    # torch.manual_seed(42)
    # shs_face = torch.rand(shs_face.shape).cuda()
    # shs_view = torch.cat([shs_face, shs_hair_final])
    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    # dir_pp = (means3D - viewpoint_camera.camera_center.repeat(shs_view.shape[0], 1))
    dir_pp = (means3D - viewpoint_camera.camera_center[None,:])
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    # torch.cuda.synchronize()
    # time3_5 = time.time()
    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])
    # torch.cuda.synchronize()
    # time3_6 = time.time()
    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_orient_conf])
    # torch.cuda.synchronize()
    # time3_7 = time.time()
    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])
    # torch.cuda.synchronize()
    # time3_8 = time.time()
    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step), pc_hair.get_depths(viewpoint_camera)])
    # torch.cuda.synchronize()
    # time3_9 = time.time()
    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)
    # colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), torch.zeros_like(rgb_precomp), torch.zeros_like(label_precomp), depth], dim=-1)
    radii = torch.zeros_like(means3D[:, 0]).int()
    # torch.cuda.synchronize()
    # time3_10 = time.time()
    gaussian_weight_input = torch.cat([dir_pp_normalized, depth],dim=-1)
    opacity = pc_hair.gaussianWeight(gaussian_weight_input) * opacity
    # import ipdb;ipdb.set_trace()
    # torch.cuda.synchronize()
    # time3_11 = time.time()
    means3D.requires_grad_(True)
    means3D.retain_grad()
    colors_precomp.requires_grad_(True)
    colors_precomp.retain_grad()
    opacity.requires_grad_(True)
    opacity.retain_grad()
    scales.requires_grad_(True)
    scales.retain_grad()
    rotations.requires_grad_(True)
    rotations.retain_grad()
    # print("pc.xyz_precomp   ",pc.xyz_precomp.shape)
    # print("pc_hair.get_xyz   ",pc_hair.get_xyz.shape)
    # torch.cuda.synchronize()
    # time4 = time.time()
    if pc_hair.idx_active_mask != None: 
        # points_mask_active = points_mask & pc_hair.idx_active_mask
        points_mask_active = pc_hair.idx_active_mask
        active_idx = torch.nonzero(points_mask_active, as_tuple=True)[0]
        means3D_active = means3D.index_select(0,active_idx)
        means2D_precomp_active = means2D_precomp.index_select(0,active_idx)
        colors_precomp_active = colors_precomp.index_select(0,active_idx)
        opacity_active = opacity.index_select(0,active_idx)
        scales_active = scales.index_select(0,active_idx)
        rotations_active = rotations.index_select(0,active_idx)
        # conic_precomp_active = conic_precomp.index_select(0,active_idx)
        # active render
        renders_active, accum_alpha_active, _radii_active = rasterizer(
        means3D = means3D_active,
        means2D = means2D_precomp_active,
        shs = None,
        colors_precomp = colors_precomp_active,
        opacities = opacity_active,
        scales = scales_active,
        rotations = rotations_active,
        cov3D_precomp = None,
        conic_precomp = None)
        # wo active render
        # points_mask_wo_active = points_mask & (~pc_hair.idx_active_mask)
        points_mask_wo_active = (~pc_hair.idx_active_mask)
        render_wo_active_data = pc_hair.wo_active_set_data[viewpoint_camera.camera_index]
        with torch.no_grad():
            if render_wo_active_data != None:
                points_mask_wo_active, renders_wo_active, accum_alpha_wo_active, _radii_wo_active = \
                render_wo_active_data["points_mask_wo_active"], render_wo_active_data["renders_wo_active"], render_wo_active_data["accum_alpha_wo_active"], render_wo_active_data["_radii_wo_active"]     
                wo_active_idx = torch.nonzero(points_mask_wo_active, as_tuple=True)[0]
            else:
                wo_active_idx = torch.nonzero(points_mask_wo_active, as_tuple=True)[0]
                means3D_wo_active = means3D.index_select(0,wo_active_idx)
                means2D_precomp_wo_active = means2D_precomp.index_select(0,wo_active_idx)
                colors_precomp_wo_active = colors_precomp.index_select(0,wo_active_idx)
                opacity_wo_active = opacity.index_select(0,wo_active_idx)
                scales_wo_active = scales.index_select(0,wo_active_idx)
                rotations_wo_active = rotations.index_select(0,wo_active_idx)
                # conic_precomp_wo_active = conic_precomp.index_select(0,wo_active_idx)
                renders_wo_active, accum_alpha_wo_active, _radii_wo_active = rasterizer(
                means3D = means3D_wo_active,
                means2D = means2D_precomp_wo_active,
                shs = None,
                colors_precomp = colors_precomp_wo_active,
                opacities = opacity_wo_active,
                scales = scales_wo_active,
                rotations = rotations_wo_active,
                cov3D_precomp = None,
                conic_precomp = None)
        # import ipdb; ipdb.set_trace()
        radii.index_copy_(0, active_idx, _radii_active)
        radii.index_copy_(0, wo_active_idx, _radii_wo_active)
        inv = (accum_alpha_active + accum_alpha_wo_active + 1e-8).reciprocal()
        renders = (renders_active + renders_wo_active).mul_(inv)
    else:
        # import ipdb; ipdb.set_trace()
        renders, accum_alpha_active, _radii_active = rasterizer(
        means3D = means3D,
        means2D = means2D_precomp,
        shs = None,
        colors_precomp = colors_precomp,
        opacities = opacity,
        scales = scales,
        rotations = rotations,
        cov3D_precomp = None,
        conic_precomp = None)
        inv_active = (accum_alpha_active + 1e-8).reciprocal()
        renders.mul_(inv_active)
        # renders, _radii_active = rasterizer(
        # means3D = means3D,
        # means2D = means2D_precomp,
        # shs = None,
        # colors_precomp = colors_precomp,
        # opacities = opacity,
        # scales = scales,
        # rotations = rotations,
        # cov3D_precomp = None,
        # conic_precomp = None)
        radii = _radii_active
        # import ipdb;ipdb.set_trace() 
    # torch.cuda.synchronize()
    # time5 = time.time()
    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)
    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi
    # torch.cuda.synchronize()
    # time6 = time.time()
    # time21 = time2 - time1
    # time32 = time3 - time2
    # time43 = time4 - time3
    # time54 = time5 - time4  
    # time65 = time6 - time5
    # time_total = time6 - time1
    # print("render_hair_weight_fine")
    # print("time_total:  ", time_total)
    # print("time1:       ", time21)
    # print("time2:       ", time32)
    # print("time3:       ", time43)
    # print("time4:       ", time54)
    # print("time5:       ", time65)
    
    # import ipdb;ipdb.set_trace()

    return {"render": rendered_image,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "means3D": means3D,
            "colors_precomp": colors_precomp,
            "opacity": opacity,
            "scales": scales,
            "rotations": rotations,
            "visibility_filter" : radii > 0,
            "radii": radii} 
def render_hair_weight_fine_test(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    # viewpoint_camera.set_device("cuda")

    # Set up rasterization configuration
    # time1 = time.time()
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=True,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    
    with torch.no_grad():
        pc.mask_precomp = pc.get_label(viewpoint_camera.time_step)[..., 0] < 0.6
        pc.points_mask_head_indices = pc.mask_precomp.nonzero(as_tuple=True)[0]
        pc.xyz_precomp = pc.get_xyz(viewpoint_camera.time_step).detach()
        pc.opacity_precomp = pc.get_opacity(viewpoint_camera.time_step).detach()
        pc.scaling_precomp = pc.get_scaling(viewpoint_camera.time_step).detach()
        pc.rotation_precomp = pc.get_rotation(viewpoint_camera.time_step).detach()
        pc.cov3D_precomp = pc.get_covariance(scaling_modifier = 1.0,time_index = viewpoint_camera.time_step).detach()
        pc.shs_view = pc.get_features(viewpoint_camera.time_step).detach().transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1)**2)

    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf

    screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step).detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    screenspace_points.requires_grad_(True)
    # torch.cuda.synchronize()
    # time3_1 = time.time()
    try:
        screenspace_points.retain_grad()
    except:
        pass
    # conic_precomp = torch.cat([
    #     pc.get_conic(viewpoint_camera, scaling_modifier,viewpoint_camera.time_step), 
    #     pc_hair.get_conic(viewpoint_camera, scaling_modifier)]
    # )
    # torch.cuda.synchronize()
    # time3_2 = time.time()   
    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    means2D_precomp = screenspace_points
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])
    # torch.cuda.synchronize()
    # time3_3 = time.time()   
    # points_mask = torch.cat([
    #     pc.filter_points(viewpoint_camera,viewpoint_camera.time_step), 
    #     pc_hair.filter_points(viewpoint_camera)]
    # )
    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])
    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    dir_pp = (means3D - viewpoint_camera.camera_center[None,:])
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])
    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])
    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step), pc_hair.get_depths(viewpoint_camera)])
    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)
    radii = torch.zeros_like(means3D[:, 0]).int()
    gaussian_weight_input = torch.cat([dir_pp_normalized, depth],dim=-1)
    opacity = pc_hair.gaussianWeight(gaussian_weight_input) * opacity
    means3D.requires_grad_(True)
    means3D.retain_grad()
    # torch.cuda.synchronize()
    # time3_11 = time.time()
    renders, accum_alpha_active, _radii_active = rasterizer(
    means3D = means3D,
    means2D = means2D_precomp,
    shs = None,
    colors_precomp = colors_precomp,
    opacities = opacity,
    scales = scales,
    rotations = rotations,
    cov3D_precomp = None,
    conic_precomp = None)
    inv_active = (accum_alpha_active + 1e-8).reciprocal()
    renders.mul_(inv_active)
    radii = _radii_active
    # renders, _radii_active = rasterizer(
    # means3D = means3D,
    # means2D = means2D_precomp,
    # shs = None,
    # colors_precomp = colors_precomp,
    # opacities = opacity,
    # scales = scales,
    # rotations = rotations,
    # cov3D_precomp = None,
    # conic_precomp = None)
    
    import random
    eps =  1e-4
    import ipdb;ipdb.set_trace()
    # renders_final = renders_final[:3]
    renders.mean().backward()
    grad_p = means3D.grad.clone()
    while True:
        random.seed(time.time()) 
        i = random.randint(0, means3D.shape[0])
        old_means3D = means3D.data.clone()
        means3D.data[i,0] += eps
        renders1, accum_alpha_active1, _radii_active1 = rasterizer(
            means3D = means3D,
            means2D = means2D_precomp,
            shs = None,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = None,
            conic_precomp = None)
        inv_active1 = (accum_alpha_active1 + 1e-8).reciprocal()
        renders1.mul_(inv_active1)
        pred = (renders1.mean() - renders.mean())/eps
        print("index:                      ",i)
        print("renders1_final.mean():      ",renders1.mean().item())
        print("renders_final.mean():       ",renders.mean().item())
        print("pred:                       ",pred.item())
        print("grad_n[i]:                  ",grad_p[i,0].item())
        # print(pred, grad_n[i])
        # import ipdb;ipdb.set_trace()
        means3D.data[i,0] -= eps
        print(abs(pred.item() - grad_p[i,0].item())/renders.mean().item())
        assert abs(pred.item() - grad_p[i,0].item())/renders.mean().item() < 1e-2
    assert False  
    
    
    radii = _radii_active
    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)
    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi

    return {"render": rendered_image,
            "renders": renders,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "means3D": means3D,
            "colors_precomp": colors_precomp,
            "opacity": opacity,
            "scales": scales,
            "rotations": rotations,
            "visibility_filter" : radii > 0,
            "radii": radii}
def render_hair_weight_sparse(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    # viewpoint_camera.set_device("cuda")
    # Set up rasterization configuration
    # torch.cuda.synchronize()
    # time1 = time.time()
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=True,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    
    with torch.no_grad():
        pc.mask_precomp = pc.get_label(viewpoint_camera.time_step)[..., 0] < 0.6
        nums_gaussian_head = pc.mask_precomp.shape[0]
        if pc_hair.idx_hair_active_mask != None:
            points_mask_active_hair = pc_hair.idx_hair_active_mask
            points_mask_active_head = torch.zeros(nums_gaussian_head).bool().cuda()
        else:
            points_mask_active_head = pc_hair.idx_active_mask[:nums_gaussian_head]
            points_mask_active_hair = pc_hair.idx_active_mask[nums_gaussian_head:]
        # if pc_hair.points_mask_active_hair_indices != None and False:  
        if pc_hair.points_mask_active_hair_indices != None:  
            points_mask_active_hair_indices = pc_hair.points_mask_active_hair_indices
        else:
            points_mask_active_hair_indices = points_mask_active_hair.nonzero(as_tuple=True)[0]
        pc.mask_precomp = pc.mask_precomp & points_mask_active_head
        pc.points_mask_head_indices = pc.mask_precomp.nonzero(as_tuple=True)[0]
        pc.xyz_precomp = pc.get_xyz(viewpoint_camera.time_step).detach()
        pc.opacity_precomp = pc.get_opacity(viewpoint_camera.time_step).detach()
        pc.scaling_precomp = pc.get_scaling(viewpoint_camera.time_step).detach()
        pc.rotation_precomp = pc.get_rotation(viewpoint_camera.time_step).detach()
        pc.cov3D_precomp = pc.get_covariance(scaling_modifier = 1.0,time_index = viewpoint_camera.time_step).detach()
        pc.shs_view = pc.get_features(viewpoint_camera.time_step).detach().transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1)**2)

    points_mask_active_hair_indices = pc_hair.points_mask_active_hair_indices
    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf
    # orient_conf_hair_final = orient_conf_hair_final.index_select(0, points_mask_active_hair_indices)
    shs_hair_final = shs_hair_final.index_select(0, points_mask_active_hair_indices)
    orient_conf_hair_final = orient_conf_hair_final.index_select(0, points_mask_active_hair_indices)
    # pc_hair.points_mask_active_hair = points_mask_active_hair
    pc_hair.points_mask_hair_indices = points_mask_active_hair_indices

    screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step).detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    screenspace_points.requires_grad_(True)
    # torch.cuda.synchronize()
    # time3_1 = time.time()
    try:
        screenspace_points.retain_grad()
    except:
        pass

    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    means2D_precomp = screenspace_points
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])

    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])

    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    dir_pp = (means3D - viewpoint_camera.camera_center[None,:])
    # dir_pp = (means3D - viewpoint_camera.camera_center.repeat(shs_view.shape[0], 1))
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)

    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])

    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_orient_conf])

    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])

    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step), pc_hair.get_depths(viewpoint_camera)])

    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)

    gaussian_weight_input = torch.cat([dir_pp_normalized, depth],dim=-1)
    opacity = pc_hair.gaussianWeight(gaussian_weight_input) * opacity

    points_mask_active_global = torch.zeros_like(pc_hair.idx_active_mask, dtype=torch.bool)
    radii = torch.zeros_like(pc_hair.idx_active_mask).int()
    # points_mask_active_global[pc_hair.idx_active_mask] = points_mask_active
    points_mask_active_global[pc_hair.idx_active_mask] = True

    active_idx_global = torch.nonzero(points_mask_active_global, as_tuple=True)[0]

    renders_active, accum_alpha_active, _radii_active = rasterizer(
    means3D = means3D,
    means2D = means2D_precomp,
    shs = None,
    colors_precomp = colors_precomp,
    opacities = opacity,
    scales = scales,
    rotations = rotations,
    cov3D_precomp = None,
    conic_precomp = None)
    # import ipdb; ipdb.set_trace()
    # wo active render
    render_wo_active_data = pc_hair.wo_active_set_data[viewpoint_camera.camera_index]
    with torch.no_grad():
        points_mask_wo_active, renders_wo_active, accum_alpha_wo_active, _radii_wo_active = \
        render_wo_active_data["points_mask_wo_active"], render_wo_active_data["renders_wo_active"], render_wo_active_data["accum_alpha_wo_active"], render_wo_active_data["_radii_wo_active"]     
        wo_active_idx = torch.nonzero(points_mask_wo_active, as_tuple=True)[0]
    # torch.cuda.synchronize()
    # time4_2 = time.time()  
    # import ipdb; ipdb.set_trace()
    radii.index_copy_(0, active_idx_global, _radii_active)
    radii.index_copy_(0, wo_active_idx, _radii_wo_active)
    inv = (accum_alpha_active + accum_alpha_wo_active + 1e-8).reciprocal()
    renders = (renders_active + renders_wo_active).mul_(inv)
    # torch.cuda.synchronize()
    # time5 = time.time()
    
    rendered_image, rendered_mask, rendered_cov2D, rendered_orient_conf, _ = renders.split([3, 2, 3, 1, 1], dim=0)
    rendered_dir2D = F.normalize(rendered_cov2D[:2], dim=0)
    to_mirror = torch.ones_like(rendered_dir2D[[0]])
    to_mirror[rendered_dir2D[[0]] < 0] *= -1
    rendered_orient_angle = torch.acos(rendered_dir2D[[1]].clamp(-1 + 1e-3, 1 - 1e-3) * to_mirror) / math.pi
    # pc_hair.points_mask_active_hair = None
    pc_hair.points_mask_hair_indices = None
    pc.points_mask_head_indices = None
    # torch.cuda.synchronize()
    # time6 = time.time()
    # time21 = time2 - time1
    # time32 = time3 - time2
    # time43 = time4 - time3
    # time54 = time5 - time4  
    # time65 = time6 - time5
    # time_total = time6 - time1
    # print("render_hair_weight_sparse")
    # print("num points:  ",pc_hair.idx_active_mask.sum())
    # print("time_total:  ", time_total)
    # print("time1:       ", time21)
    # print("time2:       ", time32)
    # print("time3:       ", time43)
    # print("time4:       ", time54)
    # print("time5:       ", time65)
    # print("time31:      ",time3_1 - time3)
    # print("time32:      ",time3_2 - time3_1)
    # print("time33:      ",time3_3 - time3_2)
    # print("time34:      ",time3_4 - time3_3)
    # print("time35:      ",time3_5 - time3_4)
    # print("time36:      ",time3_6 - time3_5)
    # print("time37:      ",time3_7 - time3_6)
    # print("time38:      ",time3_8 - time3_7)
    # print("time39:      ",time3_9 - time3_8)
    # print("time310:     ",time3_10 - time3_9)
    # print("time311:     ",time4 - time3_10)
    # print("time41:      ",time4_1 - time4)
    # print("time42:      ",time4_2 - time4_1)
    # print("time43:      ",time5 - time4_2)
    # print("time1 scale: ", time21 / time_total)
    # print("time2 scale: ", time32 / time_total)
    # print("time3 scale: ", time43 / time_total)
    # print("time4 scale: ", time54 / time_total)
    # print("time5 scale: ", time65 / time_total)
    

    return {"render": rendered_image,
            "mask": rendered_mask,
            "orient_angle": rendered_orient_angle,
            "orient_conf": rendered_orient_conf,
            "viewspace_points": screenspace_points,
            "means3D": means3D,
            "colors_precomp": colors_precomp,
            "opacity": opacity,
            "scales": scales,
            "rotations": rotations,
            "visibility_filter" : radii > 0,
            "radii": radii}
def render_hair_weight_wo_active(viewpoint_camera, pc : GaussianModel, pc_hair: GaussianModelHair, pipe, bg_color : torch.Tensor, render_state: str, scaling_modifier = 1.0):
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
    # to device
    # viewpoint_camera.set_device("cuda")

    # Set up rasterization configuration
    # time1 = time.time()
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc_hair.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=True,
        debug=pipe.debug
    )
    
    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
    with torch.no_grad():
        pc.mask_precomp = pc.get_label(viewpoint_camera.time_step)[..., 0] < 0.6
        nums_gaussian_head = pc.mask_precomp.shape[0]
        if pc_hair.idx_hair_active_mask != None:
            points_mask_active_hair = ~pc_hair.idx_hair_active_mask
            points_mask_active_head = torch.ones(nums_gaussian_head).bool().cuda()
        else:
            points_mask_active_head = (~pc_hair.idx_active_mask)[:nums_gaussian_head]
            points_mask_active_hair = (~pc_hair.idx_active_mask)[nums_gaussian_head:]
        # if pc_hair.points_mask_wo_active_hair_indices != None and False:  
        if pc_hair.points_mask_wo_active_hair_indices != None:  
            points_mask_wo_active_hair_indices = pc_hair.points_mask_wo_active_hair_indices
        else:
            print("wo_active_set_data is None")
            points_mask_wo_active_hair_indices = points_mask_active_hair.nonzero(as_tuple=True)[0]
        pc.mask_precomp = pc.mask_precomp & points_mask_active_head
        pc.points_mask_head_indices = (pc.mask_precomp & points_mask_active_head).nonzero(as_tuple=True)[0]
        pc.xyz_precomp = pc.get_xyz(viewpoint_camera.time_step).detach()
        pc.opacity_precomp = pc.get_opacity(viewpoint_camera.time_step).detach()
        pc.scaling_precomp = pc.get_scaling(viewpoint_camera.time_step).detach()
        pc.rotation_precomp = pc.get_rotation(viewpoint_camera.time_step).detach()
        pc.cov3D_precomp = pc.get_covariance(scaling_modifier = 1.0,time_index = viewpoint_camera.time_step).detach()
        pc.shs_view = pc.get_features(viewpoint_camera.time_step).detach().transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1)**2)
    # time2 = time.time()
    # import pdb; pdb.set_trace()
    
    if render_state == "fine":
        time_step  = viewpoint_camera.time_step
        num_time_steps = viewpoint_camera.num_time_steps
        shs_hair_final, orient_conf_hair_final = pc_hair.set_deformation(time_step, num_time_steps)
    else:
        shs_hair_final = pc_hair.get_features.transpose(1, 2).view(-1, 3, (pc_hair.max_sh_degree+1)**2)
        orient_conf_hair_final = pc_hair.get_orient_conf
    # shs_hair_final = shs_hair_final[points_mask_active_hair]
    shs_hair_final = shs_hair_final.index_select(0, points_mask_wo_active_hair_indices)
    orient_conf_hair_final = orient_conf_hair_final.index_select(0, points_mask_wo_active_hair_indices)
    # pc_hair.points_mask_active_hair = points_mask_active_hair
    pc_hair.points_mask_hair_indices = points_mask_wo_active_hair_indices

    screenspace_points = torch.cat([pc.get_mean_2d(viewpoint_camera,viewpoint_camera.time_step).detach(), pc_hair.get_mean_2d(viewpoint_camera)], dim=0)
    screenspace_points.requires_grad_(True)

    try:
        screenspace_points.retain_grad()
    except:
        pass
    means2D_precomp = screenspace_points
    

    means3D = torch.cat([pc.xyz_precomp, pc_hair.get_xyz])
    opacity = torch.cat([pc.opacity_precomp, pc_hair.get_opacity])

    scales = torch.cat([pc.scaling_precomp, pc_hair.get_scaling])
    rotations = torch.cat([pc.rotation_precomp, pc_hair.get_rotation])

    shs_view = torch.cat([pc.shs_view, shs_hair_final])
    dir_pp = (means3D - viewpoint_camera.camera_center.repeat(shs_view.shape[0], 1))
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(pc_hair.active_sh_degree, shs_view, dir_pp_normalized)
    rgb_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    label_precomp = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_label])
    cov2D = torch.cat([torch.zeros_like(pc.xyz_precomp), pc_hair.get_direction_2d(viewpoint_camera)])
    # orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), pc_hair.get_orient_conf])
    orient_conf = torch.cat([torch.zeros_like(pc.xyz_precomp[:, :1]), orient_conf_hair_final])
    depth = torch.cat([pc.get_depths(viewpoint_camera,viewpoint_camera.time_step), pc_hair.get_depths(viewpoint_camera)])
    colors_precomp = torch.cat([rgb_precomp, label_precomp, torch.ones_like(label_precomp), cov2D, orient_conf, depth], dim=-1)

    # radii = torch.zeros_like(means3D[:, 0]).int()
    gaussian_weight_input = torch.cat([dir_pp_normalized, depth],dim=-1)
    opacity = pc_hair.gaussianWeight(gaussian_weight_input) * opacity

    points_mask_wo_active_global = torch.zeros_like(pc_hair.idx_active_mask, dtype=torch.bool)
    # radii = torch.zeros_like(pc_hair.idx_active_mask).int()
    points_mask_wo_active_global[~pc_hair.idx_active_mask] = True
    if pc_hair.idx_active_mask != None: 

        renders_wo_active, accum_alpha_wo_active, _radii_wo_active = rasterizer(
        means3D = means3D,
        means2D = means2D_precomp,
        shs = None,
        colors_precomp = colors_precomp,
        opacities = opacity,
        scales = scales,
        rotations = rotations,
        cov3D_precomp = None,
        conic_precomp = None)

    pc.points_mask_head_indices = None
    pc_hair.points_mask_hair_indices = None

    return {"points_mask_wo_active": points_mask_wo_active_global,
            "renders_wo_active": renders_wo_active,
            "accum_alpha_wo_active": accum_alpha_wo_active,
            "_radii_wo_active" : _radii_wo_active}
