#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import numpy as np
from utils.general_utils import inverse_sigmoid
from torch import nn
import torch.nn.functional as F
import os
from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement
from utils.general_utils import strip_symmetric, get_expon_lr_func, build_scaling_rotation, parallel_transport
import math
import pickle as pkl
import sys
import random
from scene.deformation import deform_network
from scene.regulation import compute_plane_smoothness
from scene.gaussian_render import GaussRenderer
from scene.gaussian_weight import GaussianWeightPred,DeepReverseMonotonicNorm,gaussian_weight_pred_pe
sys.path.append('../ext/NeuralHaircut/')
sys.path.append('../ext/NeuralHaircut/k-diffusion')
from src.hair_networks.optimizable_textured_strands import OptimizableTexturedStrands
from src.hair_networks.strand_prior import Decoder, Encoder



class GaussianModelHair:

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation, return_full_covariance=False):
            M = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = M.transpose(1, 2) @ M
            if return_full_covariance:
                return actual_covariance
            else:
                symm = strip_symmetric(actual_covariance)
                return symm
        
        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.label_activation = torch.sigmoid
        self.inverse_label_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize

        self.orient_conf_activation = torch.exp
        self.orient_conf_inverse_activation = torch.log

    def __init__(self, data_dir, flame_mesh_dir, strands_config, deformation_config, texture_hidden_config,sh_degree, start_time_step, num_time_steps):
        num_guiding_strands = strands_config['extra_args']['num_guiding_strands']
        self.num_guiding_strands = num_guiding_strands if num_guiding_strands is not None else 0
        self.active_sh_degree = sh_degree
        self.max_sh_degree = sh_degree
        self._pts = torch.empty(0)
        self._xyz = torch.empty(0)
        self._xyzs = torch.empty(0)
        self.gaussian_hair_grad = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self._orient_conf = torch.empty(0)
        self._label = torch.empty(0)
        self.image_ROI_mask = torch.empty(0)
        self.strands_generator = OptimizableTexturedStrands(
            **strands_config['textured_strands'], 
            diffusion_cfg=strands_config['diffusion_prior'],
            texture_hidden_config = texture_hidden_config,
            data_dir=data_dir,
            flame_mesh_dir=flame_mesh_dir,
            num_guiding_strands=num_guiding_strands,
            start_time_step = start_time_step,
            num_time_steps = num_time_steps
        ).cuda()
        self.color_decoder = Decoder(None, dim_hidden=128, num_layers=2, length=99, dim_out=3*(self.max_sh_degree+1)**2 + 1).cuda()
        self.strands_encoder = Encoder(None).eval().cuda()
        self.strands_encoder.load_state_dict(torch.load(f'../ext/NeuralHaircut/pretrained_models/strand_prior/strand_ckpt.pth')['encoder'])
        self.optimizer = None
        self.scheduler = None
        self.deformaton_scale = 0.00001 # self.dirs
        self._deformation_table = torch.empty(0)
        print(deformation_config.kplanes_config)
        self._deformation = deform_network(deformation_config)
        self.gaussRender = GaussRenderer(active_sh_degree=sh_degree, white_bkgd=False)
        # self.gaussianWeight = GaussianWeightPred(input_dim=4)
        self.gaussianWeight = gaussian_weight_pred_pe(input_dim=4,pe_feq=1)
        # self.gaussianWeight = DeepReverseMonotonicNorm(input_dim=4)
        self.wo_active_set_data = [None] * 16
        self.time_list = {}
        self.dirs_times = {}
        self.spatial_lr_scale = 0.1
        self.loss_pts = 0.0
        self.hair_smoothness = torch.zeros(1).cuda()
        self.training = True
        self.time_used = 0
        self.GPU_time_used = 0
        self.uvs_mask_count = 0
        self.uvs_mask_active_count = 0
        self.idx_active_mask = None
        self.idx_hair_active_mask = None
        self.points_mask_active_hair_indices = None
        self.points_mask_wo_active_hair_indices = None
        self.points_mask_hair_indices = None
        self.grid_gradient_accum_multi_view = None
        self.strand_generator_grid_params = None
        self.hair_active_num = 0
        self.grid_interval = 2000
        self.uvs_mask_active_count_50 = 0
        self.uvs_mask_active_count_54 = 0
        self.uvs_mask_active_count_43 = 0
        self.uvs_mask_active_count_32 = 0
        self.uvs_mask_active_count_21 = 0
        self.uvs_mask_active_count_10 = 0
        self.uvs_mask_active_count_01 = 0
        self.uvs_mask_active_count_12 = 0
        self.uvs_mask_active_count_23 = 0
        self.uvs_mask_active_count_34 = 0
        self.uvs_mask_active_count_45 = 0
        self.uvs_mask_active_count_56 = 0        
        self.uvs_mask_active_count_67 = 0        
        self.uvs_mask_active_count_78 = 0        
        self.uvs_mask_active_count_89 = 0        
        self.uvs_mask_active_count_90 = 0 
        self.setup_functions()

    def capture(self):
        return (
            self._scaling,
            self.active_sh_degree,
            self.strands_generator.state_dict(),
            self.color_decoder.state_dict(),
            self.optimizer_strands.state_dict(),
            self.optimizer_color.state_dict(),
            self.optimizer_deformation.state_dict(),
            self.scheduler_strands.state_dict(),
            self.scheduler_color.state_dict(),
            self._deformation.state_dict(),
            self._deformation_table,
            self.gaussianWeight.state_dict(),
        )
    
    def restore(self, model_args, training_args):
        (
            self._scaling,
            self.active_sh_degree, 
            gen_dict,
            clr_dict,
            opt_strand_dict,
            opt_color_dict,
            opt_deformation_dict,
            shd_strands_dict,
            shd_color_dict,
            deform_dict,
            self._deformation_table,
            self.gaussian_weight_dict,
        ) = model_args
        # print(self.strands_generator.state_dict())
        current_dict = self.strands_generator.state_dict()
        for key in self.strands_generator.state_dict().keys():
            if 'local2world' not in key and 'origins' not in key and 'uvs' not in key:
                current_dict[key] = gen_dict[key]
        self.strands_generator.load_state_dict(current_dict)
        # self.strands_generator.load_state_dict(gen_dict)
        # self.strands_generator.load_state_dict(gen_dict)
        self.color_decoder.load_state_dict(clr_dict)
        self.gaussianWeight.load_state_dict(self.gaussian_weight_dict)
        self.training_setup(training_args[0], training_args[1])
        # self.optimizer_strands.load_state_dict(opt_strand_dict)
        # self.optimizer_color.load_state_dict(opt_color_dict)
        # self.optimizer_deformation.load_state_dict(opt_deformation_dict)
        # self.scheduler_strands.load_state_dict(shd_strands_dict)
        # self.scheduler_color.load_state_dict(shd_color_dict)
        # self._deformation.load_state_dict(deform_dict)
        # for param in self.strands_generator.parameters():
        #     param.requires_grad = False

    @property
    def get_scaling(self):
        scaling = torch.ones_like(self.get_xyz)
        if self.points_mask_hair_indices != None:
            scaling[:, 0] = self._dir.index_select(0, self.points_mask_hair_indices).norm(dim=-1) * 0.5
        else:
            scaling[:, 0] = self._dir.norm(dim=-1) * 0.5
        # scaling[:, 0] = self._dir.norm(dim=-1) * 0.5
        scaling[:, 1:] = self.scale

        return scaling
    
    @property
    def get_rotation(self):
        if self.points_mask_hair_indices != None:
            return self.rotation_activation(self._rotation.index_select(0, self.points_mask_hair_indices))
        else:
            return self.rotation_activation(self._rotation)

    @property
    def get_xyz(self):
        if self.points_mask_hair_indices != None:
            return self._xyz.index_select(0, self.points_mask_hair_indices)
        else:
            return self._xyz
    
    @property
    def get_features(self):
        if self.points_mask_hair_indices != None:
            features_dc = self._features_dc.index_select(0, self.points_mask_hair_indices)
            features_rest = self._features_rest.index_select(0, self.points_mask_hair_indices)
        else:
            features_dc = self._features_dc
            features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)

    @property
    def get_features_gdn(self):
        features_dc_gdn = self._features_dc_gdn
        features_rest_gdn = self._features_rest_gdn
        return torch.cat((features_dc_gdn, features_rest_gdn), dim=1)

    @property
    def get_opacity(self):
        return torch.ones_like(self.get_xyz[:, :1]) # self.opacity_activation(self._opacity)

    @property
    def get_label(self):
        return torch.ones_like(self.get_xyz[:, :1]) # self.label_activation(self._label)
    
    @property
    def get_orient_conf(self):
        if self.points_mask_hair_indices != None:
            return self.orient_conf_activation(self._orient_conf.index_select(0, self.points_mask_hair_indices))
        else:
            return self.orient_conf_activation(self._orient_conf)

    def set_deformation_scale(self, lr_init, lr_final, step, max_steps):
        t = np.clip(step / max_steps, 0, 1)
        self.deformaton_scale = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 
    @torch.no_grad()
    def filter_points(self, viewpoint_camera):
        # __forceinline__ __device__ bool in_frustum(int idx,
        #     const float* orig_points,
        #     const float* viewmatrix,
        #     const float* projmatrix,
        #     bool prefiltered,
        #     float3& p_view)
        # {
        #     float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };

        #     // Bring points to screen space
        #     float4 p_hom = transformPoint4x4(p_orig, projmatrix);
        #     float p_w = 1.0f / (p_hom.w + 0.0000001f);
        #     float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
        #     p_view = transformPoint4x3(p_orig, viewmatrix);

        #     if (p_view.z <= 0.2f)
        #     {
        #         return false;
        #     }
        #     return true;
        # }        
        mean = self.get_xyz
        viewmatrix = viewpoint_camera.world_view_transform
        p_view = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        
        mask = p_view[:, [2]] > 0.2

        mask = torch.logical_and(mask, self.det != 0)

        # float mid = 0.5f * (cov.x + cov.z);
        # float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
        # float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
        # float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));

        mid = 0.5 * (self.cov[:, [0]] + self.cov[:, [2]])
        sqrtD = (torch.clamp(mid**2 - self.det, min=0.1))**0.5
        lambda1 = mid + sqrtD
        lambda2 = mid - sqrtD
        my_radius = torch.ceil(3 * (torch.maximum(lambda1, lambda2))**0.5)

        # float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };

        # __forceinline__ __device__ float ndc2Pix(float v, int S)
        # {
        #     return ((v + 1.0) * S - 1.0) * 0.5;
        # }
        
        point_image_x = ((self.p_proj[:, [0]] + 1) * viewpoint_camera.image_width - 1.0) * 0.5
        point_image_y = ((self.p_proj[:, [1]] + 1) * viewpoint_camera.image_height - 1.0) * 0.5

        # dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);

        BLOCK_X = 16
        BLOCK_Y = 16

        grid_x = (viewpoint_camera.image_width + BLOCK_X - 1) // BLOCK_X
        grid_y = (viewpoint_camera.image_height + BLOCK_Y - 1) // BLOCK_Y

        # uint2 rect_min, rect_max;
        # getRect(point_image, my_radius, rect_min, rect_max, grid);

        # __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
        # {
        #     rect_min = {
        #         min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
        #         min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
        #     };
        #     rect_max = {
        #         min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
        #         min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
        #     };
        # }
        
        rect_min_x = torch.clamp(((point_image_x - my_radius) / BLOCK_X).int(), min=0, max=grid_x)
        rect_min_y = torch.clamp(((point_image_y - my_radius) / BLOCK_Y).int(), min=0, max=grid_y)

        rect_max_x = torch.clamp(((point_image_x + my_radius + BLOCK_X - 1) / BLOCK_X).int(), min=0, max=grid_x)
        rect_max_y = torch.clamp(((point_image_y + my_radius + BLOCK_Y - 1) / BLOCK_Y).int(), min=0, max=grid_y)

        # if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
        #     return;
        
        self.points_mask = torch.logical_and(mask, (rect_max_x - rect_min_x) * (rect_max_y - rect_min_y) != 0).squeeze()

        return self.points_mask

    def get_covariance(self, scaling_modifier = 1, return_full_covariance = False):
        # return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation, return_full_covariance)
        if self.points_mask_hair_indices is None:
            return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation, return_full_covariance)
        else:
            return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation.index_select(0,self.points_mask_hair_indices), return_full_covariance)
    def get_covariance_2d(self, viewpoint_camera, scaling_modifier = 1):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J

        Vrk = self.get_covariance(scaling_modifier, return_full_covariance=True)

        cov = T.transpose(1, 2) @ Vrk.transpose(1, 2) @ T

        # J = torch.stack(
        #     [
        #         torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st row
        #         torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd row
        #         torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd row
        #     ],
        #     dim=-1 # stack rows into columns
        # )

        # W = viewmatrix[None, :3, :3]

        # T = J @ W

        # Vrk = self.get_covariance(viewpoint_camera, scaling_modifier, return_full_covariance=True)

        # cov = T @ Vrk @ T.transpose(1, 2)

        cov[:, 0, 0] += 0.3
        cov[:, 1, 1] += 0.3


        return torch.stack([cov[:, 0, 0], cov[:, 0, 1], cov[:, 1, 1]], dim=-1)
    def get_covariance_2d_torch(self, viewpoint_camera, scaling_modifier = 1):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J

        Vrk = self.get_covariance(scaling_modifier, return_full_covariance=True)

        cov = T.transpose(1, 2) @ Vrk.transpose(1, 2) @ T

        # J = torch.stack(
        #     [
        #         torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st row
        #         torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd row
        #         torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd row
        #     ],
        #     dim=-1 # stack rows into columns
        # )

        # W = viewmatrix[None, :3, :3]

        # T = J @ W

        # Vrk = self.get_covariance(viewpoint_camera, scaling_modifier, return_full_covariance=True)

        # cov = T @ Vrk @ T.transpose(1, 2)

        cov[:, 0, 0] += 0.3
        cov[:, 1, 1] += 0.3

        return cov

    def get_conic(self, viewpoint_camera, scaling_modifier = 1):
        self.cov = self.get_covariance_2d(viewpoint_camera, scaling_modifier)
        # mean = self.get_xyz

        # height = int(viewpoint_camera.image_height)
        # width = int(viewpoint_camera.image_width)
    
        # tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        # tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        # focal_y = height / (2.0 * tan_fovy)
        # focal_x = width / (2.0 * tan_fovx)

        # viewmatrix = viewpoint_camera.world_view_transform

        # t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        # tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        # limx = 1.3 * tan_fovx
        # limy = 1.3 * tan_fovy
        # txtz = tx / tz
        # tytz = ty / tz
        
        # tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        # ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        # zeros = torch.zeros_like(tz)

        # J = torch.stack(
        #     [
        #         torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st row
        #         torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd row
        #         torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd row
        #     ],
        #     dim=-2 # stack rows into columns
        # )

        # W = viewmatrix[None, :3, :3]

        # T = J @ W

        # Vrk = self.get_covariance(scaling_modifier, return_full_covariance=True)

        # cov = T @ Vrk @ T.transpose(1, 2)
        # cov[:, 0, 0] += 0.3
        # cov[:, 1, 1] += 0.3

        # cov = torch.stack([cov[:, 0, 0], cov[:, 0, 1], cov[:, 1, 1]], dim=-1)

        # // Invert covariance (EWA algorithm)
        # float det = (cov.x * cov.z - cov.y * cov.y);
        # if (det == 0.0f)
        #     return;
        # float det_inv = 1.f / det;
        # float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
        self.det = self.cov[:, [0]] * self.cov[:, [2]] - self.cov[:, [1]]**2
        det_inv = 1. / (self.det + 0.0000001)
        conic = torch.stack([self.cov[:, 2], -self.cov[:, 1], self.cov[:, 0]], dim=-1) * det_inv
        # det = cov[:, [0]] * cov[:, [2]] - cov[:, [1]]**2
        # det_inv = (1. / (det + 1e-12)) * (det > 1e-12)
        # conic = cov * det_inv

        return conic

    def get_mean_2d(self, viewpoint_camera):
        # __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
        # {
        #     float4 transformed = {
        #         matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
        #         matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
        #         matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
        #         matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
        #     };
        #     return transformed;
        # }
        #
		# float4 p_hom = transformPoint4x4(p_orig, projmatrix);
		# float p_w = 1.0f / (p_hom.w + 0.0000001f);
		# p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
        projmatrix = viewpoint_camera.full_proj_transform
        p_hom = (self.get_xyz[:, None, :] @ projmatrix[None, :3, :] + projmatrix[None, [3]])[:, 0]
        p_w = 1.0 / (p_hom[:, [3]] + 0.0000001)
        self.p_proj = p_hom[:, :3] * p_w

        return self.p_proj

    def get_depths(self, viewpoint_camera):
        viewmatrix = viewpoint_camera.world_view_transform
        p_view = (self.get_xyz[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        return p_view[:, -1:]

    def get_direction_2d(self, viewpoint_camera):
        mean = self.get_xyz

        height = int(viewpoint_camera.image_height)
        width = int(viewpoint_camera.image_width)
    
        tan_fovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tan_fovy = math.tan(viewpoint_camera.FoVy * 0.5)

        focal_y = height / (2.0 * tan_fovy)
        focal_x = width / (2.0 * tan_fovx)

        viewmatrix = viewpoint_camera.world_view_transform

        t = (mean[:, None, :] @ viewmatrix[None, :3, :3] + viewmatrix[None, [3], :3])[:, 0]
        tx, ty, tz = t[:, 0], t[:, 1], t[:, 2]

        limx = 1.3 * tan_fovx
        limy = 1.3 * tan_fovy
        txtz = tx / tz
        tytz = ty / tz
        
        tx = torch.clamp(txtz, min=-limx, max=limx) * tz
        ty = torch.clamp(tytz, min=-limy, max=limy) * tz

        zeros = torch.zeros_like(tz)

        J = torch.stack(
            [
                torch.stack([focal_x / tz,        zeros, -(focal_x * tx) / (tz * tz)], dim=-1), # 1st column
                torch.stack([       zeros, focal_y / tz, -(focal_y * ty) / (tz * tz)], dim=-1), # 2nd column
                torch.stack([       zeros,        zeros,                       zeros], dim=-1)  # 3rd column
            ],
            dim=-1 # stack columns into rows
        )

        W = viewmatrix[None, :3, :3]

        T = W @ J

        # dir3D = F.normalize(self._dir, dim=-1)
        if self.points_mask_hair_indices != None:
            # import ipdb; ipdb.set_trace()
            # dir3D = F.normalize(self._dir[self.points_mask_active_hair], dim=-1)
            dir3D = F.normalize(self._dir.index_select(0,self.points_mask_hair_indices), dim=-1)
        else:
            dir3D = F.normalize(self._dir, dim=-1)
        dir2D = (dir3D[:, None, :] @ T)[:, 0]

        return dir2D
    def ROI_selelct(self, img_width, img_height, original_mask):
        hair_mask = original_mask[0]

        nonzero_coords = torch.nonzero(hair_mask > 0, as_tuple=False)
        top_left = nonzero_coords.min(dim=0).values.tolist()
        bottom_right = nonzero_coords.max(dim=0).values.tolist()
        left_top_x, top_left_y = top_left[0], top_left[1]
        right_bottom_x, bottom_right_y = bottom_right[0], bottom_right[1]

        BLOCK_X = 4
        BLOCK_Y = 4
        max_step = BLOCK_X * BLOCK_Y
        while max_step:

            index = random.randint(0, BLOCK_X * BLOCK_Y - 1)

            x_start = left_top_x + index % BLOCK_X * (right_bottom_x - left_top_x) // BLOCK_X
            y_start = top_left_y + index // BLOCK_X * (bottom_right_y - top_left_y) // BLOCK_Y
            x_end = x_start + (right_bottom_x - left_top_x) // BLOCK_X
            y_end = y_start + (bottom_right_y - top_left_y) // BLOCK_Y
  
            area = torch.sum(hair_mask[x_start:x_end, y_start:y_end] > 0)
            if area > 100:
                return torch.tensor([x_start, y_start, x_end, y_end]).cuda()
            else:
                max_step -= 1
                
        return torch.tensor([0, 0, img_height, img_width]).cuda()
    def get_mean_2d_monocular(self, viewpoint_camera):
        # __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
        # {
        #     float4 transformed = {
        #         matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
        #         matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
        #         matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
        #         matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
        #     };
        #     return transformed;
        # }
        #
        # float4 p_hom = transformPoint4x4(p_orig, projmatrix);
        # float p_w = 1.0f / (p_hom.w + 0.0000001f);
        # p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
        projmatrix = viewpoint_camera.full_proj_transform
        p_hom = (self._xyzs.reshape(-1,3)[:, None, :] @ projmatrix[None, :3, :] + projmatrix[None, [3]])[:, 0]
        p_w = 1.0 / (p_hom[:, [3]] + 0.0000001)
        self.p_proj = p_hom[:, :3] * p_w

        return self.p_proj
    
    def generate_strands(self, iter, num_strands = -1, viewpoint_cam=None):
        time_step = viewpoint_cam.time_step
        if num_strands != -1:
            _p, uvs, local2world, p_local, _, z = self.strands_generator.forward_inference(num_strands, time_step)
            diffusion_dict = {}
        else:
            _p, uvs, local2world, p_local, _, z, diffusion_dict = self.strands_generator(iter, time_step)

        if time_step == 0:
            p = _p
        else:
        # with torch.no_grad():
            L = _p.shape[1]
            p_world_gdn = _p[:self.num_guiding_strands].view(-1, 3).cuda()
            opacity_wolrd_gdn = torch.ones_like(p_world_gdn[:, :1])
            time = time_step / self.strands_generator.num_time_steps
            # print("time:  ", time)
            time = torch.tensor(time).to(p_world_gdn.device).repeat(p_world_gdn.shape[0],1)
            new_p_world_gdn, _, _, _, _, _ = \
                self._deformation(p_world_gdn, opacity = opacity_wolrd_gdn,
                                  times_sel = time, deformaton_pts_scale = self.deformaton_scale, deformaton_color_scale=self.deformaton_scale)  
            loss_pts = (new_p_world_gdn - p_world_gdn).norm(dim=-1).mean()
            self.loss_pts = loss_pts.item()
            self.pts_wo_rot = p_world_gdn.norm(dim=-1).mean().item()
            K = 4
            # print("loss:  ", (new_p_world_gdn - p_world_gdn).norm(dim=-1).mean().detach().cpu().numpy())
            new_p_world_gdn = new_p_world_gdn.view(-1, L, 3)
            origins_int = _p[self.num_guiding_strands:][:, 0]
            origins_gdn = _p[:self.num_guiding_strands][:, 0]
            local2world_gdn = local2world[:self.num_guiding_strands]
            local2world_int = local2world[self.num_guiding_strands:]
            # p_local_gdn = (torch.inverse(local2world[:self.num_guiding_strands][:, None]) @ (new_p_world_gdn - _p[:self.num_guiding_strands][:,:1])[..., None])[..., 0]
            p_local_gdn = (torch.inverse(local2world_gdn[:, None]) @ (new_p_world_gdn - _p[:self.num_guiding_strands][:, :1])[..., None])[..., 0]
            uvs_int = uvs[self.num_guiding_strands:]
            uvs_gdn = uvs[:self.num_guiding_strands]
            # p_local_int =  p_local[self.num_guiding_strands:]
            # p_local_gdn = p_local_dedor

            dist = ((uvs_int.view(-1, 1, 2) - uvs_gdn.view(1, -1, 2))**2).sum(-1) # num_strands x num_guiding_strands
            knn_dist, knn_idx = torch.sort(dist, dim=1)
            w = 1 / (knn_dist[:, :K] + 1e-7)
            w = w / w.sum(dim=-1, keepdim=True)
            p_local_int_nearest = p_local_gdn[knn_idx[:, 0]]            
            p_local_int_bilinear = (p_local_gdn[knn_idx[:, :K]] * w[:, :, None, None]).sum(dim=1)
            # Calculate cosine similarity between neighbouring guiding strands to get blending alphas (eq. 4 of HAAR)
            v = p_local_gdn[:, 1:] - p_local_gdn[:, :-1]
            knn_v = v[knn_idx[:, :K]]
            csim_full = torch.nn.functional.cosine_similarity(knn_v.view(-1, K, 1, 99, 3), knn_v.view(-1, 1, K, 99, 3), dim=-1).mean(-1) # num_guiding_strands x K x K
            j, k = torch.triu_indices(K, K, device=csim_full.device).split([1, 1], dim=0)
            i = torch.arange(self.num_guiding_strands, device=csim_full.device).repeat_interleave(j.shape[1])
            j = j[0].repeat(self.num_guiding_strands)
            k = k[0].repeat(self.num_guiding_strands)
            csim = csim_full[i, j, k].view(self.num_guiding_strands, -1).mean(-1)
            alpha = torch.where(csim <= 0.9, 1 - 1.63 * csim**5, 0.4 - 0.4 * csim)
            alpha_int = (alpha[knn_idx[:, :K]] * w).sum(dim=1)[:, None, None]
            p_local_int = p_local_int_nearest * alpha_int + p_local_int_bilinear * (1 - alpha_int)
            new_origins = torch.cat([origins_gdn, origins_int])
            # new_uvs = torch.cat([uvs_gdn, uvs_int])
            new_local2world = torch.cat([local2world_gdn, local2world_int])
            new_p_local = torch.cat([p_local_gdn, p_local_int])
            p = (new_local2world[:, None] @ new_p_local[..., None])[:, :, :3, 0] + new_origins[:, None]
          
        self.num_strands = p.shape[0]
        self.strand_length = p.shape[1]

        self._pts = p
        self._xyz = (p[:, 1:] + p[:, :-1]).view(-1, 3) * 0.5
        self._dir = (p[:, 1:] - p[:, :-1]).view(-1, 3)
        # self._label = z[:, :1].view(self.num_strands, 1, 1).repeat(1, self.strand_length - 1, 1).view(-1, 1)
        z_app = z[:, 1:]
        
        if self.num_guiding_strands:
            self._xyz_gdn = self._xyz.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
            self._dir_gdn = self._dir.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
        else:
            self._xyz_gdn = self._xyz
            self._dir_gdn = self._dir

        # Assign spherical harmonics features
        if z_app.shape[1] == 3 * (self.max_sh_degree + 1) ** 2:
            features_dc, features_rest = z_app.view(self.num_strands, 1, 3 * (self.max_sh_degree + 1) ** 2).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1)], dim=-1)
            features_dc = features_dc.repeat(1, self.strand_length - 1, 1)
            features_rest = features_rest.repeat(1, self.strand_length - 1, 1)        
        elif z_app.shape[1] == (self.strand_length - 1) * 3 * (self.max_sh_degree + 1) ** 2:
            features_dc, features_rest = z_app.view(self.num_strands, self.strand_length - 1, 3 * (self.max_sh_degree + 1) ** 2).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1)], dim=-1)
        elif z_app.shape[1] == 64:
            features_dc, features_rest, orient_conf = self.color_decoder(z_app).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)

        self._features_dc = features_dc.reshape(self.num_strands * (self.strand_length - 1), 1, 3)
        self._features_rest = features_rest.reshape(self.num_strands * (self.strand_length - 1), (self.max_sh_degree + 1) ** 2 - 1, 3)
        self._orient_conf = orient_conf.reshape(self.num_strands * (self.strand_length - 1), 1)

        if self.num_guiding_strands:
            self._features_dc_gdn = features_dc[:self.num_guiding_strands].reshape(-1, 1, 3)
            self._features_rest_gdn = features_rest[:self.num_guiding_strands].reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3)
        else:
            self._features_dc_gdn = self._features_dc
            self._features_rest_gdn = self._features_rest

        return diffusion_dict
    
    def generate_strands_local(self, iter, num_strands = -1, viewpoint_cam = None):
        time_step = viewpoint_cam.time_step
        if num_strands != -1:
            _p, uvs, local2world, p_local, _features_dc, _features_rest, orient_conf, z = self.strands_generator.forward_inference(num_strands, time_step)
            diffusion_dict = {}
        else:
            # _p, uvs, local2world, p_local, _, z, diffusion_dict = self.strands_generator(iter, time_step)
            _p, uvs, local2world, p_local, _features_dc, _features_rest, orient_conf, diffusion_dict = self.strands_generator(iter, time_step)

        self.num_strands = _p.shape[0]
        self.strand_length = _p.shape[1]
        if time_step == 0:
            p = _p
            features_dc = _features_dc
            features_rest = _features_rest
            if 'L_diff' in diffusion_dict.keys():
                self.LDiff = diffusion_dict['L_diff']
            else:
                self.LDiff = None
        else:

            origins_int = _p[self.num_guiding_strands:][:, 0]
            origins_gdn = _p[:self.num_guiding_strands][:, 0]
            local2world_gdn = local2world[:self.num_guiding_strands]
            local2world_int = local2world[self.num_guiding_strands:]
            p_world_gdn = _p[:self.num_guiding_strands].cuda()
            p_local_gdn = (torch.inverse(local2world_gdn[:, None]) @ (p_world_gdn - _p[:self.num_guiding_strands][:, :1])[..., None])[..., 0]
            # L = p_local_gdn.shape[1]
            L_wo_rot = p_local_gdn.shape[1] -1 
            p_local_gdn_wo_rot = p_local_gdn[:, 1:].contiguous().view(-1, 3).cuda()
            # p_local_gdn = p_local_gdn.view(-1,3).cuda()
            shs_view_hair_wo_rot = torch.cat([_features_dc[:self.num_guiding_strands].reshape(self.num_guiding_strands * (self.strand_length - 1), 1, 3), _features_rest[:self.num_guiding_strands].view(-1, (self.max_sh_degree + 1) ** 2 - 1, 3)], dim=1)
            opacity_local_gdn_wo_rot = torch.ones_like(p_local_gdn_wo_rot[:, :1])
            time = torch.tensor(time_step / self.strands_generator.num_time_steps).to(p_local_gdn_wo_rot.device).repeat(p_local_gdn_wo_rot.shape[0],1)
            new_p_local_gdn_wo_rot, _, _, _, features_all, _ = \
                self._deformation(p_local_gdn_wo_rot, opacity = opacity_local_gdn_wo_rot,
                                  times_sel = time, shs = shs_view_hair_wo_rot,
                                  deformaton_pts_scale = self.deformaton_scale, deformaton_color_scale=self.deformaton_scale)  
            new_p_local_gdn = torch.cat([p_local_gdn[:, :1], new_p_local_gdn_wo_rot.view(-1, L_wo_rot, 3)], dim=1)
            features_dc_gdn = features_all[:, :1,:].reshape(self.num_guiding_strands, self.strand_length - 1, 3)
            features_rest_gdn = features_all[:, 1:,:].reshape(self.num_guiding_strands, self.strand_length - 1, ((self.max_sh_degree + 1) ** 2 - 1) * 3)
            
            # print loss
            loss_pts = (new_p_local_gdn_wo_rot - p_local_gdn_wo_rot).norm(dim=-1).mean()
            self.loss_pts = loss_pts.item()
            self.pts_wo_rot = p_local_gdn_wo_rot.norm(dim=-1).mean().item()
            K = 4

            uvs_int = uvs[self.num_guiding_strands:]
            uvs_gdn = uvs[:self.num_guiding_strands]
            dist = ((uvs_int.view(-1, 1, 2) - uvs_gdn.view(1, -1, 2))**2).sum(-1) # num_strands x num_guiding_strands
            knn_dist, knn_idx = torch.sort(dist, dim=1)
            w = 1 / (knn_dist[:, :K] + 1e-7)
            w = w / w.sum(dim=-1, keepdim=True)
            p_local_int_nearest = new_p_local_gdn[knn_idx[:, 0]]            
            p_local_int_bilinear = (new_p_local_gdn[knn_idx[:, :K]] * w[:, :, None, None]).sum(dim=1)
            features_dc_int_nearest = features_dc_gdn[:self.num_guiding_strands][knn_idx[:, 0]]
            features_rest_int_nearest = features_rest_gdn[:self.num_guiding_strands][knn_idx[:, 0]]
            features_dc_int_bilinear = (features_dc_gdn[:self.num_guiding_strands][knn_idx[:, :K]] * w[:, :, None, None]).sum(dim=1)
            features_rest_int_bilinear = (features_rest_gdn[:self.num_guiding_strands][knn_idx[:, :K]] * w[:, :, None, None]).sum(dim=1)
            # Calculate cosine similarity between neighbouring guiding strands to get blending alphas (eq. 4 of HAAR)
            v = new_p_local_gdn[:, 1:] - new_p_local_gdn[:, :-1]
            knn_v = v[knn_idx[:, :K]]
            csim_full = torch.nn.functional.cosine_similarity(knn_v.view(-1, K, 1, L_wo_rot, 3), knn_v.view(-1, 1, K, L_wo_rot, 3), dim=-1).mean(-1) # num_guiding_strands x K x K
            j, k = torch.triu_indices(K, K, device=csim_full.device).split([1, 1], dim=0)
            i = torch.arange(self.num_guiding_strands, device=csim_full.device).repeat_interleave(j.shape[1])
            j = j[0].repeat(self.num_guiding_strands)
            k = k[0].repeat(self.num_guiding_strands)
            csim = csim_full[i, j, k].view(self.num_guiding_strands, -1).mean(-1)
            alpha = torch.where(csim <= 0.9, 1 - 1.63 * csim**5, 0.4 - 0.4 * csim)
            alpha_int = (alpha[knn_idx[:, :K]] * w).sum(dim=1)[:, None, None]
            p_local_int = p_local_int_nearest * alpha_int + p_local_int_bilinear * (1 - alpha_int)
            features_dc_int = features_dc_int_nearest * alpha_int + features_dc_int_bilinear * (1 - alpha_int)
            features_rest_int = features_rest_int_nearest * alpha_int + features_rest_int_bilinear * (1 - alpha_int)
            # orient_conf_int = orient_conf_int_nearest * alpha_int + orient_conf_int_bilinear * (1 - alpha_int)
            new_origins = torch.cat([origins_gdn, origins_int])
            # new_uvs = torch.cat([uvs_gdn, uvs_int])
            new_local2world = torch.cat([local2world_gdn, local2world_int])
            new_p_local = torch.cat([new_p_local_gdn, p_local_int])
            p = (new_local2world[:, None] @ new_p_local[..., None])[:, :, :3, 0] + new_origins[:, None]
            features_dc = torch.cat([features_dc_gdn, features_dc_int])
            features_rest = torch.cat([features_rest_gdn, features_rest_int])
            
            # Encode the guiding strands into the latent vectors
            v_gdn_local = (new_p_local_gdn[:, 1:] - new_p_local_gdn[:, :-1]) * self.strands_generator.scale_decoder[time_step]
            z_gdn = self.strands_encoder(new_p_local_gdn  * self.strands_generator.scale_decoder[time_step])[:, :64]
            grid = torch.linspace(start=-1, end=1, steps=self.strands_generator.diffusion_input + 1, device="cuda")
            grid = (grid[1:] + grid[:-1]) / 2
            uvs_sds = torch.stack(torch.meshgrid(grid, grid, indexing='xy'), dim=-1).view(-1, 2)
            new_dist = ((uvs_sds.view(-1, 1, 2) - uvs_gdn.view(1, -1, 2))**2).sum(-1) # num_sds_strands x num_guiding_strands
            new_knn_dist, new_knn_idx = torch.sort(new_dist, dim=1)
            new_w = 1 / (new_knn_dist[:, :K] + 1e-7)
            new_w = new_w / new_w.sum(dim=-1, keepdim=True)
            z_sds_nearest = z_gdn[new_knn_idx[:, 0]]
            z_sds_bilinear = (z_gdn[new_knn_idx[:, :K]] * new_w[:, :, None]).sum(dim=1)
            # Calculate cosine similarity between neighbouring guiding strands to get blending alphas (eq. 4 of HAAR)
            new_knn_v = v_gdn_local[new_knn_idx[:, :K]]
            new_csim_full = torch.nn.functional.cosine_similarity(new_knn_v.view(-1, K, 1, L_wo_rot, 3), new_knn_v.view(-1, 1, K, L_wo_rot, 3), dim=-1).mean(-1) # num_guiding_strands x K x K
            new_j, new_k = torch.triu_indices(K, K, device=new_csim_full.device).split([1, 1], dim=0)
            new_i = torch.arange(self.num_guiding_strands, device=new_csim_full.device).repeat_interleave(new_j.shape[1])
            new_j = new_j[0].repeat(self.num_guiding_strands)
            new_k = new_k[0].repeat(self.num_guiding_strands)
            new_csim = new_csim_full[new_i, new_j, new_k].view(self.num_guiding_strands, -1).mean(-1)
            
            new_alpha = torch.where(new_csim <= 0.9, 1 - 1.63 * new_csim**5, 0.4 - 0.4 * new_csim)
            alpha_sds = (new_alpha[new_knn_idx[:, :K]] * new_w).sum(dim=1)[:, None]
            z_sds = z_sds_nearest * alpha_sds + z_sds_bilinear * (1 - alpha_sds)

            diffusion_texture = z_sds.view(1, self.strands_generator.diffusion_input, self.strands_generator.diffusion_input, 64).permute(0, 3, 1, 2)

            noise = torch.randn_like(diffusion_texture)
            sigma = self.strands_generator.sample_density([diffusion_texture.shape[0]], device='cuda')
            mask = None
            if self.strands_generator.diffuse_mask[time_step] is not None:
                mask = torch.nn.functional.interpolate(
                    self.strands_generator.diffuse_mask[time_step][None][None], 
                    size=(self.strands_generator.diffusion_input, self.strands_generator.diffusion_input)
                )
            L_diff, _, _ = self.strands_generator.model_ema.loss_wo_logvar(diffusion_texture, noise, sigma, mask=mask, unet_cond=None)
            # self.Lsds = L_diff.mean()
            self.LDiff = L_diff.mean()
        
        self._pts = p
        first_difference = self._pts[:, 1:] - self._pts[:, :-1]
        self.hair_smoothness = (first_difference[:,1:,:]-first_difference[:,:-1,:]).norm(dim=-1).mean()/first_difference.norm(dim=-1).mean()
        self._xyz = (p[:, 1:] + p[:, :-1]).view(-1, 3) * 0.5
        self._dir = (p[:, 1:] - p[:, :-1]).view(-1, 3)
        # self._label = z[:, :1].view(self.num_strands, 1, 1).repeat(1, self.strand_length - 1, 1).view(-1, 1)
        # z_app = z[:, 1:]

        self._features_dc = features_dc.reshape(self.num_strands * (self.strand_length - 1), 1, 3)
        self._features_rest = features_rest.reshape(self.num_strands * (self.strand_length - 1), (self.max_sh_degree + 1) ** 2 - 1, 3)
        self._orient_conf = orient_conf.reshape(self.num_strands * (self.strand_length - 1), 1)

        # if self.num_guiding_strands:
        #     self._features_dc_gdn = features_dc[:self.num_guiding_strands].reshape(-1, 1, 3)
        #     self._features_rest_gdn = features_rest[:self.num_guiding_strands].reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3)
        # else:
        #     self._features_dc_gdn = self._features_dc
        #     self._features_rest_gdn = self._features_rest

        return diffusion_dict
    
    def generate_strands_coarse(self, iter, num_strands = -1, viewpoint_cam=None):
        if viewpoint_cam == None: 
            p, uvs, local2world, p_local, features_dc, features_rest, orient_conf, diffusion_dict = self.strands_generator(0, 0)
        else:
            if num_strands != -1:
                p, uvs, local2world, p_local, features_dc, features_rest, orient_conf, z = self.strands_generator.forward_inference(num_strands, viewpoint_cam.time_step)
                diffusion_dict = {}
            else:
                p, uvs, local2world, p_local, features_dc, features_rest, orient_conf, diffusion_dict = self.strands_generator(iter, viewpoint_cam.time_step, self.training)
        
 
        self.num_strands = p.shape[0]
        self.strand_length = p.shape[1]
        self.idx_active_mask = None
        
        self._pts = p
        self._xyz = (p[:, 1:] + p[:, :-1]).view(-1, 3) * 0.5
        self._dir = (p[:, 1:] - p[:, :-1]).view(-1, 3)
        
        if viewpoint_cam != None and num_strands == -1:
            self._xyzs = (p[:, 1:] + p[:, :-1]) * 0.5
        # self._label = z[:, :1].view(self.num_strands, 1, 1).repeat(1, self.strand_length - 1, 1).view(-1, 1)
        # z_app = z[:, 1:]
        
        if self.num_guiding_strands:
            self._xyz_gdn = self._xyz.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
            self._dir_gdn = self._dir.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
        else:
            self._xyz_gdn = self._xyz
            self._dir_gdn = self._dir

        self._features_dc = features_dc.reshape(self.num_strands * (self.strand_length - 1), 1, 3)
        self._features_rest = features_rest.reshape(self.num_strands * (self.strand_length - 1), (self.max_sh_degree + 1) ** 2 - 1, 3)
        self._orient_conf = orient_conf.reshape(self.num_strands * (self.strand_length - 1), 1)


        return diffusion_dict
    def generate_strands_coarse_sparse(self, iter, num_strands = -1, viewpoint_cam=None):

        if num_strands != -1:
            with torch.no_grad():
                p, uvs, local2world, p_local, features_dc, features_rest, orient_conf, z = self.strands_generator.forward_inference(num_strands, viewpoint_cam.time_step)
                diffusion_dict = {}
        else:
            p, uvs, local2world, p_local, features_dc, features_rest, orient_conf, diffusion_dict = self.strands_generator.forward_sparse(iter, viewpoint_cam.time_step, self.training)

        self.num_strands = p.shape[0]
        self.strand_length = p.shape[1]
        self.idx_hair_active_mask = self.strands_generator.idx_active_mask[:,None].repeat(1, self.strand_length-1).reshape(-1)
        self.hair_active_num = self.strands_generator.idx_active_mask.sum()
        

        self._pts = p
        self._xyz = (p[:, 1:] + p[:, :-1]).view(-1, 3) * 0.5
        self._dir = (p[:, 1:] - p[:, :-1]).view(-1, 3)
        
        # if viewpoint_cam != None and num_strands != -1:    
        #     self._xyzs = (p[:, 1:] + p[:, :-1]) * 0.5
        # self._label = z[:, :1].view(self.num_strands, 1, 1).repeat(1, self.strand_length - 1, 1).view(-1, 1)
        # z_app = z[:, 1:]
        
        if self.num_guiding_strands:
            self._xyz_gdn = self._xyz.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
            self._dir_gdn = self._dir.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
        else:
            self._xyz_gdn = self._xyz
            self._dir_gdn = self._dir

        self._features_dc = features_dc.reshape(self.num_strands * (self.strand_length - 1), 1, 3)
        self._features_rest = features_rest.reshape(self.num_strands * (self.strand_length - 1), (self.max_sh_degree + 1) ** 2 - 1, 3)
        self._orient_conf = orient_conf.reshape(self.num_strands * (self.strand_length - 1), 1)

        return diffusion_dict
    
    def generate_strands_inference(self, iter, num_strands = -1, time_step=0):
        with torch.no_grad():
            _p, _, _, _, _, z = self.strands_generator.forward_inference(num_strands, time_step)
        diffusion_dict = {}
        if time_step == 0:
            p = _p
        else:
        # with torch.no_grad():
            L = _p.shape[1]
            p_world_gdn = _p.view(-1, 3).cuda()
            opacity_wolrd_gdn = torch.ones_like(p_world_gdn[:, :1])
            time = time_step / self.strands_generator.num_time_steps
            # print("time:  ", time)
            time = torch.tensor(time).to(p_world_gdn.device).repeat(p_world_gdn.shape[0],1)
            new_p_world_gdn, _, _, _, _, _ = \
                self._deformation(p_world_gdn, opacity = opacity_wolrd_gdn,
                                  times_sel = time, deformaton_pts_scale = self.deformaton_scale)  
            p = new_p_world_gdn.view(-1, L, 3)
        

        self.num_strands = p.shape[0]
        self.strand_length = p.shape[1]

        self._pts = p
        self._xyz = (p[:, 1:] + p[:, :-1]).view(-1, 3) * 0.5
        self._dir = (p[:, 1:] - p[:, :-1]).view(-1, 3)
        # self._label = z[:, :1].view(self.num_strands, 1, 1).repeat(1, self.strand_length - 1, 1).view(-1, 1)
        z_app = z[:, 1:]
        
        if self.num_guiding_strands:
            self._xyz_gdn = self._xyz.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
            self._dir_gdn = self._dir.view(self.num_strands, self.strand_length - 1 , 3)[:self.num_guiding_strands].view(-1, 3)
        else:
            self._xyz_gdn = self._xyz
            self._dir_gdn = self._dir

        # Assign spherical harmonics features
        if z_app.shape[1] == 3 * (self.max_sh_degree + 1) ** 2:
            features_dc, features_rest = z_app.view(self.num_strands, 1, 3 * (self.max_sh_degree + 1) ** 2).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1)], dim=-1)
            features_dc = features_dc.repeat(1, self.strand_length - 1, 1)
            features_rest = features_rest.repeat(1, self.strand_length - 1, 1)        
        elif z_app.shape[1] == (self.strand_length - 1) * 3 * (self.max_sh_degree + 1) ** 2:
            features_dc, features_rest = z_app.view(self.num_strands, self.strand_length - 1, 3 * (self.max_sh_degree + 1) ** 2).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1)], dim=-1)
        elif z_app.shape[1] == 64:
            features_dc, features_rest, orient_conf = self.color_decoder(z_app).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)

        self._features_dc = features_dc.reshape(self.num_strands * (self.strand_length - 1), 1, 3)
        self._features_rest = features_rest.reshape(self.num_strands * (self.strand_length - 1), (self.max_sh_degree + 1) ** 2 - 1, 3)
        self._orient_conf = orient_conf.reshape(self.num_strands * (self.strand_length - 1), 1)

        if self.num_guiding_strands:
            self._features_dc_gdn = features_dc[:self.num_guiding_strands].reshape(-1, 1, 3)
            self._features_rest_gdn = features_rest[:self.num_guiding_strands].reshape(-1, (self.max_sh_degree + 1) ** 2 - 1, 3)
        else:
            self._features_dc_gdn = self._features_dc
            self._features_rest_gdn = self._features_rest

        return diffusion_dict

    def initialize_gaussians_hair(self, iter, num_strands=-1, viewpoint_cam=None, state="coarse"):
        if state == "coarse":
            diffusion_dict = self.generate_strands_coarse(iter, num_strands, viewpoint_cam)
            if 'L_diff' in diffusion_dict.keys():
                self.LDiff = diffusion_dict['L_diff']
            else:
                self.LDiff = None
        else:
            # diffusion_dict = self.generate_strands(iter, num_strands, time_step)
            # diffusion_dict = self.generate_strands_local(iter, num_strands, time_step)
            diffusion_dict = self.generate_strands_coarse_sparse(iter, num_strands, viewpoint_cam)
            if 'L_diff' in diffusion_dict.keys():
                self.LDiff = diffusion_dict['L_diff']
            else:
                self.LDiff = None
        
        # Assign geometric features        
        self._rotation = parallel_transport(
            a=torch.cat(
                [
                    torch.ones_like(self._xyz[:, :1]),
                    torch.zeros_like(self._xyz[:, :2])
                ],
                dim=-1
            ),
            b=self._dir
        ).view(-1, 4) # rotation parameters that align x-axis with the segment direction

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def create_from_pcd(self, data_path, scale = 1e-3,spatial_lr_scale = 1):
        with torch.no_grad():
            # self.generate_strands(0)
            self.strands_generator.int_para_precomp(0)   
            self.generate_strands_coarse(0)

        # scene_transform = pkl.load(open(os.path.join(data_path, "scale.pickle"), 'rb'))
        # self.scale = scale * scene_transform['scale'] * torch.ones(1, device="cuda")
        xyz_max = np.ones(3) * 1.0
        xyz_min = np.ones(3) * -1.0
        self._deformation.deformation_net.set_aabb(xyz_max,xyz_min)
        self.scale = 1e-4 * torch.ones(1, device="cuda")
        self.spatial_lr_scale = spatial_lr_scale
        self._deformation = self._deformation.to("cuda")
        self.gaussRender = self.gaussRender.to("cuda")
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        self.gaussianWeight = self.gaussianWeight.to("cuda")

    def training_setup(self, training_args, training_args_hair):
        # self.optimizer = torch.optim.AdamW(list(self.strands_generator.parameters()) + list(self.color_decoder.parameters()), training_args_hair['general']['lr'])
        # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=training_args.iterations, eta_min=1e-4)
        strand_generator_params = []
        strand_generator_grid_params = []
        for name, param in self.strands_generator.named_parameters():
            if  "grid" in name:
                strand_generator_grid_params.append(param)
            else:
                strand_generator_params.append(param)
        self.strand_generator_grid_params = strand_generator_grid_params
                
        # self.optimizer_strands = torch.optim.AdamW(self.strands_generator.parameters(), training_args_hair['general']['lr'])
        # self.optimizer_strands_grid = torch.optim.AdamW(self.strands_generator.parameters(), training_args_hair['general']['lr'])
        self.optimizer_strands = torch.optim.AdamW(strand_generator_params, training_args_hair['general']['lr'])
        self.optimizer_strands_grid = torch.optim.AdamW(strand_generator_grid_params, training_args.grid_coarse_lr_init * self.spatial_lr_scale)
        self.optimizer_gaussianWeight = torch.optim.AdamW(self.gaussianWeight.parameters(), training_args_hair['general']['lr'])
        self.optimizer_color = torch.optim.AdamW(self.color_decoder.parameters(), training_args_hair['general']['lr'])
        l = [
            {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_coarse_lr_init * self.spatial_lr_scale, "name": "deformation"},
            {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_coarse_lr_init * self.spatial_lr_scale, "name": "grid"},
        ]
        lr_min = training_args_hair['general']['lr'] * 1e-1
        lr_min_grid = training_args.grid_coarse_lr_final * self.spatial_lr_scale
        self.optimizer_deformation = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        self.scheduler_strands = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_strands, T_max=training_args.iterations, eta_min = lr_min)
        self.scheduler_strands_grid = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_strands_grid, T_max=training_args.iterations, eta_min = lr_min_grid)
        # self.scheduler_gaussianWeight = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_gaussianWeight, T_max=5000, eta_min = 1e-5)
        self.scheduler_gaussianWeight = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_gaussianWeight, T_max=training_args.iterations, eta_min = lr_min)
        # self.scheduler_gaussianWeight = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_gaussianWeight, T_max=training_args.iterations, eta_min = lr_min)
        # self.scheduler_strands = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_strands, T_max=5000, eta_min = 1e-5)
        # self.scheduler_strands_grid = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_strands_grid, T_max=5000, eta_min = 1e-5)
        # self.scheduler_gaussianWeight = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_gaussianWeight, T_max=2000, eta_min = 1e-7)
        
        self.scheduler_color = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_color, T_max=training_args.iterations, eta_min = lr_min)
        self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_coarse_lr_init * self.spatial_lr_scale,
                                                    lr_final=training_args.deformation_coarse_lr_final * self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.deformation_coarse_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)    
        self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_coarse_lr_init * self.spatial_lr_scale,
                                                    lr_final=training_args.grid_coarse_lr_final * self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.deformation_coarse_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)  
        
        for param in self._deformation.get_mlp_parameters():
            param.requires_grad = False
        for param in self._deformation.get_grid_parameters():
            param.requires_grad = False              
        # for param in self.strand_generator_grid_params:
        #     param.requires_grad = False  
    def reset_optimizers(self,training_args, training_args_hair,decay=1e-2):
        for param_group in self.optimizer_strands.param_groups:
            param_group['lr'] = training_args_hair['general']['lr'] * decay
        lr_min = training_args_hair['general']['lr'] * 1e-1 * decay
        self.scheduler_strands = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer_strands,
            T_max=training_args.iterations,
            eta_min=lr_min
        )
        for param_group in self.optimizer_strands_grid.param_groups:
            param_group['lr'] = training_args.grid_coarse_lr_init * self.spatial_lr_scale * decay
        lr_min_grid = training_args.grid_coarse_lr_final * self.spatial_lr_scale * decay
        self.scheduler_strands_grid = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer_strands_grid,
            T_max=training_args.iterations,
            eta_min=lr_min_grid
        )
        
    def update_learning_rate(self, iteration, coarse_iteration):
        # self.scheduler.step()
        self.scheduler_strands.step()
        self.scheduler_strands_grid.step()
        self.scheduler_gaussianWeight.step()
        # self.scheduler_color.step()
        # for param_group in self.optimizer_deformation.param_groups:
        #     if param_group["name"] == "deformation":
        #         lr = self.deformation_scheduler_args(max(0,iteration-coarse_iteration))
        #         param_group['lr'] = lr
        #         # return lr
        #     elif "grid" in param_group["name"]:
        #         lr = self.grid_scheduler_args(max(0,iteration-coarse_iteration))
        #         param_group['lr'] = lr
                
    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l
    def save_deformation(self, path):
        torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth"))
        torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth"))
    
    def save_ply(self, path):
        mkdir_p(os.path.dirname(path))

        xyz = self._xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = torch.ones_like(self._xyz[:, :1]).detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
        
        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)
        
    def save_origins(self, uvs_index_select, uvs_index_wo_select, path):
        uvs_index_select = uvs_index_select.cpu().detach().numpy()
        uvs_index_wo_select = uvs_index_wo_select.cpu().detach().numpy()

        uvs_index_select_color = np.ones_like(uvs_index_select) * np.array([255, 0, 0])
        uvs_index_wo_select_color = np.ones_like(uvs_index_wo_select) * np.array([0, 0, 255])
        
        uvs_index_select_data = np.empty(uvs_index_select.shape[0], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
        uvs_index_wo_select_data = np.empty(uvs_index_wo_select.shape[0], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
        
        uvs_index_select_attributes = np.concatenate((uvs_index_select, uvs_index_select_color), axis=1)
        uvs_index_wo_select_attributes = np.concatenate((uvs_index_wo_select, uvs_index_wo_select_color), axis=1)
        
        uvs_index_select_data[:] = list(map(tuple, uvs_index_select_attributes))
        uvs_index_wo_select_data[:] = list(map(tuple, uvs_index_wo_select_attributes))
        
        elements = np.concatenate((uvs_index_select_data, uvs_index_wo_select_data), axis=0)
        ply_element = PlyElement.describe(elements, 'vertex')
        PlyData([ply_element], text=False).write(path)
        
    def _plane_regulation(self):
        multi_res_grids = self._deformation.deformation_net.grid.grids
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids =  [0,1,3]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return total
    def _time_regulation(self):
        multi_res_grids = self._deformation.deformation_net.grid.grids
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids =[2, 4, 5]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return total
    def _l1_regulation(self):
                # model.grids is 6 x [1, rank * F_dim, reso, reso]
        multi_res_grids = self._deformation.deformation_net.grid.grids

        total = 0.0
        for grids in multi_res_grids:
            if len(grids) == 3:
                continue
            else:
                # These are the spatiotemporal grids
                spatiotemporal_grids = [2, 4, 5]
            for grid_id in spatiotemporal_grids:
                total += torch.abs(1 - grids[grid_id]).mean()
        return total
    def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):
        return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation()
    def densify_and_prune(self, gaussian_hair_grad, num_threshold = 10000, iteration = 0,  densification_threshold = 5 * 1e-4):
        # num_threshold = 15000
            # print(f"densify_grid {iteration}")
            device = torch.device('cuda')
            import time
        # if self.strands_generator.uvs_index_select.shape[0] > num_threshold and iteration >= coarse_iteration and iteration % densification_interval == 0 and iteration % densification_recover_interval != 0:
        # if self.strands_generator.uvs_index_select.shape[0] > num_threshold and iteration >= coarse_iteration and iteration % densification_interval == 0:
            # grad_norm = torch.norm(self.strands_generator.grad_uvs_p_storage, p = 2, dim = 1)
            strands_grad_mask = torch.norm(
                gaussian_hair_grad.view(-1, self.strands_generator.strand_length - 1, 3),
                p=2, dim=(1, 2)
            )

            densify_index_mask = strands_grad_mask >= densification_threshold
            index_select = torch.arange(densify_index_mask.shape[0], device=device)
            index_select = index_select[densify_index_mask]
            uvs_select = self.strands_generator.uvs_select[index_select] # N x 2

            uvs_grid = self.strands_generator.dynamic_mgrid.clone() # 2 x (res + 1) x (res + 1)
            uvs_mask = self.strands_generator.dynamic_mgrid_mask.clone() # (res + 1) x (res + 1)

            res_final = uvs_grid.shape[1] - 1
            res_init = self.strands_generator.res_init.item()
            # multi_res = [(res + 1) * res_init for res in range(res_final/res_init)]
            M = math.ceil(math.log(res_final / res_init, 2)) + 1
            multi_res = [res_init * (2 ** i) for i in range(M)]

            # uvs_select_ad = ((uvs_select - torch.tensor([-1,-1]).cuda()) / torch.tensor([2,2]).cuda())[None,:,:]
            uvs_select_ad = ((uvs_select + 1) * 0.5).unsqueeze(0)
            multi_res_ad = torch.tensor(multi_res).cuda().unsqueeze(-1).unsqueeze(-1)
            uvs_select_ad =  uvs_select_ad * multi_res_ad

            multi_res_ad_grid_four = multi_res_ad[:,:,None,:].int().repeat(1,1,4,1)
            uvs_select_ad_grid = torch.floor(uvs_select_ad).int() # M x N x 2
            # uvs_select_ad_grid_four = uvs_select_ad_grid.unsqueeze(2).repeat(1,1,4,1) # M x N x 4 x 2
            # uvs_select_ad_grid_four[:,:,1,0] = uvs_select_ad_grid_four[:,:,1,0] + 1
            # uvs_select_ad_grid_four[:,:,2,1] = uvs_select_ad_grid_four[:,:,2,1] + 1
            # uvs_select_ad_grid_four[:,:,3,:] = uvs_select_ad_grid_four[:,:,3,:] + 1
            offsets = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32).view(1, 1, 4, 2).cuda()
            uvs_select_ad_grid_four = uvs_select_ad_grid.unsqueeze(2) + offsets
            uvs_select_ad_grid_four = torch.minimum(uvs_select_ad_grid_four, multi_res_ad_grid_four)

            # import ipdb; ipdb.set_trace()
            scale_factors = torch.tensor([res_final // res for res in multi_res],
                             device=device,
                             dtype=torch.int32)[:, None, None, None]
            uvs_select_ad_grid_four_high_res = (uvs_select_ad_grid_four * scale_factors).int()
            mask_vals = uvs_mask[uvs_select_ad_grid_four_high_res[..., 0],
                     uvs_select_ad_grid_four_high_res[..., 1]]  # shape: (M, N, 4)
            uvs_select_ad_grid_four_high_res_mask = torch.all(mask_vals, dim=-1).transpose(0, 1).int()
            # uvs_select_ad_grid_four_high_res = (uvs_select_ad_grid_four * torch.tensor([int(res_final/res) for res in multi_res]).cuda()[:,None,None,None]).int()
            # uvs_select_ad_grid_four_high_res_mask = uvs_mask[uvs_select_ad_grid_four_high_res[:,:,:,0],uvs_select_ad_grid_four_high_res[:,:,:,1]].reshape(uvs_select_ad_grid_four.shape[0]*uvs_select_ad_grid_four.shape[1],4)
            # uvs_select_ad_grid_four_high_res_mask = torch.any(uvs_select_ad_grid_four_high_res_mask, dim=1).reshape(uvs_select_ad_grid_four.shape[0],uvs_select_ad_grid_four.shape[1]).transpose(1,0).int()
            # uvs_select_ad_grid_four_high_res_mask : N x M; uvs_res_select: N 
            # uvs_res_select = ( torch.argmax(uvs_select_ad_grid_four_high_res_mask, dim=1).int().cuda() + 1 ) * res_init
            uvs_grid_idx = M-1-torch.argmax(uvs_select_ad_grid_four_high_res_mask.flip(dims=[1]), dim=1).int()

            multi_res_tensor = torch.tensor(multi_res, device=device)[:, None]
            uvs_select_ad_grid_four_left_up = torch.minimum(uvs_select_ad_grid_four, multi_res_ad_grid_four - 1)
            left_index = uvs_select_ad_grid_four_left_up[:, :, 0, 0] + uvs_select_ad_grid_four_left_up[:, :, 0, 1] * multi_res_tensor

            uvs_grid_select = left_index.int().transpose(1,0)

            uvs_grid_mask = torch.zeros(len(multi_res), res_final*res_final , device=device).int()
   
            uvs_grid_index = torch.stack([uvs_grid_idx, uvs_grid_select[torch.arange(uvs_grid_idx.shape[0]),uvs_grid_idx]], dim=1).int().cuda()
            values = torch.ones(uvs_grid_index.shape[0]).int().cuda()
            uvs_grid_mask.index_put_((uvs_grid_index[:, 0], uvs_grid_index[:, 1]), values, accumulate=True)

            # uvs_grid_mask: M x res_final * res_final
            grid_density_threshold = 20
            uvs_grid_mask_mask = uvs_grid_mask >= grid_density_threshold
            # uvs_grid_mask_mask: M x res_final * res_final
            if torch.any(uvs_grid_mask_mask[-1]):
                uvs_grid = torch.stack(torch.meshgrid([torch.linspace(-1, 1, 2 * res_final + 1)]*2, indexing='xy')).cuda()
                new_uvs_mask = torch.zeros(2 * res_final + 1, 2 * res_final + 1).int().cuda()
                new_uvs_mask[::2,::2] = uvs_mask
                # uvs_grid = new_uvs_grid
                uvs_mask = new_uvs_mask
            new_res_final = uvs_grid.shape[1] - 1
            if new_res_final > 256:
                return None
  
            multi_res_tensor = torch.tensor(multi_res, device=device, dtype=torch.float)

            mask_indices = torch.nonzero(uvs_grid_mask_mask)

            i_vals = mask_indices[:, 0]
            j_vals = mask_indices[:, 1].to(torch.float)

            ms = multi_res_tensor[i_vals]

            left_up_y = (j_vals / ms).int()
            left_up_x = torch.remainder(j_vals, ms).int()
            left_up = torch.stack([left_up_x, left_up_y], dim=1)  # shape: (N, 2)

            offsets = torch.tensor([[0.5, 0.5],   # middle
                        [0.0, 0.5],   # middle_up
                        [0.5, 0.0],   # middle_left
                        [1.0, 0.5],   # middle_bottom
                        [0.5, 1.0]],  # middle_right
                       device=device, dtype=torch.float)  # shape: (5, 2)
            all_coords_float = left_up.unsqueeze(1) + offsets.unsqueeze(0)
            scale = (new_res_final / ms).unsqueeze(1).unsqueeze(2)  # shape: (N, 1, 1)
            scaled_coords = all_coords_float * scale  
            all_coords = torch.round(scaled_coords).long().reshape(-1, 2)

            uvs_mask[all_coords[:, 0], all_coords[:, 1]] = 1
            self.strands_generator.dynamic_mgrid = uvs_grid
            self.strands_generator.dynamic_mgrid_mask = uvs_mask
            self.strands_generator.grid_active_set = uvs_mask
            self.uvs_mask_count = torch.sum(uvs_mask)
    def update_grid_active_set(self, num_threshold = 10000, iteration = 0, sparse_iteration = 0, coarse_iteration = 1000, grid_active_interval = 100, grid_active_threshold = 1 * 1e-5):
        
        # grid_grad = torch.norm(self.strands_generator.grad_uvs_p_storage, p=2, dim=-1)
        grid_grad = self.grid_gradient_accum_multi_view.mean(dim=0).norm(p=2, dim=-1)
        
        grid_active_mask = grid_grad > grid_active_threshold  # shape: (N,)
        uvs_mask = self.strands_generator.dynamic_mgrid_mask.detach()
        uvs_mask_flat = uvs_mask.view(-1)  # shape: ((res+1)*(res+1),)
        selected_idx = torch.nonzero(uvs_mask_flat, as_tuple=True)[0]  # shape: (N,)
        grid_active_set = torch.zeros_like(uvs_mask_flat, dtype=torch.int)
        grid_active_set[selected_idx] = grid_active_mask.int()
        self.strands_generator.grid_active_set = grid_active_set.view(*uvs_mask.shape)
        self.uvs_mask_active_count = grid_active_set.sum()

    def densify_and_prune_v2(self, num_threshold = 10000, iteration = 0,  densification_threshold = 5 * 1e-4):
        import time
        uvs_grid = self.strands_generator.dynamic_mgrid.clone() # 2 x (res + 1) x (res + 1)
        res_final = uvs_grid.shape[1] - 1
        new_res_final = res_final * 2
        if new_res_final <= 32:
            uvs_grid = torch.stack(torch.meshgrid([torch.linspace(-1, 1, 2 * res_final + 1)]*2, indexing='xy')).cuda()
            new_uvs_mask = torch.ones(2 * res_final + 1, 2 * res_final + 1).int().cuda()
            self.strands_generator.dynamic_mgrid = uvs_grid
            self.strands_generator.dynamic_mgrid_mask = new_uvs_mask
            self.strands_generator.grid_active_set = new_uvs_mask
            self.uvs_mask_count = torch.sum(new_uvs_mask)
            self.strands_generator.int_para_precomp(0) 
        if self.strands_generator.num_strands*2 < 10000:
            self.strands_generator.num_strands *= 2
            self.grid_interval *=2
        if iteration % 100 == 0:
            print("uvs_mask_count: ",self.uvs_mask_count) 
    def add_densification_stats_average(self, cam_id):
        if self.grid_gradient_accum_multi_view is None:
            # self.grid_gradient_accum_multi_view = torch.zeros((16,self.strands_generator.self.dynamic_mgrid_mask.sum(), 52), device="cuda")
            grad_uvs_shape = self.strands_generator.grad_uvs_p_storage.shape
            self.grid_gradient_accum_multi_view = torch.zeros((16,grad_uvs_shape[0],grad_uvs_shape[1]), device="cuda")
        # self.grid_gradient_accum_multi_view[cam_id] = torch.norm(self.strands_generator.grad_uvs_p_storage, p=2, dim=-1) 
        self.grid_gradient_accum_multi_view[cam_id] = self.strands_generator.grad_uvs_p_storage

    def update_wo_active_set_data(self,iteration, viewpoint_cam,scene,render_hair_weight_wo_active,renderArgs):
        with torch.no_grad():
            config = [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(0, 16, 1)]
            for idx, viewpoint in enumerate(config):
                state = 'coarse'
                self.initialize_gaussians_hair(iteration, num_strands = -1, viewpoint_cam = viewpoint_cam, state=state)
                render_pkg = render_hair_weight_wo_active(viewpoint, scene.gaussians, self, *renderArgs)
                self.wo_active_set_data[viewpoint.camera_index] = render_pkg
                # self.wo_active_set_data[viewpoint.time_step][idx] = render_pkg