#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import numpy as np
from utils.general_utils import inverse_sigmoid,build_rotation
from torch import nn
import torch.nn.functional as F
import os
from utils.general_utils import strip_symmetric, get_expon_lr_func, build_scaling_rotation, parallel_transport
from utils.system_utils import mkdir_p
# from gaussian_renderer import render_hair, network_gui, render_hair_torch, render_hair_torch_partial, render_hair_torch_fine, render_hair_weight, render_hair_weight_fine, render_hair_weight_wo_active
from plyfile import PlyData, PlyElement
import math
import pickle as pkl
import sys
import trimesh
from scene.deformation import deform_network
from scene.regulation import compute_plane_smoothness
from scene.gaussian_weight import GaussianWeightPred,gaussian_weight_pred_pe
from scene.gaussian_render import GaussRenderer
from pysdf import SDF
sys.path.append('../ext/NeuralHaircut/')
sys.path.append('../ext/NeuralHaircut/k-diffusion')
from src.hair_networks.optimizable_textured_strands import OptimizableTexturedStrands
from src.hair_networks.strand_prior import Decoder, Encoder
from torch.func import vjp
from torch.autograd import Function
from simple_knn._C import distCUDA2


class GaussianModelCurves:
    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation, return_full_covariance=False):
            M = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = M.transpose(1, 2) @ M
            if return_full_covariance:
                return actual_covariance
            else:
                symm = strip_symmetric(actual_covariance)
                return symm
        
        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.label_activation = torch.sigmoid
        self.inverse_label_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize

        self.orient_conf_activation = torch.exp
        self.orient_conf_inverse_activation = torch.log

    def __init__(self, data_dir, flame_mesh_dir, strands_config, texture_hidden_config, deformation_config, sh_degree, start_time_step, num_time_steps):
        num_guiding_strands = strands_config['extra_args']['num_guiding_strands']
        self.num_guiding_strands = num_guiding_strands if num_guiding_strands is not None else 0
        self.active_sh_degree = sh_degree
        self.max_sh_degree = sh_degree
        self._xyz_static = torch.empty(0)
        self._pts_wo_root_static = torch.empty(0)
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self._orient_conf = torch.empty(0)
        self._label = torch.empty(0)
        self._opacity_uncertainty = torch.empty(0)
        self._opacity_view = torch.empty(0)
        # self._relative_length_loss = torch.empty(0)
        self.strands_generator = OptimizableTexturedStrands(
            **strands_config['textured_strands'], 
            diffusion_cfg=strands_config['diffusion_prior'],
            texture_hidden_config = texture_hidden_config,
            data_dir=data_dir,
            flame_mesh_dir=flame_mesh_dir,
            num_guiding_strands=num_guiding_strands,
            start_time_step=start_time_step,
            num_time_steps=num_time_steps
        ).cuda()
        self.flame_mesh_dir = flame_mesh_dir
        self.color_decoder = Decoder(None, dim_hidden=128, num_layers=2, dim_out=3*(self.max_sh_degree+1)**2 + 1).cuda()
        self.strands_encoder = Encoder(None).eval().cuda()
        self.strands_encoder.load_state_dict(torch.load(f'../ext/NeuralHaircut/pretrained_models/strand_prior/strand_ckpt.pth')['encoder'])
        self.optimizer = None
        self.use_sds = True
        self.time_list = {} 
        self.time_sparse_list = {} 
        self.dirs_times = {}
        self.deformaton_pts_scale = 0.00001 # self.dirs
        self.deformaton_color_scale = 0.00001 # self.dirs
        self.deformaton_hf_pts_scale = 0.00001 # self.dirs
        self.deformaton_hf_color_scale = 0.00001 # self.dirs
        self.deformaton_coarse_scale = 1e-2 # self.points
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(0)
        self.xyz_gradient_accum_multi_view = torch.empty(0)
        self.denom_list = torch.empty(0)
        self.denom_multi_view = torch.empty(0)
        self._deformation_coarse = deform_network(deformation_config)
        # self.gaussianWeight = GaussianWeightPred(input_dim=4, hidden_dims=[128,128,128])
        # self.gaussianWeight = GaussianWeightPred(input_dim=4, hidden_dims=[32,64,32])
        self.gaussianWeight = gaussian_weight_pred_pe(input_dim=4,pe_feq=1)
        self.gaussRender = GaussRenderer(active_sh_degree=sh_degree, white_bkgd=False)
        # self._deformation_hf = deform_network(deformation_config)
        self._deformation_hf = nn.Identity()
        deformation_config.kplanes_config.resolution = [1024, 1024, 1024,60]
        deformation_config.kplanes_config.output_coordinate_dim = 16
        # deformation_config.kplanes_config.grid_dimensions = 4
        deformation_config.net_width = 128
        deformation_config.defor_depth = 2
        deformation_config.multires = [1, 2]
        # deformation_config.multires = [1, 2, 4, 8]
        self._deformation = deform_network(deformation_config)
        self.percent_dense = 0
        self.spatial_lr_scale = 0
        self._deformation_table = torch.empty(0)
        self.deformation_state = 'lf'
        self.iteration = 0
        self.time_used = 0
        self.GPU_time_used = 0
        self.hair_smoothness = None
        self.points_mask_active_hair = None
        self.points_mask_hair_indices = None
        self.loss_pts = 0
        self.pts_wo_rot = 0
        self.loss_feat = 0
        self.shs_view_hair = 0
        self._d_xyz_norm = 0
        self.wo_active_set_data = [None] * 16
        # self.grid_active_set = None
        self.idx_active_mask = None
        self.idx_hair_active_mask = None
        self.points_mask_wo_active_hair_indices = None
        self.points_mask_active_hair_indices = None
        self.uvs_mask_active_count = 0
        self.uvs_mask_active_count_50 = 0
        self.uvs_mask_active_count_54 = 0
        self.uvs_mask_active_count_43 = 0
        self.uvs_mask_active_count_32 = 0
        self.uvs_mask_active_count_21 = 0
        self.uvs_mask_active_count_10 = 0
        self.uvs_mask_active_count_01 = 0
        self.uvs_mask_active_count_12 = 0
        self.uvs_mask_active_count_23 = 0
        self.uvs_mask_active_count_34 = 0
        self.uvs_mask_active_count_45 = 0
        self.uvs_mask_active_count_56 = 0        
        self.uvs_mask_active_count_67 = 0        
        self.uvs_mask_active_count_78 = 0        
        self.uvs_mask_active_count_89 = 0        
        self.uvs_mask_active_count_90 = 0        
        self.gaussian_hair_pos_grad_average = 0
        self.gaussian_hair_features_grad_average = 0
        self.gaussian_hair_orient_conf_grad_average = 0
        self.setup_functions()

    def capture(self):
        return (
            self._pts,
            self._features_dc,
            self._features_rest,
            self._orient_conf,
            self.active_sh_degree,
            self.optimizer.state_dict(),
            self._xyz,
            self._scaling,
            self._rotation,
            self._opacity,
            self.max_radii2D,
            self.spatial_lr_scale,
            self.xyz_gradient_accum,
            self.denom,
            self._deformation.state_dict(),
            self._deformation_hf.state_dict(),
            self._deformation_table,
            self.time_list,
            self._xyz_static,
            self._pts_wo_root_static,
            self.gaussianWeight.state_dict(),
        )
    
    def restore(self, model_args, training_args):
        (
            self._pts,
            self._features_dc,
            self._features_rest,
            self._orient_conf,
            self.active_sh_degree, 
            opt_dict,
            self._xyz,
            self._scaling, 
            self._rotation, 
            self._opacity,
            self.max_radii2D,
            self.spatial_lr_scale,
            self.xyz_gradient_accum,
            self.denom,
            deform_dict,
            deform_hf_dict,
            self._deformation_table,
            self.time_list,
            self._xyz_static,
            self._pts_wo_root_static,
            gaussian_weight,
        ) = model_args
        self.pts_origins = self._pts[:, :1]
        self._dirs = self._pts[:, 1:] - self._pts[:, :-1]
        # self._orient_conf = torch.ones_like(self._features_dc[:, :1, 0])
        try:
            self.training_setup(training_args)
            self.optimizer.load_state_dict(opt_dict)
        except:
            print('Failed to load optimizer')
        self._deformation.load_state_dict(deform_dict)
        self._deformation_hf.load_state_dict(deform_hf_dict)
        self.gaussianWeight.load_state_dict(gaussian_weight)
        # import ipdb;ipdb.set_trace()

    @property
    def get_scaling(self):
        if self.points_mask_hair_indices != None:
            return self._scaling.index_select(0, self.points_mask_hair_indices)
        else:
            return self._scaling
        # if self.points_mask_hair_indices != None:
        #     return self.scaling_activation(self._scaling.index_select(0, self.points_mask_hair_indices))
        # else:
        #     return self.scaling_activation(self._scaling)
    
    @property
    def get_rotation(self):
        if self.points_mask_hair_indices != None:
            return self.rotation_activation(self._rotation.index_select(0, self.points_mask_hair_indices))
        else:
            return self.rotation_activation(self._rotation)

    @property
    def get_xyz(self):
        if self.points_mask_hair_indices != None:
            return self._xyz.index_select(0, self.points_mask_hair_indices)
        else:
            return self._xyz
    
    @property
    def get_features(self):
        if self.points_mask_hair_indices != None:
            features_dc = self._features_dc.index_select(0, self.points_mask_hair_indices)
            features_rest = self._features_rest.index_select(0, self.points_mask_hair_indices)
        else:
            features_dc = self._features_dc
            features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)

    @property
    def get_opacity(self):
        return torch.ones_like(self.get_xyz[:, :1])
        # return self.opacity_activation(self._opacity)
        # return self._opacity

    @property
    def get_label(self):
        return torch.ones_like(self.get_xyz[:, :1])
    
    @property
    def get_orient_conf(self):
        if self.points_mask_hair_indices != None:
            return self.orient_conf_activation(self._orient_conf.index_select(0, self.points_mask_hair_indices))
        else:
            return self.orient_conf_activation(self._orient_conf)
    
    def set_deformation_scale(self, lr_init, lr_final, step, max_steps):
        t = np.clip(step / max_steps, 0, 1)
        self.deformaton_scale = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
        
    def set_deformation(self,time_step, num_time_steps,sparse_index = None, wo_sparse_index=None):
        # scales_hair = self.get_scaling
        # rotations_hair = self.get_rotation
        opacity_hair = self.get_opacity
        shs_view_hair_static = self.get_features
        orient_conf_stiatic = self.get_orient_conf
        _time = time_step / num_time_steps
 
        with torch.no_grad():
            pts_init = self.time_list[0][0].cuda()
            pts = self.time_list[time_step][0].cuda()
            pts_origins = pts[:, :1]
            features_dc_init = self.time_list[0][3].cuda()
            features_dc = self.time_list[time_step][3].cuda()
            features_rest_init = self.time_list[0][4].cuda()
            features_rest = self.time_list[time_step][4].cuda()
            N = pts[:, 1:].shape[0]
            L = pts[:, 1:].shape[1]
        # pts_wo_rot = pts_wo_rot.reshape(-1, 3)
        # if time_step == 0:
        if time_step == -1 or True:
            _pts_wo_root_dynamic = self._pts_wo_root_static
            features_dynamic = shs_view_hair_static
            orient_conf_final = orient_conf_stiatic
        else:
            if sparse_index != None:
                _time_sparse = torch.tensor(_time).to(self._pts_wo_root_static.device).repeat(self._pts_wo_root_static.index_select(0,sparse_index).shape[0],1)
                # import ipdb;ipdb.set_trace()
                _pts_wo_root_dynamic_sparse, _, _, _, features_dynamic_sparse, orient_conf_final_sparse = \
                    self._deformation(self._pts_wo_root_static.index_select(0,sparse_index), 
                                    None, 
                                    None, 
                                    opacity_hair.index_select(0,sparse_index), 
                                    shs_view_hair_static.index_select(0,sparse_index), 
                                    orient_conf_stiatic.index_select(0,sparse_index), 
                                    _time_sparse, 
                                    self.deformaton_pts_scale,
                                    self.deformaton_color_scale)  
                _pts_wo_root_dynamic = torch.empty_like(self._pts_wo_root_static)
                features_dynamic = torch.empty_like(shs_view_hair_static)
                orient_conf_final = torch.empty_like(orient_conf_stiatic)
                with torch.no_grad():
                    _time_wo_sparse = torch.tensor(_time).to(self._pts_wo_root_static.device).repeat(self._pts_wo_root_static.index_select(0,wo_sparse_index).shape[0],1)
                    _pts_wo_root_dynamic_wo_sparse, _, _, _, features_dynamic_wo_sparse, orient_conf_final_wo_sparse = \
                        self._deformation(self._pts_wo_root_static.index_select(0,wo_sparse_index), 
                                        None, 
                                        None, 
                                        opacity_hair.index_select(0,wo_sparse_index), 
                                        shs_view_hair_static.index_select(0,wo_sparse_index), 
                                        orient_conf_stiatic.index_select(0,wo_sparse_index), 
                                        _time_wo_sparse, 
                                        self.deformaton_pts_scale,
                                        self.deformaton_color_scale)  
                _pts_wo_root_dynamic = torch.empty_like(self._pts_wo_root_static)
                features_dynamic = torch.empty_like(shs_view_hair_static)
                orient_conf_final = torch.empty_like(orient_conf_stiatic)
                _pts_wo_root_dynamic.index_copy_(0,sparse_index,_pts_wo_root_dynamic_sparse)
                features_dynamic.index_copy_(0,sparse_index,features_dynamic_sparse)
                orient_conf_final.index_copy_(0,sparse_index,orient_conf_final_sparse)
                _pts_wo_root_dynamic.index_copy_(0,wo_sparse_index,_pts_wo_root_dynamic_wo_sparse)
                features_dynamic.index_copy_(0,wo_sparse_index,features_dynamic_wo_sparse)
                orient_conf_final.index_copy_(0,wo_sparse_index,orient_conf_final_wo_sparse)
            else:
                _time = torch.tensor(_time).to(self._pts_wo_root_static.device).repeat(self._pts_wo_root_static.shape[0],1)
                _pts_wo_root_dynamic, _, _, _, features_dynamic, orient_conf_final = \
                    self._deformation(self._pts_wo_root_static, 
                                    None, 
                                    None, 
                                    opacity_hair, 
                                    shs_view_hair_static, 
                                    orient_conf_stiatic, 
                                    _time, 
                                    self.deformaton_pts_scale,
                                    self.deformaton_color_scale)       
        # import ipdb;ipdb.set_trace()
        features_final = features_dynamic.clone() + torch.cat([features_dc.reshape(-1, 1, 3), features_rest.reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3)],dim=1) - torch.cat([features_dc_init.reshape(-1, 1, 3), features_rest_init.reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3)],dim=1) 
        loss_pts = (_pts_wo_root_dynamic - self._pts_wo_root_static).norm(dim=-1).mean()
        loss_feat = (features_final - shs_view_hair_static).norm(dim=-1).mean()
        self.loss_pts = loss_pts.item()
        self.loss_feat = loss_feat.item()
        # self.shs_view_hair = shs_view_hair.norm(dim=-1).mean().item()
        # self._d_xyz_norm = self._d_xyz.reshape(-1,3).norm(dim=-1).mean().item()
        # print("loss_pts:      ",loss_pts.item())
        # print("pts_wo_rot:    ",pts_wo_rot.norm(dim=-1).mean().item())
        # print("loss_feat:     ",loss_feat.item())
        # print("shs_view_hair: ",shs_view_hair.norm(dim=-1).mean().item())
        # print("_d_xyz:        ",self._d_xyz.reshape(-1,3).norm(dim=-1).mean().item())
        # print("fine deformation: ",(dirs_hair_final-dirs_hair).norm(dim=-1).mean().detach().cpu().numpy())
        self._pts = torch.cat([pts_origins, _pts_wo_root_dynamic.reshape(-1,L,3) + pts[:,1:] - pts_init[:,1:]], dim=1)
        self._xyz = (self._pts[:, 1:] + self._pts[:, :-1]).view(-1, 3) * 0.5
        self._dir = (self._pts[:, 1:] - self._pts[:, :-1]).view(-1, 3)
        first_difference = self._pts[:, 1:] - self._pts[:, :-1]
        self.hair_smoothness = (first_difference[:,1:,:]-first_difference[:,:-1,:]).norm(dim=-1).mean()/first_difference.norm(dim=-1).mean()
        # self.dirs_times[time_step] = dirs_hair_final.detach().cpu().numpy()
        # time = torch.tensor(time).to(means3D_hair.device).repeat(means3D_hair.shape[0],1)
        # means3D_hair_final, _, _, _, features_all, _ = \
        #     self._deformation(means3D_hair, scales_hair, rotations_hair, opacity_hair, shs_view_hair, orient_conf, time) 
        self._rotation = parallel_transport(
            a=torch.cat(
                [
                    torch.ones_like(self._xyz[:, :1]),
                    torch.zeros_like(self._xyz[:, :2])
                ],
                dim=-1
            ),
            b=self._dir
        ).view(-1, 4) 

        self._scaling = torch.ones_like(self.get_xyz)
        self._scaling[:, 0] = self._dir.norm(dim=-1) * 0.5
        self._scaling[:, 1:] = self.scale  
        
        return features_final.transpose(1, 2).view(-1, 3, (self.max_sh_degree+1)**2), orient_conf_final

    def compute_lsds(self,time_step):
        if self.use_sds:
            # Encode the guiding strands into the latent vectors
            with torch.no_grad():
                NUM_GUIDING_STRANDS = 1000
                idx = torch.randint(low=0, high=self.num_strands, size=(NUM_GUIDING_STRANDS,), device="cuda")
                uvs_gdn = self.uvs[idx]
                # import ipdb;ipdb.set_trace()
                grid = torch.linspace(start=-1, end=1, steps=self.strands_generator.diffusion_input + 1, device="cuda")
                grid = (grid[1:] + grid[:-1]) / 2
                uvs_sds = torch.stack(torch.meshgrid(grid, grid, indexing='xy'), dim=-1).view(-1, 2)
                K = 4
                dist = ((uvs_sds.view(-1, 1, 2) - uvs_gdn.view(1, -1, 2))**2).sum(-1) # num_sds_strands x num_guiding_strands
                knn_dist, knn_idx = torch.sort(dist, dim=1)
                w = 1 / (knn_dist[:, :K] + 1e-7)
                w = w / w.sum(dim=-1, keepdim=True)
                pts_gdn_local = (torch.inverse(self.local2world[idx][:, None]) @ (self._pts[idx] - self.pts_origins[idx])[..., None])[..., 0]
                v_gdn_local = (pts_gdn_local[:, 1:] - pts_gdn_local[:, :-1]) * self.strands_generator.scale_decoder[time_step]
                # with torch.no_grad():
                #     z_gdn = self.strands_encoder(pts_gdn_local  * self.strands_generator.scale_decoder[time_step])[:, :64]
                z_gdn = self.strands_encoder(pts_gdn_local  * self.strands_generator.scale_decoder[time_step])[:, :64]
                
                z_sds_nearest = z_gdn[knn_idx[:, 0]]
                z_sds_bilinear = (z_gdn[knn_idx[:, :K]] * w[:, :, None]).sum(dim=1)
                knn_v = v_gdn_local[knn_idx[:, :K]]
                hair_length = v_gdn_local.shape[1]
                csim_full = torch.nn.functional.cosine_similarity(knn_v.view(-1, K, 1, hair_length, 3), knn_v.view(-1, 1, K, hair_length, 3), dim=-1).mean(-1) # num_guiding_strands x K x K
                j, k = torch.triu_indices(K, K, device=csim_full.device).split([1, 1], dim=0)
                i = torch.arange(NUM_GUIDING_STRANDS, device=csim_full.device).repeat_interleave(j.shape[1])
                j = j[0].repeat(NUM_GUIDING_STRANDS)
                k = k[0].repeat(NUM_GUIDING_STRANDS)
                csim = csim_full[i, j, k].view(NUM_GUIDING_STRANDS, -1).mean(-1)
                
                alpha = torch.where(csim <= 0.9, 1 - 1.63 * csim**5, 0.4 - 0.4 * csim)
                alpha_sds = (alpha[knn_idx[:, :K]] * w).sum(dim=1)[:, None]
                z_sds = z_sds_nearest * alpha_sds + z_sds_bilinear * (1 - alpha_sds)
                
                diffusion_texture = z_sds.view(1, self.strands_generator.diffusion_input, self.strands_generator.diffusion_input, 64).permute(0, 3, 1, 2)

                noise = torch.randn_like(diffusion_texture)
                sigma = self.strands_generator.sample_density([diffusion_texture.shape[0]], device='cuda')
                mask = None
                if self.strands_generator.diffuse_mask[time_step] is not None:
                    mask = torch.nn.functional.interpolate(
                        self.strands_generator.diffuse_mask[time_step][None][None], 
                        size=(self.strands_generator.diffusion_input, self.strands_generator.diffusion_input)
                    )
                L_diff, pred_image, noised_image = self.strands_generator.model_ema.loss_wo_logvar(diffusion_texture, noise, sigma, mask=mask, unet_cond=None)

                self.Lsds = L_diff.mean()
            # self.Lsds = VJPReg.apply(self._pts,self,time_step)
    @torch.no_grad()
    def filter_points(self, viewpoint_camera):
        # __forceinline__ __device__ bool in_frustum(int idx,
        #     const float* orig_points,
        #     const float* viewmatrix,
        #     const float* projmatrix,
        #     bool prefiltered,
        #     float3& p_view)
        # {
        #     float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };

        #     // Bring points to screen space
        #     float4 p_hom = transformPoint4x4(p_orig, projmatrix);
        #     float p_w = 1.0f / (p_hom.w + 0.0000001f);
        #     float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
        #     p_view = transformPoint4x3(p_orig, viewmatrix);

        #     if (p_view.z <= 0.2f)
        #     {
        #         return false;
        #     }
        #     return true;
        # }        
        mean = self.get_xyz
        viewmatrix = viewpoint_camera.world_view_transform
        p_view = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        
        mask = p_view[:, [2]] > 0.2

        mask = torch.logical_and(mask, self.det != 0)

        # float mid = 0.5f * (cov.x + cov.z);
        # float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
        # float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
        # float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));

        mid = 0.5 * (self.cov[:, [0]] + self.cov[:, [2]])
        sqrtD = (torch.clamp(mid**2 - self.det, min=0.1))**0.5
        lambda1 = mid + sqrtD
        lambda2 = mid - sqrtD
        my_radius = torch.ceil(3 * (torch.maximum(lambda1, lambda2))**0.5)

        # float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };

        # __forceinline__ __device__ float ndc2Pix(float v, int S)
        # {
        #     return ((v + 1.0) * S - 1.0) * 0.5;
        # }
        
        point_image_x = ((self.p_proj[:, [0]] + 1) * viewpoint_camera.image_width - 1.0) * 0.5
        point_image_y = ((self.p_proj[:, [1]] + 1) * viewpoint_camera.image_height - 1.0) * 0.5

        # dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);

        BLOCK_X = 16
        BLOCK_Y = 16

        grid_x = (viewpoint_camera.image_width + BLOCK_X - 1) // BLOCK_X
        grid_y = (viewpoint_camera.image_height + BLOCK_Y - 1) // BLOCK_Y

        # uint2 rect_min, rect_max;
        # getRect(point_image, my_radius, rect_min, rect_max, grid);

        # __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
        # {
        #     rect_min = {
        #         min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
        #         min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
        #     };
        #     rect_max = {
        #         min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
        #         min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
        #     };
        # }
        
        rect_min_x = torch.clamp(((point_image_x - my_radius) / BLOCK_X).int(), min=0, max=grid_x)
        rect_min_y = torch.clamp(((point_image_y - my_radius) / BLOCK_Y).int(), min=0, max=grid_y)

        rect_max_x = torch.clamp(((point_image_x + my_radius + BLOCK_X - 1) / BLOCK_X).int(), min=0, max=grid_x)
        rect_max_y = torch.clamp(((point_image_y + my_radius + BLOCK_Y - 1) / BLOCK_Y).int(), min=0, max=grid_y)

        # if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
        #     return;
        
        self.points_mask = torch.logical_and(mask, (rect_max_x - rect_min_x) * (rect_max_y - rect_min_y) != 0).squeeze()

        return self.points_mask

    def get_covariance(self, scaling_modifier = 1, return_full_covariance = False):
        if self.points_mask_hair_indices is None:
            return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation, return_full_covariance)
        else:
            return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation.index_select(0,self.points_mask_hair_indices), return_full_covariance)

    def get_covariance_2d(self, viewpoint_camera, scaling_modifier = 1):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J

        Vrk = self.get_covariance(scaling_modifier, return_full_covariance=True)

        cov = T.transpose(1, 2) @ Vrk.transpose(1, 2) @ T

        # J = torch.stack(
        #     [
        #         torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st row
        #         torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd row
        #         torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd row
        #     ],
        #     dim=-1 # stack rows into columns
        # )

        # W = viewmatrix[None, :3, :3]

        # T = J @ W

        # Vrk = self.get_covariance(viewpoint_camera, scaling_modifier, return_full_covariance=True)

        # cov = T @ Vrk @ T.transpose(1, 2)

        cov[:, 0, 0] += 0.3
        cov[:, 1, 1] += 0.3

        return torch.stack([cov[:, 0, 0], cov[:, 0, 1], cov[:, 1, 1]], dim=-1)
    def get_covariance_2d_torch(self, viewpoint_camera, scaling_modifier = 1):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J

        Vrk = self.get_covariance(scaling_modifier, return_full_covariance=True)

        cov = T.transpose(1, 2) @ Vrk.transpose(1, 2) @ T

        # J = torch.stack(
        #     [
        #         torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st row
        #         torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd row
        #         torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd row
        #     ],
        #     dim=-1 # stack rows into columns
        # )

        # W = viewmatrix[None, :3, :3]

        # T = J @ W

        # Vrk = self.get_covariance(viewpoint_camera, scaling_modifier, return_full_covariance=True)

        # cov = T @ Vrk @ T.transpose(1, 2)

        cov[:, 0, 0] += 0.3
        cov[:, 1, 1] += 0.3

        return cov

    def get_conic(self, viewpoint_camera, scaling_modifier = 1):
        self.cov = self.get_covariance_2d(viewpoint_camera, scaling_modifier)
        # mean = self.get_xyz

        # height = int(viewpoint_camera.image_height)
        # width = int(viewpoint_camera.image_width)
    
        # tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        # tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        # focal_y = height / (2.0 * tan_fovy)
        # focal_x = width / (2.0 * tan_fovx)

        # viewmatrix = viewpoint_camera.world_view_transform

        # t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        # tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        # limx = 1.3 * tan_fovx
        # limy = 1.3 * tan_fovy
        # txtz = tx / tz
        # tytz = ty / tz
        
        # tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        # ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        # zeros = torch.zeros_like(tz)

        # J = torch.stack(
        #     [
        #         torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st row
        #         torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd row
        #         torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd row
        #     ],
        #     dim=-2 # stack rows into columns
        # )

        # W = viewmatrix[None, :3, :3]

        # T = J @ W

        # Vrk = self.get_covariance(scaling_modifier, return_full_covariance=True)

        # cov = T @ Vrk @ T.transpose(1, 2)
        # cov[:, 0, 0] += 0.3
        # cov[:, 1, 1] += 0.3

        # cov = torch.stack([cov[:, 0, 0], cov[:, 0, 1], cov[:, 1, 1]], dim=-1)

        # // Invert covariance (EWA algorithm)
        # float det = (cov.x * cov.z - cov.y * cov.y);
        # if (det == 0.0f)
        #     return;
        # float det_inv = 1.f / det;
        # float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
        self.det = self.cov[:, [0]] * self.cov[:, [2]] - self.cov[:, [1]]**2
        det_inv = 1. / (self.det + 0.0000001)
        conic = torch.stack([self.cov[:, 2], -self.cov[:, 1], self.cov[:, 0]], dim=-1) * det_inv
        # det = cov[:, [0]] * cov[:, [2]] - cov[:, [1]]**2
        # det_inv = (1. / (det + 1e-12)) * (det > 1e-12)
        # conic = cov * det_inv

        return conic

    def get_mean_2d(self, viewpoint_camera):
        # __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
        # {
        #     float4 transformed = {
        #         matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
        #         matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
        #         matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
        #         matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
        #     };
        #     return transformed;
        # }
        #
		# float4 p_hom = transformPoint4x4(p_orig, projmatrix);
		# float p_w = 1.0f / (p_hom.w + 0.0000001f);
		# p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
        # projmatrix = viewpoint_camera.full_proj_transform
        # p_hom = (self.get_xyz[:, None, :] @ projmatrix[None, :3, :] + projmatrix[None, [3]])[:, 0]
        # p_w = 1.0 / (p_hom[:, [3]] + 0.0000001)
        # self.p_proj = p_hom[:, :3] * p_w
        xyz = self.get_xyz
        projmatrix = viewpoint_camera.full_proj_transform
        N, device, dtype = xyz.shape[0], xyz.device, xyz.dtype
        ones = torch.ones((N, 1), device=device, dtype=dtype)
        xyz_h = torch.cat([xyz, ones], dim=1)               
        p_hom = xyz_h @ projmatrix.T                       
        w = p_hom[:, 3:].clamp(min=1e-7)                  
        self.p_proj = p_hom[:, :3] / w                       
        return self.p_proj

    def get_depths(self, viewpoint_camera):
        viewmatrix = viewpoint_camera.world_view_transform
        p_view = (self.get_xyz[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        return p_view[:, -1:]

    def get_direction_2d(self, viewpoint_camera):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J
        if self.points_mask_hair_indices != None:
            # import ipdb; ipdb.set_trace()
            # dir3D = F.normalize(self._dir[self.points_mask_active_hair], dim=-1)
            dir3D = F.normalize(self._dir.index_select(0,self.points_mask_hair_indices), dim=-1)
        else:
            dir3D = F.normalize(self._dir, dim=-1)
        dir2D = (dir3D[:, None, :] @ T)[:, 0]

        return dir2D
    def get_direction_2d_point(self, viewpoint_camera):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J
        if self.points_mask_hair_indices != None:
            # import ipdb; ipdb.set_trace()
            # dir3D = F.normalize(self._dir[self.points_mask_active_hair], dim=-1)
            dir3D = F.normalize(self._dir.index_select(0,self.points_mask_hair_indices), dim=-1)
        else:
            dir3D = F.normalize(self._dir, dim=-1)
        dir2D = (dir3D[:, None, :] @ T)[:, 0]

        return dir2D

    def initialize_gaussians_hair(self, num_strands, time_step):  

        with torch.no_grad():
            if time_step in self.time_list.keys():
                self.pts = self.time_list[time_step][0].cuda()
                self.pts_origins = self.pts[:, :1].cuda()
                self.uvs = self.time_list[time_step][1].cuda()
                self.local2world = self.time_list[time_step][2].cuda()
                # features_dc, features_rest, orient_conf = self.time_list[time_step][3], self.time_list[time_step][4], self.time_list[time_step][5]
                # z_app = self.time_list[time_step][3]
                # features_dc, features_rest, orient_conf = self.color_decoder(z_app).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)
                # self._features_dc = features_dc.cuda().reshape(-1, 1, 3).contiguous().clone()
                # self._features_rest = features_rest.cuda().reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3).contiguous().clone()
                self._orient_conf = self.time_list[time_step][5].cuda()[:, None, :].reshape(-1, 1).contiguous().clone()
            else:
                pts_t, uvs, local2world, _, features_dc, features_rest, orient_conf, _ = self.strands_generator.forward_inference(self.num_strands, time_step)
     
                L_p = pts_t.shape[1]
          
                pts_3d = pts_t.view(-1, 3)
                new_pts_3d = pts_3d.view(-1, L_p, 3)
                self.pts_origins = new_pts_3d[:, :1]
                self.pts = new_pts_3d
                self.uvs = uvs
                self.local2world = local2world
                self.time_list[time_step] = [new_pts_3d.detach().cpu(), uvs.detach().cpu(), local2world.detach().cpu(), features_dc.detach().cpu(), features_rest.detach().cpu(), orient_conf.detach().cpu()]
                # z_app = self.time_list[time_step][3]
                # features_dc, features_rest, orient_conf = self.color_decoder(z_app).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)
                # self._features_dc = features_dc.reshape(-1, 1, 3).contiguous().clone()
                # self._features_rest = features_rest.reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3).contiguous().clone()
                self._orient_conf = orient_conf[:, None, :].reshape(-1, 1).contiguous().clone()
                # self.time_sparse_list[time_step] = [pts_t[:,1:,:].reshape(-1, 3).clone().detach().cpu(), 
                #                                     self.get_features.clone().detach().cpu(),
                #                                     self.get_orient_conf.clone().detach().cpu()]
            

        # self._pts = pts 
        # self.dirs = self.pts[:, 1:] - self.pts[:, :-1]
        # self.dir = self.dirs.view(-1, 3)
        # self.xyz = (self.pts[:, 1:] + self.pts[:, :-1]).view(-1, 3) * 0.5
       
        # self._rotation = parallel_transport(
        #     a=torch.cat(
        #         [
        #             torch.ones_like(self.xyz[:, :1]),
        #             torch.zeros_like(self.xyz[:, :2])
        #         ],
        #         dim=-1
        #     ),
        #     b=self.dir
        # ).view(-1, 4) # rotation parameters that align x-axis with the segment direction

        # self._scaling = torch.ones_like(self.get_xyz)
        # self._scaling[:, 0] = self.dir.norm(dim=-1) * 0.5
        # self._scaling[:, 1:] = self.scale

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def create_from_pcd(self, data_path, model_params, num_strands, spatial_lr_scale):
        with torch.no_grad():
            (
                self._scaling,
                self.active_sh_degree, 
                gen_dict,
                clr_dict,
                opt_strand_dict,
                opt_color_dict,
                opt_deformation_dict,
                shd_strands_dict,
                shd_color_dict,
                deform_dict,
                self._deformation_table,
                gaussian_weight_dict,
            ) = model_params
            # current_dict = self.strands_generator.state_dict()
            # for key in self.strands_generator.state_dict().keys():
            #     if 'local2world' not in key and 'origins' not in key and 'uvs' not in key:
            #         current_dict[key] = gen_dict[key]
            # self.strands_generator.load_state_dict(current_dict)
            # import ipdb;ipdb.set_trace()
            self.gaussianWeight.load_state_dict(gaussian_weight_dict)
            # print(gen_dict.keys())
            self.strands_generator.load_state_dict(gen_dict,strict=False)
            self.color_decoder.load_state_dict(clr_dict)
            self._deformation_coarse.load_state_dict(deform_dict)

            self.spatial_lr_scale = spatial_lr_scale
            
            with torch.no_grad():
                pts, uvs, local2world, p_local, features_dc, features_rest, orient_conf, z = self.strands_generator.forward_inference(num_strands,0)
            # import ipdb;ipdb.set_trace()
            # step = 10
            # pts = torch.cat([pts[:,:1,:], pts[:,1::step,:]], dim=1)
            # p_local = torch.cat([p_local[:,:1,:], p_local[:,1::step,:]], dim=1)
            # features_dc = features_dc[:,::step,:]
            # features_rest = features_rest[:,::step,:]
            # orient_conf = orient_conf[:,::step,:]
      
            # for i in range(2):
            #     N = pts.shape[0]
            #     L = pts.shape[1]
            #     new_pts = torch.zeros((N, L * 2 - 1, 3), device="cuda")
            #     new_pts[:,::2,:] = pts
            #     new_pts[:,1::2,:] = (pts[:,1:] + pts[:,:-1]) * 0.5
            #     pts = new_pts
            #     new_features_dc = features_dc.clone()
            #     new_features_dc = new_features_dc.repeat_interleave(2, dim=1)
            #     new_features_rest = features_rest.clone()
            #     new_features_rest = new_features_rest.repeat_interleave(2, dim=1)
            #     new_orient_conf = orient_conf.repeat_interleave(2, dim=1)
            #     features_dc = new_features_dc
            #     orient_conf = new_orient_conf
            #     features_rest = new_features_rest
            
            
            self.pts_origins = pts[:, :1]
            self.uvs = uvs
            self.local2world = local2world
            xyz = (pts[:, 1:] + pts[:, :-1]).view(-1, 3) * 0.5
            
            # noise = (torch.rand_like(self._xyz).cuda() * 2 - 1) * 0.1
            # self._xyz = self._xyz + noise
            
            # self.z_geom = z_geom
            self.p_local = p_local
            self.pts = pts
            self.dirs = pts[:, 1:] - pts[:, :-1]
            self.num_strands = pts.shape[0]
            self.strand_length = pts.shape[1]
            
            # label = z[:, 0]
            z_app = z[:, 1:]

            # features_dc, features_rest, orient_conf = self.color_decoder(z_app).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)

        # # Prune hair strands with low label and the ones that intersect the FLAME mesh
        # mesh = trimesh.load(f'{self.flame_mesh_dir}/stage_3/mesh_final.obj')
        # sdf_handle = SDF(mesh.vertices, mesh.faces)
        # L = pts.shape[1]

        # p_npy = pts.detach().cpu().numpy()
        # sdf = sdf_handle(p_npy.reshape(-1, 3))
        # mask = (sdf.reshape(-1, L) < 0).mean(axis=1) >= 0.5
        L = pts.shape[1]
        pts_wo_root = pts[:,1:,:].reshape(-1, 3)
        mask = torch.ones_like(pts[..., 0, 0]).bool()
        print(f'Pruning {sum(~mask)} strands that intersect the head mesh')

        # mask = torch.logical_and(label >= 0.5, torch.from_numpy(mask).cuda())
        self.num_strands = sum(mask)

        self.pts_origins = self.pts_origins[mask]
        self.uvs = self.uvs[mask]
        self.local2world = self.local2world[mask]
        # self.z_geom = z_geom[mask]
        self.p_local = p_local[mask]
        self.dir =  self.dirs.view(-1, 3)
        # self._dirs = nn.Parameter(dirs[mask].contiguous().clone().requires_grad_(True))
        # self._d_xyz = nn.Parameter(torch.zeros_like(self._dirs).contiguous().clone().requires_grad_(True))
        xyz_mask = mask[:,None].repeat(1, L-1).view(-1)
        self._xyz_static = nn.Parameter(xyz[xyz_mask].reshape(-1, 3).contiguous().clone().requires_grad_(True))
        self._pts_wo_root_static = nn.Parameter(pts_wo_root[xyz_mask].reshape(-1, 3).contiguous().clone().requires_grad_(True))
        # self._pts_wo_root_static = nn.Parameter(torch.zeros_like(self._dirs).contiguous().clone().requires_grad_(True))
        # self._d_pts_wo_root_static = nn.Parameter(torch.zeros_like(pts_wo_root[xyz_mask]).reshape(-1, 3).contiguous().clone().requires_grad_(True))
        # self._d_pts_wo_root_static = nn.Parameter(torch.zeros_like(self.dirs[mask]).contiguous().clone().requires_grad_(True))
        self._xyz = self._xyz_static
        self._features_dc = nn.Parameter(features_dc[mask].reshape(-1, 1, 3).contiguous().clone().requires_grad_(True))
        self._features_rest = nn.Parameter(features_rest[mask].reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3).contiguous().clone().requires_grad_(True))
        self._orient_conf = nn.Parameter(orient_conf[mask][:, None, :].reshape(-1, 1).contiguous().clone().requires_grad_(True))
        # self._opacity_uncertainty = nn.Parameter(10.0 * torch.ones(self._features_dc.shape[0],1).cuda().requires_grad_(True))
        # self._opacity_view = nn.Parameter(0.1 * torch.ones(self._features_dc.shape[0],1).cuda().requires_grad_(True))
        # scene_transform = pkl.load(open(os.path.join(data_path, "scale.pickle"), 'rb'))
        # self.scale = 1e-4 * scene_transform['scale'] * torch.ones(1, device="cuda")
        self.scale = 1e-4 * torch.ones(1, device="cuda")
        xyz_max = pts.reshape(-1,3).detach().cpu().numpy().max(axis=0)
        xyz_min = pts.reshape(-1,3).detach().cpu().numpy().min(axis=0)
        self._deformation.deformation_net.set_aabb(xyz_max,xyz_min)
        self._deformation = self._deformation.to("cuda")
        self.gaussianWeight = self.gaussianWeight.to("cuda")
        
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        self.time_list[0] = [pts.detach().cpu(), self.uvs.detach().cpu(), self.local2world.detach().cpu(), features_dc.detach().cpu(), features_rest.detach().cpu(), orient_conf.detach().cpu()]
        # self.time_sparse_list[0] = [pts[:,1:,:].reshape(-1, 3).clone().detach().cpu(), 
        #                             self.get_features.clone().detach().cpu(),
        #                             self.get_orient_conf.clone().detach().cpu(),
        #                             pts[:,1:,:].reshape(-1, 3).clone().detach().cpu(), 
        #                             self.get_features.clone().detach().cpu(),
        #                             self.get_orient_conf.clone().detach().cpu(),
        #                             ]
        
        # self._pts = pts
        # self._xyz = nn.Parameter(xyz[xyz_mask].reshape(-1, 3).contiguous().clone().requires_grad_(True))
        # self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        # self._features_dc = nn.Parameter(features_dc[mask].reshape(-1, 1, 3).contiguous().clone().requires_grad_(True))
        # self._features_rest = nn.Parameter(features_rest[mask].reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3).contiguous().clone().requires_grad_(True))
        # self._opacity = nn.Parameter(self.inverse_opacity_activation(torch.ones_like(self.get_xyz[:, :1])).clone().detach().requires_grad_(True))
        # self._dir = self.dir
        # _rotation = parallel_transport(
        #     a=torch.cat(
        #         [
        #             torch.ones_like(self._xyz[:, :1]),
        #             torch.zeros_like(self._xyz[:, :2])
        #         ],
        #         dim=-1
        #     ),
        #     b=self.dir
        # ).view(-1, 4) 
        # dist2 = torch.clamp_min(distCUDA2(self.get_xyz.detach()), 0.0000001)
        # scales = torch.sqrt(dist2)[...,None].repeat(1, 3) 
        # _scaling = torch.ones_like(self.get_xyz)
        # _scaling[:, 0] = self.dir.norm(dim=-1) * 0.5
        # _scaling[:, 1:] = self.scale
        # self._scaling = nn.Parameter(torch.log(scales).clone().detach().requires_grad_(True))
        # self._rotation = nn.Parameter(_rotation.clone().detach().requires_grad_(True))
        # self._xyz_init = self._xyz.clone().detach()
        # self._opacity_init = self._opacity.clone().detach()
        # self._rotation_init = self._rotation.clone().detach()
        # self._scaling_init = self._scaling.clone().detach()
        # self._features_dc_init = self._features_dc.clone().detach()
        # self._features_rest_init = self._features_rest.clone().detach()

    def training_setup(self, training_args):
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        # self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
        
        self.optimizer_gaussianWeight = torch.optim.AdamW(self.gaussianWeight.parameters(), 1e-4)
        l = [
            # {'params': [self._xyz_static], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            # {'params': [self._d_pts_wo_root_static], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            {'params': [self._pts_wo_root_static], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            # {'params': [self._dirs], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"},
            {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"},
            # {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
            {'params': [self._orient_conf], 'lr': training_args.orient_conf_lr, "name": "orient_conf"},
            # {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
            # {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
            # {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
            # {'params': [self._opacity_uncertainty], 'lr': training_args.feature_lr, "name": "o_unc"},
            # {'params': [self._opacity_view], 'lr': training_args.feature_lr, "name": "o_view"},
            # {'params': [self.strands_generator.parameters()], 'lr': training_args.strands_generator_lr, "name": "strands_generator"},
            # {'params': [self.color_decoder.parameters()], 'lr': training_args.color_decoder_lr, "name": "color_decoder"}
        ]

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        # import ipdb; ipdb.set_trace()
        self.scheduler_gaussianWeight = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_gaussianWeight, T_max=training_args.iterations, eta_min = 1e-5)
        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)
        self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.deformation_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.deformation_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)    
        self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.grid_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.deformation_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)    
        # self.deformation_hf_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_hf_lr_init*self.spatial_lr_scale,
        #                                             lr_final=training_args.deformation_hf_lr_final*self.spatial_lr_scale,
        #                                             lr_delay_mult=training_args.deformation_hf_lr_delay_mult,
        #                                             max_steps=training_args.position_lr_max_steps)    
        # self.grid_hf_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_hf_lr_init*self.spatial_lr_scale,
        #                                             lr_final=training_args.grid_hf_lr_final*self.spatial_lr_scale,
        #                                             lr_delay_mult=training_args.deformation_hf_lr_delay_mult,
        #                                             max_steps=training_args.position_lr_max_steps) 
        # self.freeze_deformation()
        # self._xyz.requires_grad = False
        # self._opacity.requires_grad = False
        # self._rotation.requires_grad = False
        # self._scaling.requires_grad = False
        # self._features_dc.requires_grad = False
        # self._features_rest.requires_grad = False
        # self._orient_conf.requires_grad = False
        
    def freeze_deformation(self):

        for param in self._deformation.get_mlp_parameters():
            param.requires_grad = False
        for param in self._deformation.get_grid_parameters():
            param.requires_grad = False

    def unfreeze_deformation(self,training_args):

        for param in self._deformation.get_mlp_parameters():
            param.requires_grad = True
        for param in self._deformation.get_grid_parameters():
            param.requires_grad = True 
        
    def update_learning_rate(self, iteration, iteration_fine):
        ''' Learning rate scheduling per step '''
        self.scheduler_gaussianWeight.step()
        for param_group in self.optimizer.param_groups:
            if param_group["name"] == "xyz":
                lr = self.xyz_scheduler_args(iteration)
                param_group['lr'] = lr
                # return lr
            if  param_group["name"] == "grid":
                lr = self.grid_scheduler_args(iteration)
                param_group['lr'] = lr
                # return lr
            elif param_group["name"] == "deformation":
                lr = self.deformation_scheduler_args(iteration)
                param_group['lr'] = lr
            # if  param_group["name"] == "grid_hf":
            #     lr = self.grid_hf_scheduler_args(iteration-iterations/2)
            #     param_group['lr'] = lr
            #     # return lr
            # elif param_group["name"] == "deformation_hf":
            #     lr = self.deformation_hf_scheduler_args(iteration-iterations/2)
            #     param_group['lr'] = lr
            
                # return lr
    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l
    def load_model(self, path):
        print("loading model from exists{}".format(path))
        weight_dict = torch.load(os.path.join(path,"deformation.pth"),map_location="cuda")
        self._deformation.load_state_dict(weight_dict)
        self._deformation = self._deformation.to("cuda")
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
        if os.path.exists(os.path.join(path, "deformation_table.pth")):
            self._deformation_table = torch.load(os.path.join(path, "deformation_table.pth"),map_location="cuda")
        if os.path.exists(os.path.join(path, "deformation_accum.pth")):
            self._deformation_accum = torch.load(os.path.join(path, "deformation_accum.pth"),map_location="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        # print(self._deformation.deformation_net.grid.)
    def save_deformation(self, path):
        torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth"))
        torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth"))
        torch.save(self._deformation_accum,os.path.join(path, "deformation_accum.pth"))
    def save_ply(self, path):
        mkdir_p(os.path.dirname(path))

        xyz = self._xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = torch.ones_like(self._xyz[:, :1]).detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
        
        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        # print("xyz.shape:     ",xyz.shape)
        # print("normals.shape: ",normals.shape)
        # print("f_dc.shape:    ",f_dc.shape)
        # print("f_rest.shape:  ",f_rest.shape)
        # print("opacities.shape: ",opacities.shape)
        # print("scale.shape:   ",scale.shape)
        # print("rotation.shape:",rotation.shape)
        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)
        
    def print_deformation_weight_grad(self):
        for name, weight in self._deformation.named_parameters():
            if weight.requires_grad:
                if weight.grad is None:
                    
                    print(name," :",weight.grad)
                else:
                    if weight.grad.mean() != 0:
                        print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max())
        print("-"*50)
        
    def _plane_regulation(self, deformation_state = 'lf'):
        if deformation_state == 'hf':
            multi_res_grids = self._deformation_hf.deformation_net.grid.grids
        else:
            multi_res_grids = self._deformation.deformation_net.grid.grids
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids =  [0,1,3]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return total
    
    def _time_regulation(self, deformation_state = 'lf'):
        if deformation_state == 'hf':
            multi_res_grids = self._deformation_hf.deformation_net.grid.grids
        else:
            multi_res_grids = self._deformation.deformation_net.grid.grids
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids =[2, 4, 5]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return total
    
    def _l1_regulation(self, deformation_state = 'lf'):
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        if deformation_state == 'hf':
            multi_res_grids = self._deformation_hf.deformation_net.grid.grids
        else:
            multi_res_grids = self._deformation.deformation_net.grid.grids

        total = 0.0
        for grids in multi_res_grids:
            if len(grids) == 3:
                continue
            else:
                # These are the spatiotemporal grids
                spatiotemporal_grids = [2, 4, 5]
            for grid_id in spatiotemporal_grids:
                total += torch.abs(1 - grids[grid_id]).mean()
        return total
    
    def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):
        loss = plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation()
        if self.deformation_state == 'hf':
            loss += plane_tv_weight * self._plane_regulation(self.deformation_state) + time_smoothness_weight * self._time_regulation(self.deformation_state) + l1_time_planes_weight * self._l1_regulation(self.deformation_state)
        return loss
    
    def update_active_set(self, nums_gaussian_head = 10000, iteration = 0, sparse_iteration = 0, coarse_iteration = 1000, grid_active_interval = 100, grid_active_threshold = 1 * 1e-5):
        # gaussian_hair_grad = self.xyz_gradient_accum / self.denom
        # gaussian_hair_grad[gaussian_hair_grad.isnan()] = 0.0
        if (iteration > coarse_iteration and iteration % grid_active_interval == 0) or iteration == coarse_iteration:
        # if iteration > coarse_iteration and iteration % grid_active_interval == 0 and iteration == coarse_iteration + grid_active_interval:
        # if iteration == coarse_iteration and False:
            # gaussian_hair_grad = self.xyz_gradient_accum / self.denom
            # gaussian_hair_grad = self.xyz_gradient_accum_multi_view.sum(dim=0) / self.denom_multi_view.sum(dim=0)
            # gaussian_hair_grad[gaussian_hair_grad.isnan()] = 0.0
            # gaussian_hair_grad = gaussian_hair_grad.view(-1)
            # gaussian_hair_grad = self.xyz_gradient_accum_multi_view.sum(dim=0)
            # gaussian_hair_grad = torch.norm(gaussian_hair_grad, dim=1,p=2)
            gaussian_hair_grad = self.xyz_gradient_accum_multi_view
            gaussian_hair_pos_grad = torch.norm(gaussian_hair_grad[:,:,:3], dim=2,p=2)
            gaussian_hair_features_grad = torch.norm(gaussian_hair_grad[:,:,3:51], dim=2,p=2)
            gaussian_hair_orient_conf_grad = torch.norm(gaussian_hair_grad[:,:,51:], dim=2,p=2)
            gaussian_hair_pos_grad = gaussian_hair_pos_grad.mean(dim=0)
            gaussian_hair_features_grad = gaussian_hair_features_grad.mean(dim=0)
            gaussian_hair_orient_conf_grad = gaussian_hair_orient_conf_grad.mean(dim=0)
            
            
            # gaussian_hair_pos_grad_threshold = 1e-6
            # gaussian_hair_color_grad_threshold = 1e-6
            # gaussian_hair_mask_grad_threshold = 1e-7
            # gaussian_hair_cov2D_grad_threshold = 1e-10
            # gaussian_hair_opacity_grad_threshold = 1e-1
            
            # gaussian_hair_pos_grad_threshold = 1e-5
            # gaussian_hair_color_grad_threshold = 1e-5
            # gaussian_hair_mask_grad_threshold = 1e-6
            # gaussian_hair_cov2D_grad_threshold = 1e-9
            # gaussian_hair_opacity_grad_threshold = 5*1e-1
            
            # gaussian_hair_pos_grad_threshold = 8*1e-6
            # gaussian_hair_color_grad_threshold = 8*1e-6
            # gaussian_hair_mask_grad_threshold = 8*1e-7
            # gaussian_hair_cov2D_grad_threshold = 8*1e-10
            # gaussian_hair_opacity_grad_threshold = 2*1e-1
            # import ipdb; ipdb.set_trace()
            # gaussian_hair_pos_grad_threshold = 4*1e-4
            # gaussian_hair_color_grad_threshold = 1e-5
            # gaussian_hair_mask_grad_threshold = 1e-6
            # gaussian_hair_cov2D_grad_threshold = 1e-9
            # gaussian_hair_opacity_grad_threshold = 4*1e-1
            # gaussian_hair_pos_grad_threshold = 1 * 1e-4
            # gaussian_hair_features_grad_threshold = 5 * 1e-8
            # large +
            gaussian_hair_pos_grad_threshold = 3 * 1e-4
            gaussian_hair_features_grad_threshold = 10 * 1e-8
            gaussian_hair_orient_conf_grad_threshold = 0
            # large 
            # gaussian_hair_pos_grad_threshold = 1 * 1e-4
            # gaussian_hair_features_grad_threshold = 5 * 1e-8
            # gaussian_hair_orient_conf_grad_threshold = 0
            # middle
            # gaussian_hair_pos_grad_threshold = 0.1 * 1e-4
            # gaussian_hair_features_grad_threshold = 0.5 * 1e-8
            # gaussian_hair_orient_conf_grad_threshold = 0
            # small
            # gaussian_hair_pos_grad_threshold = 0.01 * 1e-4
            # gaussian_hair_features_grad_threshold = 0.05 * 1e-8
            # gaussian_hair_orient_conf_grad_threshold = 0
            # import ipdb; ipdb.set_trace()
            grid_active_pos_mask = gaussian_hair_pos_grad > gaussian_hair_pos_grad_threshold
            grid_hair_active_pos_mask = grid_active_pos_mask[nums_gaussian_head:].reshape(30000,-1)
            # print("iteration: ", iteration)
            # print(f"grid_active_pos_mask: {grid_active_pos_mask.sum()}")
            grid_hair_active_pos_mask[:,1:] = grid_hair_active_pos_mask[:,1:] | grid_hair_active_pos_mask[:,:-1]
            grid_active_pos_mask[nums_gaussian_head:] = grid_hair_active_pos_mask.reshape(-1)
            # print(f"grid_active_pos_mask: {grid_active_pos_mask.sum()}")
            # import ipdb; ipdb.set_trace()
            # grid_active_mask = (gaussian_hair_pos_grad > gaussian_hair_pos_grad_threshold) | (gaussian_hair_features_grad > gaussian_hair_features_grad_threshold) \
            #     | (gaussian_hair_orient_conf_grad > gaussian_hair_orient_conf_grad_threshold)
            grid_active_mask = (grid_active_pos_mask) | (gaussian_hair_features_grad > gaussian_hair_features_grad_threshold) \
                | (gaussian_hair_orient_conf_grad > gaussian_hair_orient_conf_grad_threshold)
            self.uvs_mask_active_count = grid_active_mask.sum()
            # import ipdb; ipdb.set_trace()
            # grid_active_mask = (gaussian_hair_pos_grad > gaussian_hair_pos_grad_threshold) | (gaussian_hair_color_grad > gaussian_hair_color_grad_threshold) \
            #     | (gaussian_hair_mask_grad > gaussian_hair_mask_grad_threshold) | (gaussian_hair_cov2D_grad > gaussian_hair_cov2D_grad_threshold) \
            #     | (gaussian_hair_opacity_grad > gaussian_hair_opacity_grad_threshold)
            # grid_active_mask = gaussian_hair_grad > grid_active_threshold  # shape: (N,)
            # import ipdb; ipdb.set_trace()
            self.idx_active_mask = grid_active_mask
            # points_mask_active_head = self.idx_active_mask[:nums_gaussian_head]
            points_mask_active_hair = self.idx_active_mask[nums_gaussian_head:]
            self.points_mask_active_hair_indices = points_mask_active_hair.nonzero(as_tuple=True)[0]
            points_mask_wo_active_hair = (~self.idx_active_mask)[nums_gaussian_head:]
            self.points_mask_wo_active_hair_indices = points_mask_wo_active_hair.nonzero(as_tuple=True)[0]
            # points_mask_active_hair_reshape = points_mask_active_hair.reshape(-1, 99)
            # points_mask_active_hair_reshape_indices = points_mask_active_hair_reshape.nonzero(as_tuple=True)[0]
            # self._d_xyz.requires_grad = False
            # self._features_dc.requires_grad = False
            # self._features_rest.requires_grad = False
            # self._orient_conf.requires_grad = False
            # self._dirs.requires_grad = False
            # self._d_xyz[points_mask_active_hair_reshape_indices].requires_grad = False
            # self._features_dc[self.points_mask_wo_active_hair_indices].requires_grad = False
            # self._features_rest[self.points_mask_wo_active_hair_indices].requires_grad = False
            # self._scaling[self.points_mask_wo_active_hair_indices].requires_grad = False
            
            self.uvs_mask_active_count = grid_active_mask.sum()
        elif iteration > coarse_iteration and False:
            # gaussian_hair_grad = self.xyz_gradient_accum / self.denom
            # import ipdb; ipdb.set_trace()
            # gaussian_hair_grad = self.xyz_gradient_accum_multi_view.sum(dim=0) / self.denom_multi_view.sum(dim=0)
            # gaussian_hair_grad[gaussian_hair_grad.isnan()] = 0.0
            gaussian_hair_grad = self.xyz_gradient_accum_multi_view.sum(dim=0)
            gaussian_hair_grad = torch.norm(gaussian_hair_grad, dim=1,p=2)
            # gaussian_hair_grad = torch.where(self.denom != 0, self.xyz_gradient_accum / self.denom, torch.zeros_like(self.xyz_gradient_accum)).view(-1)
            grid_active_mask = gaussian_hair_grad > grid_active_threshold  # shape: (N,)
            self.uvs_mask_active_count = grid_active_mask.sum()
            grid_active_mask_50 = gaussian_hair_grad > 1e+5
            grid_active_mask_40 = gaussian_hair_grad > 1e+4
            grid_active_mask_30 = gaussian_hair_grad > 1e+3
            grid_active_mask_20 = gaussian_hair_grad > 1e+2
            grid_active_mask_10 = gaussian_hair_grad > 1e+1
            grid_active_mask_00 = gaussian_hair_grad > 1e+0
            grid_active_mask_01 = gaussian_hair_grad > 1e-1
            grid_active_mask_02 = gaussian_hair_grad > 1e-2
            grid_active_mask_03 = gaussian_hair_grad > 1e-3
            grid_active_mask_04 = gaussian_hair_grad > 1e-4
            grid_active_mask_05 = gaussian_hair_grad > 1e-5
            self.uvs_mask_active_count_50 = grid_active_mask_50.sum()
            self.uvs_mask_active_count_54 = (grid_active_mask_40 & ~grid_active_mask_50).sum()
            self.uvs_mask_active_count_43 = (grid_active_mask_30 & ~grid_active_mask_40).sum()
            self.uvs_mask_active_count_32 = (grid_active_mask_20 & ~grid_active_mask_30).sum()
            self.uvs_mask_active_count_21 = (grid_active_mask_10 & ~grid_active_mask_20).sum()
            self.uvs_mask_active_count_10 = (grid_active_mask_00 & ~grid_active_mask_10).sum()
            self.uvs_mask_active_count_01 = (grid_active_mask_01 & ~grid_active_mask_00).sum()
            self.uvs_mask_active_count_12 = (grid_active_mask_02 & ~grid_active_mask_01).sum()
            self.uvs_mask_active_count_23 = (grid_active_mask_03 & ~grid_active_mask_02).sum()
            self.uvs_mask_active_count_34 = (grid_active_mask_04 & ~grid_active_mask_03).sum()
            self.uvs_mask_active_count_45 = (grid_active_mask_05 & ~grid_active_mask_04).sum()
            self.uvs_mask_active_count_56 = (~grid_active_mask_05).sum()
            # self.idx_active_mask = grid_active_mask
    def vis_active_set(self, nums_gaussian_head = 10000, iteration = 0, sparse_iteration = 0, coarse_iteration = 1000, grid_active_interval = 100, grid_active_threshold = 1 * 1e-5):
        gaussian_hair_grad = self.xyz_gradient_accum_multi_view
        # import ipdb; ipdb.set_trace()
        gaussian_hair_pos_grad = torch.norm(gaussian_hair_grad[:,:,:3], dim=2,p=2)
        gaussian_hair_features_grad = torch.norm(gaussian_hair_grad[:,:,3:51], dim=2,p=2)
        gaussian_hair_orient_conf_grad = torch.norm(gaussian_hair_grad[:,:,51:], dim=2,p=2)
        gaussian_hair_pos_grad = gaussian_hair_pos_grad.mean(dim=0)
        gaussian_hair_features_grad = gaussian_hair_features_grad.mean(dim=0)
        gaussian_hair_orient_conf_grad = gaussian_hair_orient_conf_grad.mean(dim=0)
        # import ipdb; ipdb.set_trace()
        
        gaussian_hair_grad_average = self.xyz_gradient_accum_multi_view.mean(dim=0)
        gaussian_hair_pos_grad_average = torch.norm(gaussian_hair_grad_average[:,:3], dim=1,p=2)
        gaussian_hair_features_grad_average = torch.norm(gaussian_hair_grad_average[:,3:51], dim=1,p=2)
        gaussian_hair_orient_conf_grad_average = torch.norm(gaussian_hair_grad_average[:,51:], dim=1,p=2)
    
        # gaussian_hair_pos_grad_threshold = 1*1e-5
        gaussian_hair_pos_grad_threshold = 10*1e-4
        gaussian_hair_features_grad_threshold = 3*1e-8
        gaussian_hair_orient_conf_grad_threshold = 0
        # import ipdb; ipdb.set_trace()
        
        grid_active_mask = (gaussian_hair_pos_grad > gaussian_hair_pos_grad_threshold) | (gaussian_hair_features_grad > gaussian_hair_features_grad_threshold) \
            | (gaussian_hair_orient_conf_grad > gaussian_hair_orient_conf_grad_threshold)
            
        self.uvs_mask_active_count = grid_active_mask.sum()
        self.gaussian_hair_pos_grad_average = gaussian_hair_pos_grad_average.mean().item()
        self.gaussian_hair_features_grad_average = gaussian_hair_features_grad_average.mean().item()
        self.gaussian_hair_orient_conf_grad_average = gaussian_hair_orient_conf_grad_average.mean().item()
        grid_active_mask_50 = gaussian_hair_pos_grad_average > 1e+5
        grid_active_mask_40 = gaussian_hair_pos_grad_average > 1e+4
        grid_active_mask_30 = gaussian_hair_pos_grad_average > 1e+3
        grid_active_mask_20 = gaussian_hair_pos_grad_average > 1e+2
        grid_active_mask_10 = gaussian_hair_pos_grad_average > 1e+1
        grid_active_mask_00 = gaussian_hair_pos_grad_average > 1e+0
        grid_active_mask_01 = gaussian_hair_pos_grad_average > 1e-1
        grid_active_mask_02 = gaussian_hair_pos_grad_average > 1e-2
        grid_active_mask_03 = gaussian_hair_pos_grad_average > 1e-3
        grid_active_mask_04 = gaussian_hair_pos_grad_average > 1e-4
        grid_active_mask_05 = gaussian_hair_pos_grad_average > 1e-5
        grid_active_mask_06 = gaussian_hair_pos_grad_average > 1e-6
        grid_active_mask_07 = gaussian_hair_pos_grad_average > 1e-7
        grid_active_mask_08 = gaussian_hair_pos_grad_average > 1e-8
        grid_active_mask_09 = gaussian_hair_pos_grad_average > 1e-9
        # self.gaussian_hair_pos_grad_average = gaussian_hair_pos_grad.mean().item()
        # self.gaussian_hair_features_grad_average = gaussian_hair_features_grad.mean().item()
        # self.gaussian_hair_orient_conf_grad_average = gaussian_hair_orient_conf_grad.mean().item()
        # grid_active_mask_50 = gaussian_hair_pos_grad > 1e+5
        # grid_active_mask_40 = gaussian_hair_pos_grad > 1e+4
        # grid_active_mask_30 = gaussian_hair_pos_grad > 1e+3
        # grid_active_mask_20 = gaussian_hair_pos_grad > 1e+2
        # grid_active_mask_10 = gaussian_hair_pos_grad > 1e+1
        # grid_active_mask_00 = gaussian_hair_pos_grad > 1e+0
        # grid_active_mask_01 = gaussian_hair_pos_grad > 1e-1
        # grid_active_mask_02 = gaussian_hair_pos_grad > 1e-2
        # grid_active_mask_03 = gaussian_hair_pos_grad > 1e-3
        # grid_active_mask_04 = gaussian_hair_pos_grad > 1e-4
        # grid_active_mask_05 = gaussian_hair_pos_grad > 1e-5
        # grid_active_mask_06 = gaussian_hair_pos_grad > 1e-6
        # grid_active_mask_07 = gaussian_hair_pos_grad > 1e-7
        # grid_active_mask_08 = gaussian_hair_pos_grad > 1e-8
        # grid_active_mask_09 = gaussian_hair_pos_grad > 1e-9
        self.uvs_mask_active_count_50 = grid_active_mask_50.sum()
        self.uvs_mask_active_count_54 = (grid_active_mask_40 & ~grid_active_mask_50).sum()
        self.uvs_mask_active_count_43 = (grid_active_mask_30 & ~grid_active_mask_40).sum()
        self.uvs_mask_active_count_32 = (grid_active_mask_20 & ~grid_active_mask_30).sum()
        self.uvs_mask_active_count_21 = (grid_active_mask_10 & ~grid_active_mask_20).sum()
        self.uvs_mask_active_count_10 = (grid_active_mask_00 & ~grid_active_mask_10).sum()
        self.uvs_mask_active_count_01 = (grid_active_mask_01 & ~grid_active_mask_00).sum()
        self.uvs_mask_active_count_12 = (grid_active_mask_02 & ~grid_active_mask_01).sum()
        self.uvs_mask_active_count_23 = (grid_active_mask_03 & ~grid_active_mask_02).sum()
        self.uvs_mask_active_count_34 = (grid_active_mask_04 & ~grid_active_mask_03).sum()
        self.uvs_mask_active_count_45 = (grid_active_mask_05 & ~grid_active_mask_04).sum()
        self.uvs_mask_active_count_56 = (grid_active_mask_06 & ~grid_active_mask_05).sum()
        self.uvs_mask_active_count_67 = (grid_active_mask_07 & ~grid_active_mask_06).sum()
        self.uvs_mask_active_count_78 = (grid_active_mask_08 & ~grid_active_mask_07).sum()
        self.uvs_mask_active_count_89 = (grid_active_mask_09 & ~grid_active_mask_08).sum()
        self.uvs_mask_active_count_90 = (~grid_active_mask_09).sum()
        # self.gaussian_hair_pos_grad_average = gaussian_hair_pos_grad.mean().item()
        # self.gaussian_hair_features_grad_average = gaussian_hair_features_grad.mean().item()
        # self.gaussian_hair_orient_conf_grad_average = gaussian_hair_orient_conf_grad.mean().item()
            
    def add_densification_stats(self, viewspace_point_tensor, update_filter):
        # import ipdb; ipdb.set_trace()
        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True)
        self.denom[update_filter] += 1
        # self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
        # self.denom[update_filter] += 1 
        # gradients = viewspace_point_tensor.grad[update_filter, :2]
        # norm_vals = gradients.pow(2).sum(dim=-1, keepdim=True).sqrt()
        # self.xyz_gradient_accum[update_filter].add_(norm_vals)
        # self.denom[update_filter].add_(1)
    def add_densification_stats_average(self, cam_id,viewspace_point_tensor):
        self.xyz_gradient_accum_multi_view[cam_id] = viewspace_point_tensor
        self.denom_multi_view[cam_id] = 1
        
    def update_wo_active_set_data(self,num_strands,scene,render_hair_weight_wo_active,renderArgs):
        with torch.no_grad():
            config = [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(0, 16, 1)]
            for idx, viewpoint in enumerate(config):
                # scene.gaussians.mask_precomp = scene.gaussians.get_label(viewpoint.time_step)[..., 0] < 0.6
                # scene.gaussians.xyz_precomp = scene.gaussians.get_xyz(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                # scene.gaussians.opacity_precomp = scene.gaussians.get_opacity(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                # scene.gaussians.scaling_precomp = scene.gaussians.get_scaling(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                # scene.gaussians.rotation_precomp = scene.gaussians.get_rotation(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                # scene.gaussians.cov3D_precomp = scene.gaussians.get_covariance(scaling_modifier = 1.0,time_index = viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                # scene.gaussians.shs_view = scene.gaussians.get_features(viewpoint.time_step)[scene.gaussians.mask_precomp].detach().transpose(1, 2).view(-1, 3, (scene.gaussians.max_sh_degree + 1)**2)
                self.initialize_gaussians_hair(num_strands, time_step = viewpoint.time_step)
                render_pkg = render_hair_weight_wo_active(viewpoint, scene.gaussians, self, *renderArgs)
                self.wo_active_set_data[viewpoint.camera_index] = render_pkg
                # self.wo_active_set_data[viewpoint.time_step][idx] = render_pkg
    def create_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:
                stored_state["exp_avg"] = torch.zeros_like(extension_tensor)
                stored_state["exp_avg_sq"] = torch.zeros_like(extension_tensor)
                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(extension_tensor.requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state
                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(extension_tensor.requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors
    def densification_postfix_create(self, new_xyz, new_features_dc, new_features_rest, new_orient_conf):
        d = {"xyz": new_xyz,
        "f_dc": new_features_dc,
        "f_rest": new_features_rest,
        "orient_conf": new_orient_conf
        }
        optimizable_tensors = self.create_tensors_to_optimizer(d)
        self._xyz_static = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._orient_conf = optimizable_tensors["orient_conf"]
        
    def densify_and_split_hair(self,time_step):
        self.pts = self.time_list[0][0].cuda()
        pts_wo_rot = self.pts[:, 1:]
        N = pts_wo_rot.shape[0]
        L = pts_wo_rot.shape[1]
        if 2 * L >128:
            return
        sign = 1 - 2 * (torch.arange(L, device='cuda') % 2)
        alt_s = self._xyz_static.view(-1, L, 3) * sign.view(1, L, 1)       # (N, M, 3)
        cs = alt_s.cumsum(dim=1)             # (N, M, 3)
        p0_flat = self.pts_origins.squeeze(1)              # (N, 3)
        sign_k  = sign.view(1, L, 1)         # broadcast 
        ps_tail = sign_k * (2 * cs - p0_flat[:, None, :])
        self._pts = torch.cat([self.pts_origins, ps_tail], dim=1)
        self._dir = (self._pts[:, 1:] - self._pts[:, :-1]).view(-1, 3)

        # (self._pts[:, 1:] + self._pts[:, :-1]).view(-1, 3)
        new_pts = torch.zeros((N, L * 2 + 1, 3), device="cuda")
        new_pts[:,::2,:] = self._pts
        new_pts[:,1::2,:] = self._xyz_static.reshape(-1, L, 3)
        (new_pts[:,0,:] + new_pts[:,2,:] - 2 * new_pts[:,1,:])
        self._pts = new_pts
        new_xyz = (self._pts[:, 1:] + self._pts[:, :-1]).view(-1, 3) * 0.5
        # new_features_dc = torch.empty((N, L * 2, self._features_dc.shape[1], self._features_dc.shape[2]), device="cuda")
        # new_features_rest = torch.empty((N, L * 2, self._features_rest.shape[1], self._features_rest.shape[2]), device="cuda")
        # new_orient_conf = torch.empty((N, L * 2), device="cuda")
        # new_features_dc[:,::2,:] = self._features_dc.view(-1, L, self._features_dc.shape[1], self._features_dc.shape[2])
        # new_features_dc[:,1::2,:] = self._features_dc.view(-1, L, self._features_dc.shape[1], self._features_dc.shape[2])
        # new_features_rest[:,::2,:] = self._features_rest.view(-1, L, self._features_rest.shape[1], self._features_rest.shape[2])
        # new_features_rest[:,1::2,:] = self._features_rest.view(-1, L, self._features_rest.shape[1], self._features_rest.shape[2])
        # new_orient_conf[:,::2] = self._orient_conf.view(-1, L)
        # new_orient_conf[:,1::2] = self._orient_conf.view(-1, L)
        # new_features_dc = new_features_dc.view(-1, self._features_dc.shape[1], self._features_dc.shape[2])
        # new_features_rest = new_features_rest.view(-1, self._features_rest.shape[1], self._features_rest.shape[2])
        # new_orient_conf = new_orient_conf.view(-1)
        features_dc_shape = self._features_dc.shape
        features_rest_shape = self._features_rest.shape
        new_features_dc = self._features_dc.view(-1, L, features_dc_shape[1], features_dc_shape[2])
        new_features_dc = new_features_dc.repeat_interleave(2, dim=1).view(-1, features_dc_shape[1], features_dc_shape[2])
        new_features_rest = self._features_rest.view(-1, L, features_rest_shape[1], features_rest_shape[2])
        new_features_rest = new_features_rest.repeat_interleave(2, dim=1).view(-1, features_rest_shape[1], features_rest_shape[2])
        new_orient_conf = self._orient_conf.view(-1, L).repeat_interleave(2, dim=1).view(-1,1)
        self.densification_postfix_create(new_xyz, new_features_dc, new_features_rest, new_orient_conf)
        # import ipdb; ipdb.set_trace()
        self.time_list[time_step][0] = self._pts.detach().cpu()
        self.time_list[time_step][4] = self._features_dc.detach().cpu()
        self.time_list[time_step][5] = self._features_rest.detach().cpu()
        self.time_list[time_step][6] = self._orient_conf.detach().cpu()
    def densify(self,time_step):
        self.pts = self.time_list[0][0].cuda()
        pts_wo_rot = self.pts[:, 1:]
        N = pts_wo_rot.shape[0]
        L = pts_wo_rot.shape[1]
        if 2 * L >128:
            return
        sign = 1 - 2 * (torch.arange(L, device='cuda') % 2)
        alt_s = self._xyz_static.view(-1, L, 3) * sign.view(1, L, 1)       # (N, M, 3)
        cs = alt_s.cumsum(dim=1)             # (N, M, 3)
        p0_flat = self.pts_origins.squeeze(1)              # (N, 3)
        sign_k  = sign.view(1, L, 1)         # broadcast 
        ps_tail = sign_k * (2 * cs - p0_flat[:, None, :])
        self._pts = torch.cat([self.pts_origins, ps_tail], dim=1)
        self._dir = (self._pts[:, 1:] - self._pts[:, :-1]).view(-1, 3)

        # (self._pts[:, 1:] + self._pts[:, :-1]).view(-1, 3)
        new_pts = torch.zeros((N, L * 2 + 1, 3), device="cuda")
        new_pts[:,::2,:] = self._pts
        new_pts[:,1::2,:] = self._xyz_static.reshape(-1, L, 3)
        (new_pts[:,0,:] + new_pts[:,2,:] - 2 * new_pts[:,1,:])
        self._pts = new_pts
        new_xyz = (self._pts[:, 1:] + self._pts[:, :-1]).view(-1, 3) * 0.5
        features_dc_shape = self._features_dc.shape
        features_rest_shape = self._features_rest.shape
        new_features_dc = self._features_dc.view(-1, L, features_dc_shape[1], features_dc_shape[2])
        new_features_dc = new_features_dc.repeat_interleave(2, dim=1).view(-1, features_dc_shape[1], features_dc_shape[2])
        new_features_rest = self._features_rest.view(-1, L, features_rest_shape[1], features_rest_shape[2])
        new_features_rest = new_features_rest.repeat_interleave(2, dim=1).view(-1, features_rest_shape[1], features_rest_shape[2])
        new_orient_conf = self._orient_conf.view(-1, L).repeat_interleave(2, dim=1).view(-1,1)
        self.densification_postfix_create(new_xyz, new_features_dc, new_features_rest, new_orient_conf)
        # import ipdb; ipdb.set_trace()
        self.time_list[time_step][0] = self._pts.detach().cpu()
        self.time_list[time_step][4] = self._features_dc.detach().cpu()
        self.time_list[time_step][5] = self._features_rest.detach().cpu()
        self.time_list[time_step][6] = self._orient_conf.detach().cpu()

    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if group["name"] == name:
                stored_state = self.optimizer.state.get(group['params'][0], None)
                stored_state["exp_avg"] = torch.zeros_like(tensor)
                stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:
                stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]

        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]

        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]
        self.tmp_radii = self.tmp_radii[valid_points_mask]

    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:

                stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
                stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii):
        d = {"xyz": new_xyz,
        "f_dc": new_features_dc,
        "f_rest": new_features_rest,
        "opacity": new_opacities,
        "scaling" : new_scaling,
        "rotation" : new_rotation}

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]

        self.tmp_radii = torch.cat((self.tmp_radii, new_tmp_radii))
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
        n_init_points = self.get_xyz.shape[0]
        # Extract points that satisfy the gradient condition
        padded_grad = torch.zeros((n_init_points), device="cuda")
        padded_grad[:grads.shape[0]] = grads.squeeze()
        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)

        stds = self.get_scaling[selected_pts_mask].repeat(N,1)
        means =torch.zeros((stds.size(0), 3),device="cuda")
        samples = torch.normal(mean=means, std=stds)
        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
        new_tmp_radii = self.tmp_radii[selected_pts_mask].repeat(N)

        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_tmp_radii)

        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
        self.prune_points(prune_filter)

    def densify_and_clone(self, grads, grad_threshold, scene_extent):
        # Extract points that satisfy the gradient condition
        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
        
        new_xyz = self._xyz[selected_pts_mask]
        new_features_dc = self._features_dc[selected_pts_mask]
        new_features_rest = self._features_rest[selected_pts_mask]
        new_opacities = self._opacity[selected_pts_mask]
        new_scaling = self._scaling[selected_pts_mask]
        new_rotation = self._rotation[selected_pts_mask]

        new_tmp_radii = self.tmp_radii[selected_pts_mask]

        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_tmp_radii)

    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, radii):
        # import ipdb; ipdb.set_trace()
        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0

        self.tmp_radii = radii
        self.densify_and_clone(grads, max_grad, extent)
        self.densify_and_split(grads, max_grad, extent)

        prune_mask = (self.get_opacity < min_opacity).squeeze()
        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
            prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
        self.prune_points(prune_mask)
        tmp_radii = self.tmp_radii
        self.tmp_radii = None

        torch.cuda.empty_cache()
        self._xyz_init = self._xyz.clone().detach()
        self._opacity_init = self._opacity.clone().detach()
        self._rotation_init = self._rotation.clone().detach()
        self._scaling_init = self._scaling.clone().detach()
        self._features_dc_init = self._features_dc.clone().detach()
        self._features_rest_init = self._features_rest.clone().detach()
    def reset_opacity(self):
        opacities_new = self.inverse_opacity_activation(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
        self._opacity = optimizable_tensors["opacity"]