import torch
import numpy as np
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
from torch import nn
import os
from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement
from utils.sh_utils import RGB2SH
from simple_knn._C import distCUDA2
from utils.graphics_utils import BasicPointCloud
from utils.general_utils import strip_symmetric, build_scaling_rotation
import utils.general_utils as utils
from utils.loss_utils import entity_binary_convert_tensor
import torch.distributed as dist
import math

# from gsplat import quat_scale_to_covar_preci

lr_scale_fns = {
    "linear": lambda x: x,
    "sqrt": lambda x: np.sqrt(x),
}


class GaussianModel:

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm

        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid
        self.entity_activation = torch.sigmoid

        self.rotation_activation = torch.nn.functional.normalize

    def __init__(self, sh_degree: int, args):
        self.active_sh_degree = 0
        self.max_sh_degree = sh_degree
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(
            0
        )
        self.denom = torch.empty(0)
        self.optimizer = None
        self.percent_dense = 0
        self.spatial_lr_scale = 0

        self._semantic_feature = torch.empty(0)
        self.entity_ids = torch.empty(0)
        self.entity_cls = torch.empty(0)
        self.entity_cls_num = torch.empty(0)
        self.is_control_init = False
        self.is_entity_init = False
        self.semantic_feature_dim = args.semantic_feature_dim
        self.use_truncated_binary = args.use_truncated_binary
        self.entity_dim = args.entity_dim
        self.entity_cls_num = args.entity_cls_num
        self.setup_functions()

    def entity_init(self):
        self.gs_binary_ids = torch.round(self.get_entity).to(torch.long)
        if self.use_truncated_binary:
            self.entity_cls = entity_binary_convert_tensor(
                torch.arange(0, self.entity_cls_num).unsqueeze(0).unsqueeze(0).cuda()).permute(2, 1, 0)
        else:
            self.entity_cls = torch.unique(self.gs_binary_ids, dim=0)
        self.entity_cls_num = self.entity_cls.shape[0]
        self.is_entity_init = True

    def entity_extract_mask(self, entity_id):
        entity_mask = torch.where(self.entity_ids == entity_id)[0]
        return entity_mask

    def capture(self):
        return (
            self.active_sh_degree,
            self._xyz,
            self._features_dc,
            self._features_rest,
            self._scaling,
            self._rotation,
            self._opacity,
            self.max_radii2D,
            self._semantic_feature,
            self.entity_ids,
            self.xyz_gradient_accum,
            self.denom,
            self.optimizer.state_dict(),
            self.spatial_lr_scale,
        )

    def restore(self, model_args, training_args):
        (
            self.active_sh_degree,
            self._xyz,
            self._features_dc,
            self._features_rest,
            self._scaling,
            self._rotation,
            self._opacity,
            self.max_radii2D,
            self._semantic_feature,
            self.entity_ids,
            xyz_gradient_accum,
            denom,
            opt_dict,
            self.spatial_lr_scale,
        ) = model_args

        # self.eneity_cls = torch.unique(self.entity_ids)
        # self.entity_cls_num = self.entity_cls.shape[0]
        self.training_setup(training_args)
        self.xyz_gradient_accum = (
            xyz_gradient_accum
        )
        self.denom = denom
        if opt_dict is not None:
            self.optimizer.load_state_dict(opt_dict)

    @property
    def get_entity(self):
        return self.entity_activation(self.entity_ids)

    @property
    def get_scaling(self):
        return self.scaling_activation(self._scaling)

    @property
    def get_rotation(self):
        return self.rotation_activation(self._rotation)

    @property
    def get_xyz(self):
        return self._xyz

    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)

    @property
    def get_opacity(self):
        return self.opacity_activation(self._opacity)

    @property
    def get_semantic_feature(self):
        return self._semantic_feature

    def get_covariance(self, scaling_modifier=1, d_rotation=None, gs_rot_bias=None):
        if d_rotation is not None:
            rotation = quaternion_multiply(self._rotation, d_rotation)
        else:
            rotation = self._rotation
        if gs_rot_bias is not None:
            rotation = rotation / rotation.norm(dim=-1, keepdim=True)
            rotation = quaternion_multiply(gs_rot_bias, rotation)
        return self.covariance_activation(self.get_scaling, scaling_modifier, rotation)

    def update_sem_centroid(self):
        self.entity_init()
        self.semantic_centroid = torch.zeros([self.entity_cls_num, self._semantic_feature.shape[-1]],
                                             device=self._xyz.device)
        # self.xyz_centroid = torch.zeros([self.entity_cls_num, 3])
        for i in range(self.entity_cls_num):
            cur_entity_id = self.entity_cls[i]
            row_index = torch.where(torch.eq(self.gs_binary_ids, cur_entity_id).squeeze(1).all(dim=1))[0]
            self.semantic_centroid[i, :] = torch.mean(self._semantic_feature[row_index], dim=0)
            # self.xyz_centroid[i, :] = torch.mean(self._xyz[row_index], dim=0)

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def get_rotation_bias(self, rotation_bias=None, gs_detach=False, entity_index=None):
        rotation_bias = rotation_bias if rotation_bias is not None else 0.
        if not gs_detach:
            if entity_index is not None:
                return self.rotation_activation(self._rotation.index_add(0, entity_index, rotation_bias))
            else:
                return self.rotation_activation(self._rotation)
        else:
            if entity_index is not None:
                return self.rotation_activation(self._rotation.detach().index_add(0, entity_index, rotation_bias))
            else:
                return self.rotation_activation(self._rotation.detach() + rotation_bias)

    def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
        log_file = utils.get_log_file()
        # loading could replicated on all ranks.
        self.spatial_lr_scale = spatial_lr_scale

        fused_point_cloud = (
            torch.tensor(np.asarray(pcd.points)).float().cuda()
        )  # It is not contiguous
        fused_point_cloud = fused_point_cloud.contiguous()  # Now it's contiguous
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = (
            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
            .float()
            .cuda()
        )
        features[:, :3, 0] = fused_color
        features[:, 3:, 1:] = 0.0

        # entity_feature = RGB2SH(torch.rand((fused_point_cloud.shape[0], self.entity_dim), device="cuda"))
        entity_feature = torch.ones((fused_point_cloud.shape[0], self.entity_dim), device="cuda") * 0.5
        entity_feature = entity_feature[:, :, None]
        semantic_feature = RGB2SH(torch.rand((fused_point_cloud.shape[0], self.semantic_feature_dim), device="cuda"))
        semantic_feature = semantic_feature[:, :, None]

        if utils.GLOBAL_RANK == 0:
            print(
                "Number of points before initialization : ", fused_point_cloud.shape[0]
            )

        dist2 = torch.clamp_min(
            distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
            0.0000001,
        )
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        opacities = inverse_sigmoid(
            0.1
            * torch.ones(
                (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
            )
        )

        args = utils.get_args()
        if (
                args.gaussians_distribution
        ):  # shard 3dgs storage across all GPU including dp and mp groups.
            shard_world_size = utils.DEFAULT_GROUP.size()
            shard_rank = utils.DEFAULT_GROUP.rank()

            point_ind_l, point_ind_r = utils.get_local_chunk_l_r(
                fused_point_cloud.shape[0], shard_world_size, shard_rank
            )
            fused_point_cloud = fused_point_cloud[point_ind_l:point_ind_r].contiguous()
            features = features[point_ind_l:point_ind_r].contiguous()
            scales = scales[point_ind_l:point_ind_r].contiguous()
            rots = rots[point_ind_l:point_ind_r].contiguous()
            opacities = opacities[point_ind_l:point_ind_r].contiguous()
            semantic_feature = semantic_feature[point_ind_l:point_ind_r].contiguous()
            entity_feature = entity_feature[point_ind_l:point_ind_r].contiguous()
            log_file.write(
                "rank: {}, Number of initialized points: {}\n".format(
                    utils.GLOBAL_RANK, fused_point_cloud.shape[0]
                )
            )

        if args.drop_initial_3dgs_p > 0.0:
            # drop each point with probability args.drop_initial_3dgs_p
            drop_mask = (
                    np.random.rand(fused_point_cloud.shape[0]) > args.drop_initial_3dgs_p
            )
            fused_point_cloud = fused_point_cloud[drop_mask]
            features = features[drop_mask]
            scales = scales[drop_mask]
            rots = rots[drop_mask]
            opacities = opacities[drop_mask]
            log_file.write(
                "rank: {}, Number of initialized points after random drop: {}\n".format(
                    utils.GLOBAL_RANK, fused_point_cloud.shape[0]
                )
            )
            # print("rank", utils.GLOBAL_RANK, "Number of initialized points after random drop : ", fused_point_cloud.shape[0])

        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(
            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)
        )
        self._features_rest = nn.Parameter(
            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)
        )
        self._scaling = nn.Parameter(scales.requires_grad_(True))
        self._rotation = nn.Parameter(rots.requires_grad_(True))
        self._opacity = nn.Parameter(opacities.requires_grad_(True))
        self._semantic_feature = nn.Parameter(semantic_feature.transpose(1, 2).contiguous().requires_grad_(True))
        self.entity_ids = nn.Parameter(entity_feature.transpose(1, 2).contiguous().requires_grad_(True))
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self.sum_visible_count_in_one_batch = torch.zeros(
            (self.get_xyz.shape[0]), device="cuda"
        )

    def add_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
        log_file = utils.get_log_file()
        # loading could replicated on all ranks.
        fused_point_cloud = (
            torch.tensor(np.asarray(pcd.points)).float().cuda()
        )  # It is not contiguous
        fused_point_cloud = fused_point_cloud.contiguous()  # Now it's contiguous
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = (
            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
            .float()
            .cuda()
        )
        features[:, :3, 0] = fused_color
        features[:, 3:, 1:] = 0.0

        # entity_feature = RGB2SH(torch.rand((fused_point_cloud.shape[0], self.entity_dim), device="cuda"))
        entity_feature = torch.ones((fused_point_cloud.shape[0], self.entity_dim), device="cuda") * 0.5
        entity_feature = entity_feature[:, :, None]
        # entity_feature = torch.tensor([[[-1.,-1,-1,1,1,1]]], dtype=torch.float32).cuda().repeat(fused_point_cloud.shape[0], 1, 1).permute(0,2,1,).contiguous()

        semantic_feature = RGB2SH(torch.rand((fused_point_cloud.shape[0], self.semantic_feature_dim), device="cuda"))
        semantic_feature = semantic_feature[:, :, None]

        if utils.GLOBAL_RANK == 0:
            print(
                "Number of points before initialization : ", fused_point_cloud.shape[0]
            )

        dist2 = torch.clamp_min(
            distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
            0.0000001,
        )
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        opacities = inverse_sigmoid(
            0.1
            * torch.ones(
                (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
            )
        )

        args = utils.get_args()
        if (
                args.gaussians_distribution
        ):  # shard 3dgs storage across all GPU including dp and mp groups.
            shard_world_size = utils.DEFAULT_GROUP.size()
            shard_rank = utils.DEFAULT_GROUP.rank()

            point_ind_l, point_ind_r = utils.get_local_chunk_l_r(
                fused_point_cloud.shape[0], shard_world_size, shard_rank
            )
            fused_point_cloud = fused_point_cloud[point_ind_l:point_ind_r].contiguous()
            features = features[point_ind_l:point_ind_r].contiguous()
            scales = scales[point_ind_l:point_ind_r].contiguous()
            rots = rots[point_ind_l:point_ind_r].contiguous()
            opacities = opacities[point_ind_l:point_ind_r].contiguous()
            semantic_feature = semantic_feature[point_ind_l:point_ind_r].contiguous()
            entity_feature = entity_feature[point_ind_l:point_ind_r].contiguous()
            log_file.write(
                "rank: {}, Number of initialized points: {}\n".format(
                    utils.GLOBAL_RANK, fused_point_cloud.shape[0]
                )
            )
            # print("rank", utils.GLOBAL_RANK, "Number of initialized points after gaussians_distribution : ", fused_point_cloud.shape[0])

        if args.drop_initial_3dgs_p > 0.0:
            # drop each point with probability args.drop_initial_3dgs_p
            drop_mask = (
                    np.random.rand(fused_point_cloud.shape[0]) > args.drop_initial_3dgs_p
            )
            fused_point_cloud = fused_point_cloud[drop_mask]
            features = features[drop_mask]
            scales = scales[drop_mask]
            rots = rots[drop_mask]
            opacities = opacities[drop_mask]
            log_file.write(
                "rank: {}, Number of initialized points after random drop: {}\n".format(
                    utils.GLOBAL_RANK, fused_point_cloud.shape[0]
                )
            )

        new_xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        new_features_dc = nn.Parameter(
            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)
        )
        new_features_rest = nn.Parameter(
            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)
        )
        new_scaling = nn.Parameter(scales.requires_grad_(True))
        new_rotation = nn.Parameter(rots.requires_grad_(True))
        new_opacity = nn.Parameter(opacities.requires_grad_(True))
        new_semantic_feature = nn.Parameter(semantic_feature.transpose(1, 2).contiguous().requires_grad_(True))
        new_entity_ids = nn.Parameter(entity_feature.transpose(1, 2).contiguous().requires_grad_(True))
        new_send_to_gpui_cnt = torch.zeros((new_xyz.shape[0], shard_world_size), dtype=torch.int, device="cuda")

        self.densification_postfix(
            new_xyz,
            new_features_dc,
            new_features_rest,
            new_opacity,
            new_scaling,
            new_rotation,
            new_semantic_feature,
            new_entity_ids,
            new_send_to_gpui_cnt,
        )

    def all_parameters(self):
        return [
            self._xyz,
            self._features_dc,
            self._features_rest,
            self._scaling,
            self._rotation,
            self._opacity,
            self._semantic_feature,
            self.entity_ids
        ]

    def training_setup(self, training_args):
        self.percent_dense = training_args.percent_dense
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")

        shard_world_size = self.group_for_redistribution().size()
        self.send_to_gpui_cnt = torch.zeros(
            (self.get_xyz.shape[0], shard_world_size), dtype=torch.int, device="cuda"
        )

        args = utils.get_args()
        log_file = utils.get_log_file()

        l = [
            {
                "params": [self._xyz],
                "lr": training_args.position_lr_init
                      * self.spatial_lr_scale
                      * args.lr_scale_pos_and_scale,
                "name": "xyz",
            },
            {
                "params": [self._features_dc],
                "lr": training_args.feature_lr,
                "name": "f_dc",
            },
            {
                "params": [self._features_rest],
                "lr": training_args.feature_lr / 20.0,
                "name": "f_rest",
            },
            {
                "params": [self._opacity],
                "lr": training_args.opacity_lr,
                "name": "opacity",
            },
            {
                "params": [self._scaling],
                "lr": training_args.scaling_lr * args.lr_scale_pos_and_scale,
                "name": "scaling",
            },
            {
                "params": [self._rotation],
                "lr": training_args.rotation_lr,
                "name": "rotation",
            },
            {
                "params": [self.entity_ids],
                "lr": training_args.feature_lr,
                "name": "entity_f",
            },
            {
                "params": [self._semantic_feature],
                "lr": training_args.feature_lr,
                "name": "sem_f",
            }
        ]

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        # self.optimizer = torch.optim.SGD(l, lr=0.0, momentum=0.1)

        bsz = utils.get_args().bsz
        for param_group in self.optimizer.param_groups:
            if training_args.lr_scale_mode == "linear":
                lr_scale = bsz
                param_group["lr"] *= lr_scale
            elif training_args.lr_scale_mode == "sqrt":
                lr_scale = np.sqrt(bsz)
                param_group["lr"] *= lr_scale
                if "eps" in param_group:  # Adam
                    param_group["eps"] /= lr_scale
                    param_group["betas"] = [beta ** bsz for beta in param_group["betas"]]
                    # utils.print_rank_0(param_group["name"] + " betas: " + str(param_group["betas"]))
                    log_file.write(
                        param_group["name"]
                        + " betas: "
                        + str(param_group["betas"])
                        + "\n"
                    )
            elif training_args.lr_scale_mode == "accumu":
                lr_scale = 1
            else:
                assert (
                    False
                ), f"lr_scale_mode {training_args.lr_scale_mode} not supported."

        self.xyz_scheduler_args = get_expon_lr_func(
            lr_init=training_args.position_lr_init * self.spatial_lr_scale * lr_scale * args.lr_scale_pos_and_scale,
            lr_final=training_args.position_lr_final * self.spatial_lr_scale * lr_scale * args.lr_scale_pos_and_scale,
            lr_delay_mult=training_args.position_lr_delay_mult,
            max_steps=args.position_lr_max_steps,
        )

        utils.check_initial_gpu_memory_usage("after training_setup")

    def training_setup_mcmc(self, args):
        params = [
            # name, value, lr
            ("xyz", torch.nn.Parameter(self._xyz), 1.6e-4 * self.spatial_lr_scale),
            ("f_dc", torch.nn.Parameter(self._features_dc), args.feature_lr),
            ("f_rest", torch.nn.Parameter(self._features_rest), args.feature_lr / 20),
            ("scaling", torch.nn.Parameter(self._scaling), args.scaling_lr),
            ("rotation", torch.nn.Parameter(self._rotation), args.rotation_lr),
            ("opacity", torch.nn.Parameter(self._opacity), args.opacity_lr),
            ("entity_f", torch.nn.Parameter(self.entity_ids), args.feature_lr),
            ("sem_f", torch.nn.Parameter(self._semantic_feature), args.feature_lr)
        ]
        self.splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to("cuda")
        BS = 3 * utils.DEFAULT_GROUP.size()
        optimizer_class = None
        optimizer_class = torch.optim.Adam
        self.optimizers = {
            name: optimizer_class(
                [{"params": self.splats[name], "lr": lr * math.sqrt(BS), "name": name}],
                eps=1e-15 / math.sqrt(BS),
                betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
            )
            for name, _, lr in params
        }
        self.cfg.strategy.check_sanity(self.splats, self.optimizers)
        self.strategy_state = self.cfg.strategy.initialize_state()
        self.schedulers = [
            # means has a learning rate schedule, that end at 0.01 of the initial value
            torch.optim.lr_scheduler.ExponentialLR(
                self.optimizers["means"], gamma=0.01 ** (1.0 / args.iterations)
            ),
        ]

    def log_gaussian_stats(self):
        # log the statistics of the gaussian model
        # number of total 3dgs on this rank
        num_3dgs = self._xyz.shape[0]
        # average size of 3dgs
        avg_size = torch.mean(torch.max(self.get_scaling, dim=1).values).item()
        # average opacity
        avg_opacity = torch.mean(self.get_opacity).item()
        stats = {
            "num_3dgs": num_3dgs,
            "avg_size": avg_size,
            "avg_opacity": avg_opacity,
        }

        # get the exp_avg, exp_avg_sq state for all parameters
        exp_avg_dict = {}
        exp_avg_sq_dict = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" in stored_state:
                    exp_avg_dict[group["name"]] = torch.mean(
                        torch.norm(stored_state["exp_avg"], dim=-1)
                    ).item()
                    exp_avg_sq_dict[group["name"]] = torch.mean(
                        torch.norm(stored_state["exp_avg_sq"], dim=-1)
                    ).item()
        return stats, exp_avg_dict, exp_avg_sq_dict

    def sync_gradients_for_replicated_3dgs_storage(self, batched_screenspace_pkg):
        args = utils.get_args()

        if "visible_count" in args.grad_normalization_mode:
            # allgather visibility filder from all dp workers, so that each worker contains the visibility filter of all data points.
            batched_locally_preprocessed_visibility_filter_int = [
                x.int()
                for x in batched_screenspace_pkg[
                    "batched_locally_preprocessed_visibility_filter"
                ]
            ]
            sum_batched_locally_preprocessed_visibility_filter_int = torch.sum(
                torch.stack(batched_locally_preprocessed_visibility_filter_int), dim=0
            )
            batched_screenspace_pkg[
                "sum_batched_locally_preprocessed_visibility_filter_int"
            ] = sum_batched_locally_preprocessed_visibility_filter_int

        if args.sync_grad_mode == "dense":
            sync_func = sync_gradients_densely
        elif args.sync_grad_mode == "sparse":
            sync_func = sync_gradients_sparsely
        elif args.sync_grad_mode == "fused_dense":
            sync_func = sync_gradients_fused_densely
        elif args.sync_grad_mode == "fused_sparse":
            sync_func = sync_gradients_fused_sparsely
        else:
            assert False, f"sync_grad_mode {args.sync_grad_mode} not supported."

        if not args.gaussians_distribution and utils.DEFAULT_GROUP.size() > 1:
            sync_func(self, utils.DEFAULT_GROUP)

    def update_learning_rate(self, iteration):
        """Learning rate scheduling per step"""
        for param_group in self.optimizer.param_groups:
            if param_group["name"] == "xyz":
                lr = self.xyz_scheduler_args(iteration)
                param_group["lr"] = lr
                return lr

    def construct_list_of_attributes(self):
        l = ["x", "y", "z", "nx", "ny", "nz"]
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
            l.append("f_dc_{}".format(i))
        for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
            l.append("f_rest_{}".format(i))
        l.append("opacity")
        for i in range(self._scaling.shape[1]):
            l.append("scale_{}".format(i))
        for i in range(self._rotation.shape[1]):
            l.append("rot_{}".format(i))
        for i in range(self._semantic_feature.shape[1] * self._semantic_feature.shape[2]):
            l.append('sem_{}'.format(i))
        for i in range(self.entity_ids.shape[1] * self.entity_ids.shape[2]):
            l.append('entity_{}'.format(i))
        return l

    def save_ply(
            self, path
    ):  # here, we should be in torch.no_grad() context. train.py ensures that.
        args = utils.get_args()
        # _xyz = _features_dc = _features_rest = _opacity = _scaling = _rotation = None
        _xyz = _features_dc = _features_rest = _opacity = _scaling = _rotation = _semantic_feature = entity_ids = None
        utils.log_cpu_memory_usage("start save_ply")
        group = utils.DEFAULT_GROUP
        if args.gaussians_distribution and not args.distributed_save:
            # gather all gaussians at rank 0
            def gather_uneven_tensors(tensor):
                # gather size of tensors on different ranks
                tensor_sizes = torch.zeros(
                    (group.size()), dtype=torch.int, device="cuda"
                )
                tensor_sizes[group.rank()] = tensor.shape[0]
                dist.all_reduce(tensor_sizes, op=dist.ReduceOp.SUM)
                # move tensor_sizes to CPU and convert to int list
                tensor_sizes = tensor_sizes.cpu().numpy().tolist()

                # NOTE: Internal implementation of gather could not gather tensors of different sizes.
                # So, I do not use dist.gather(tensor, dst=0) but use dist.send(tensor, dst=0) and dist.recv(tensor, src=i) instead.

                # gather tensors on different ranks using grouped send/recv
                gathered_tensors = []
                if group.rank() == 0:
                    for i in range(group.size()):
                        if i == group.rank():
                            gathered_tensors.append(tensor)
                        else:
                            tensor_from_rk_i = torch.zeros(
                                (tensor_sizes[i],) + tensor.shape[1:],
                                dtype=tensor.dtype,
                                device="cuda",
                            )
                            dist.recv(tensor_from_rk_i, src=i)
                            gathered_tensors.append(tensor_from_rk_i)
                    gathered_tensors = torch.cat(gathered_tensors, dim=0)
                else:
                    dist.send(tensor, dst=0)
                # concatenate gathered tensors

                return (
                    gathered_tensors if group.rank() == 0 else None
                )  # only return gather tensors at rank 0

            _xyz = gather_uneven_tensors(self._xyz)
            _features_dc = gather_uneven_tensors(self._features_dc)
            _features_rest = gather_uneven_tensors(self._features_rest)
            _opacity = gather_uneven_tensors(self._opacity)
            _scaling = gather_uneven_tensors(self._scaling)
            _rotation = gather_uneven_tensors(self._rotation)
            _semantic_feature = gather_uneven_tensors(self._semantic_feature)
            entity_ids = gather_uneven_tensors(self.entity_ids)

            if group.rank() != 0:
                return

        elif args.gaussians_distribution and args.distributed_save:
            assert (
                    utils.DEFAULT_GROUP.size() > 1
            ), "distributed_save should be used with more than 1 rank."
            _xyz = self._xyz
            _features_dc = self._features_dc
            _features_rest = self._features_rest
            _opacity = self._opacity
            _scaling = self._scaling
            _rotation = self._rotation
            _semantic_feature = self._semantic_feature
            entity_ids = self.entity_ids
            if path.endswith(".ply"):
                path = (
                        path[:-4]
                        + "_rk"
                        + str(utils.GLOBAL_RANK)
                        + "_ws"
                        + str(utils.WORLD_SIZE)
                        + ".ply"
                )
        elif not args.gaussians_distribution:
            if group.rank() != 0:
                return
            _xyz = self._xyz
            _features_dc = self._features_dc
            _features_rest = self._features_rest
            _opacity = self._opacity
            _scaling = self._scaling
            _rotation = self._rotation
            _semantic_feature = self._semantic_feature
            entity_ids = self.entity_ids
            if path.endswith(".ply"):
                path = (
                        path[:-4]
                        + "_rk"
                        + str(utils.GLOBAL_RANK)
                        + "_ws"
                        + str(utils.WORLD_SIZE)
                        + ".ply"
                )

        mkdir_p(os.path.dirname(path))

        xyz = _xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = (
            _features_dc.detach()
            .transpose(1, 2)
            .flatten(start_dim=1)
            .contiguous()
            .cpu()
            .numpy()
        )
        f_rest = (
            _features_rest.detach()
            .transpose(1, 2)
            .flatten(start_dim=1)
            .contiguous()
            .cpu()
            .numpy()
        )
        opacities = _opacity.detach().cpu().numpy()
        scale = _scaling.detach().cpu().numpy()
        rotation = _rotation.detach().cpu().numpy()
        _semantic_feature = self._semantic_feature.detach().transpose(1, 2).flatten(
            start_dim=1).contiguous().cpu().numpy()
        entity_ids = self.entity_ids.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()

        utils.log_cpu_memory_usage("after change gpu tensor to cpu numpy")

        dtype_full = [
            (attribute, "f4") for attribute in self.construct_list_of_attributes()
        ]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate(
            (xyz, normals, f_dc, f_rest, opacities, scale, rotation, _semantic_feature, entity_ids), axis=1
        )
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, "vertex")

        utils.log_cpu_memory_usage(
            "after change numpy to plyelement before writing ply file"
        )
        PlyData([el]).write(path)
        utils.log_cpu_memory_usage("finish write ply file")
        # remark: max_radii2D, xyz_gradient_accum and denom are not saved here; they are save elsewhere.

    def reset_opacity(self):
        utils.LOG_FILE.write("Resetting opacity to 0.01\n")
        opacities_new = inverse_sigmoid(
            torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)
        )
        opacities_new[~self.densify_mask] = self.get_opacity[~self.densify_mask]
        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
        self._opacity = optimizable_tensors["opacity"]

    def prune_based_on_opacity(self, min_opacity):
        prune_mask = (self.get_opacity < min_opacity).squeeze()
        utils.LOG_FILE.write(
            "Pruning based on opacity. Percent: {:.2f}\n".format(
                100 * prune_mask.sum().item() / prune_mask.shape[0]
            )
        )
        self.prune_points(prune_mask)

    def distributed_load_ply(self, folder):
        # count the number of files like "point_cloud_rk0_ws4.ply"
        world_size = -1
        for f in os.listdir(folder):
            if "_ws" in f:
                world_size = int(f.split("_ws")[1].split(".")[0])
                break
        assert world_size > 0, "world_size should be greater than 1."

        catted_xyz = []
        catted_features_dc = []
        catted_features_rest = []
        catted_opacity = []
        catted_scaling = []
        catted_rotation = []
        catted_semantic_feature = []
        catted_entity_ids = []
        for rk in range(world_size):
            one_checkpoint_path = (
                    folder + "/point_cloud_rk" + str(rk) + "_ws" + str(world_size) + ".ply"
            )
            xyz, features_dc, features_extra, opacities, scales, rots, semantic_feature, entity_ids = (
                self.load_raw_ply(one_checkpoint_path)
            )
            catted_xyz.append(xyz)
            catted_features_dc.append(features_dc)
            catted_features_rest.append(features_extra)
            catted_opacity.append(opacities)
            catted_scaling.append(scales)
            catted_rotation.append(rots)
            catted_semantic_feature.append(semantic_feature)
            catted_entity_ids.append(entity_ids)
        catted_xyz = np.concatenate(catted_xyz, axis=0)
        catted_features_dc = np.concatenate(catted_features_dc, axis=0)
        catted_features_rest = np.concatenate(catted_features_rest, axis=0)
        catted_opacity = np.concatenate(catted_opacity, axis=0)
        catted_scaling = np.concatenate(catted_scaling, axis=0)
        catted_rotation = np.concatenate(catted_rotation, axis=0)
        catted_semantic_feature = np.concatenate(catted_semantic_feature, axis=0)
        catted_entity_ids = np.concatenate(catted_entity_ids, axis=0)

        self._xyz = nn.Parameter(
            torch.tensor(catted_xyz, dtype=torch.float, device="cuda").requires_grad_(
                True
            )
        )
        self._features_dc = nn.Parameter(
            torch.tensor(catted_features_dc, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self._features_rest = nn.Parameter(
            torch.tensor(catted_features_rest, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self._opacity = nn.Parameter(
            torch.tensor(
                catted_opacity, dtype=torch.float, device="cuda"
            ).requires_grad_(True)
        )
        self._scaling = nn.Parameter(
            torch.tensor(
                catted_scaling, dtype=torch.float, device="cuda"
            ).requires_grad_(True)
        )
        self._rotation = nn.Parameter(
            torch.tensor(
                catted_rotation, dtype=torch.float, device="cuda"
            ).requires_grad_(True)
        )

        self._semantic_feature = nn.Parameter(
            torch.tensor(catted_semantic_feature, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self.entity_ids = nn.Parameter(
            torch.tensor(catted_entity_ids, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )

        self.active_sh_degree = self.max_sh_degree

    def load_raw_ply(self, path):
        print("Loading ", path)
        plydata = PlyData.read(path)

        xyz = np.stack(
            (
                np.asarray(plydata.elements[0]["x"]),
                np.asarray(plydata.elements[0]["y"]),
                np.asarray(plydata.elements[0]["z"]),
            ),
            axis=1,
        )
        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

        extra_f_names = [
            p.name
            for p in plydata.elements[0].properties
            if p.name.startswith("f_rest_")
        ]
        extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
        assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
        for idx, attr_name in enumerate(extra_f_names):
            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
        features_extra = features_extra.reshape(
            (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)
        )

        count = sum(1 for name in plydata.elements[0].data.dtype.names if name.startswith("sem_"))
        semantic_feature = np.stack([np.asarray(plydata.elements[0][f"sem_{i}"]) for i in range(count)], axis=1)
        semantic_feature = np.expand_dims(semantic_feature, axis=-1)

        count2 = sum(1 for name in plydata.elements[0].data.dtype.names if name.startswith("entity_"))
        entity_ids = np.stack([np.asarray(plydata.elements[0][f"entity_{i}"]) for i in range(count2)], axis=1)
        entity_ids = np.expand_dims(entity_ids, axis=-1)

        scale_names = [
            p.name
            for p in plydata.elements[0].properties
            if p.name.startswith("scale_")
        ]
        scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
        scales = np.zeros((xyz.shape[0], len(scale_names)))
        for idx, attr_name in enumerate(scale_names):
            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

        rot_names = [
            p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
        ]
        rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
        rots = np.zeros((xyz.shape[0], len(rot_names)))
        for idx, attr_name in enumerate(rot_names):
            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        args = utils.get_args()
        # The above computation/memory is replicated on all ranks. Because initialization is small, it's ok.
        # Split the point cloud across the ranks.

        if args.gaussians_distribution and utils.WORLD_SIZE > 1:
            chunk = xyz.shape[0] // utils.WORLD_SIZE + 1
            point_ind_l = chunk * utils.LOCAL_RANK
            point_ind_r = min(chunk * (utils.LOCAL_RANK + 1), xyz.shape[0])

            xyz = np.ascontiguousarray(xyz[point_ind_l:point_ind_r])
            features_dc = np.ascontiguousarray(features_dc[point_ind_l:point_ind_r])
            features_extra = np.ascontiguousarray(
                features_extra[point_ind_l:point_ind_r]
            )
            scales = np.ascontiguousarray(scales[point_ind_l:point_ind_r])
            rots = np.ascontiguousarray(rots[point_ind_l:point_ind_r])
            opacities = np.ascontiguousarray(opacities[point_ind_l:point_ind_r])
            semantic_feature = np.ascontiguousarray(semantic_feature[point_ind_l:point_ind_r])
            entity_ids = np.ascontiguousarray(entity_ids[point_ind_l:point_ind_r])

        if args.drop_initial_3dgs_p > 0.0:
            # drop each point with probability args.drop_initial_3dgs_p
            drop_mask = np.random.rand(xyz.shape[0]) > args.drop_initial_3dgs_p
            xyz = xyz[drop_mask]
            features_dc = features_dc[drop_mask]
            features_extra = features_extra[drop_mask]
            scales = scales[drop_mask]
            rots = rots[drop_mask]
            opacities = opacities[drop_mask]
            semantic_feature = semantic_feature[drop_mask]
            entity_ids = entity_ids[drop_mask]

        return xyz, features_dc, features_extra, opacities, scales, rots, semantic_feature, entity_ids

    def one_file_load_ply(self, folder):
        path = os.path.join(folder, "point_cloud.ply")
        xyz, features_dc, features_extra, opacities, scales, rots, semantic_feature, entity_ids = self.load_raw_ply(
            path
        )

        self._xyz = nn.Parameter(
            torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        self._features_dc = nn.Parameter(
            torch.tensor(features_dc, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self._features_rest = nn.Parameter(
            torch.tensor(features_extra, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self._opacity = nn.Parameter(
            torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(
                True
            )
        )
        self._scaling = nn.Parameter(
            torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        self._rotation = nn.Parameter(
            torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)
        )

        self._semantic_feature = nn.Parameter(
            torch.tensor(semantic_feature, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self.entity_ids = nn.Parameter(
            torch.tensor(entity_ids, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self.active_sh_degree = self.max_sh_degree

    def load_ply(self, path):
        if os.path.exists(os.path.join(path, "point_cloud.ply")):
            self.one_file_load_ply(path)
        else:
            self.distributed_load_ply(path)

    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if group["name"] == name:
                stored_state = self.optimizer.state.get(group["params"][0], None)
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = torch.zeros_like(tensor)
                else:
                    stored_state["exp_avg"] = torch.zeros_like(tensor)
                    stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = stored_state["momentum_buffer"][
                        mask
                    ]
                else:
                    stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                    stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    (group["params"][0][mask].requires_grad_(True))
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    group["params"][0][mask].requires_grad_(True)
                )
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]

        self._semantic_feature = optimizable_tensors["sem_f"]
        self.entity_ids = optimizable_tensors["entity_f"]

        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]

        self.send_to_gpui_cnt = self.send_to_gpui_cnt[valid_points_mask]

        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]
        self.sum_visible_count_in_one_batch = self.sum_visible_count_in_one_batch[
            valid_points_mask
        ]

        self.densify_mask = self.densify_mask[valid_points_mask]

    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = torch.cat(
                        (
                            stored_state["momentum_buffer"],
                            torch.zeros_like(extension_tensor),
                        ),
                        dim=0,
                    )
                else:
                    stored_state["exp_avg"] = torch.cat(
                        (stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
                        dim=0,
                    )
                    stored_state["exp_avg_sq"] = torch.cat(
                        (
                            stored_state["exp_avg_sq"],
                            torch.zeros_like(extension_tensor),
                        ),
                        dim=0,
                    )

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    torch.cat(
                        (group["params"][0], extension_tensor), dim=0
                    ).requires_grad_(True)
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    torch.cat(
                        (group["params"][0], extension_tensor), dim=0
                    ).requires_grad_(True)
                )
                optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(
            self,
            new_xyz,
            new_features_dc,
            new_features_rest,
            new_opacities,
            new_scaling,
            new_rotation,
            new_semantic_feature,
            new_entity_ids,
            new_send_to_gpui_cnt,
    ):
        d = {
            "xyz": new_xyz,
            "f_dc": new_features_dc,
            "f_rest": new_features_rest,
            "opacity": new_opacities,
            "scaling": new_scaling,
            "rotation": new_rotation,
            "sem_f": new_semantic_feature,
            "entity_f": new_entity_ids,
        }

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        self._semantic_feature = optimizable_tensors["sem_f"]
        self.entity_ids = optimizable_tensors["entity_f"]

        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self.sum_visible_count_in_one_batch = torch.zeros(
            (self.get_xyz.shape[0]), device="cuda"
        )

        self.send_to_gpui_cnt = torch.cat(
            (self.send_to_gpui_cnt, new_send_to_gpui_cnt), dim=0
        )

    def densify_and_split(self, grads, grad_threshold, scene_extent, densify_mask, N=2):
        n_init_points = self.get_xyz.shape[0]
        # Extract points that satisfy the gradient condition
        padded_grad = torch.zeros((n_init_points), device="cuda")
        padded_grad[: grads.shape[0]] = grads.squeeze()
        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(
            selected_pts_mask,
            torch.max(self.get_scaling, dim=1).values
            > self.percent_dense * scene_extent,
        )

        selected_pts_mask = torch.logical_and(selected_pts_mask, self.densify_mask)
        stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
        means = torch.zeros((stds.size(0), 3), device="cuda")
        samples = torch.normal(mean=means, std=stds)
        # [N * number of selected points, 3]

        utils.get_log_file().write(
            "Number of split gaussians: {}\n".format(selected_pts_mask.sum().item())
        )
        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[
            selected_pts_mask
        ].repeat(N, 1)
        new_scaling = self.scaling_inverse_activation(
            self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)
        )
        new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
        new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
        new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
        new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)

        new_semantic_feature = self._semantic_feature[selected_pts_mask].repeat(N, 1, 1)
        new_entity_ids = self.entity_ids[selected_pts_mask].repeat(N, 1, 1)

        new_send_to_gpui_cnt = self.send_to_gpui_cnt[selected_pts_mask].repeat(N, 1)

        self.densification_postfix(
            new_xyz,
            new_features_dc,
            new_features_rest,
            new_opacity,
            new_scaling,
            new_rotation,
            new_semantic_feature,
            new_entity_ids,
            new_send_to_gpui_cnt,
        )

        prune_filter = torch.cat(
            (
                selected_pts_mask,
                torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool),
            )
        )
        self.densify_mask = torch.cat(
            (self.densify_mask, torch.ones(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
        self.prune_points(prune_filter)

    def densify_and_clone(self, grads, grad_threshold, scene_extent, densify_mask):
        # Extract points that satisfy the gradient condition
        selected_pts_mask = torch.where(
            torch.norm(grads, dim=-1) >= grad_threshold, True, False
        )
        selected_pts_mask = torch.logical_and(
            selected_pts_mask,
            torch.max(self.get_scaling, dim=1).values
            <= self.percent_dense * scene_extent,
        )

        selected_pts_mask = torch.logical_and(selected_pts_mask, densify_mask)

        utils.get_log_file().write(
            "Number of cloned gaussians: {}\n".format(selected_pts_mask.sum().item())
        )
        new_xyz = self._xyz[selected_pts_mask]
        new_features_dc = self._features_dc[selected_pts_mask]
        new_features_rest = self._features_rest[selected_pts_mask]
        new_opacities = self._opacity[selected_pts_mask]
        new_scaling = self._scaling[selected_pts_mask]
        new_rotation = self._rotation[selected_pts_mask]
        new_semantic_feature = self._semantic_feature[selected_pts_mask]
        new_entity_ids = self.entity_ids[selected_pts_mask]

        new_send_to_gpui_cnt = self.send_to_gpui_cnt[selected_pts_mask]

        self.densification_postfix(
            new_xyz,
            new_features_dc,
            new_features_rest,
            new_opacities,
            new_scaling,
            new_rotation,
            new_semantic_feature,
            new_entity_ids,
            new_send_to_gpui_cnt,
        )
        self.densify_mask = torch.cat(
            (self.densify_mask, torch.ones(selected_pts_mask.sum(), device="cuda", dtype=bool)))

    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, total_gs_num, densify_mask=None,
                          max_num=2000_000, skip_densify=False):
        args = utils.get_args()
        self.densify_mask = densify_mask
        if not args.gaussians_distribution and utils.DEFAULT_GROUP.size() > 1:
            torch.distributed.all_reduce(
                self.max_radii2D, op=dist.ReduceOp.MAX, group=utils.DP_GROUP
            )
            torch.distributed.all_reduce(
                self.xyz_gradient_accum, op=dist.ReduceOp.SUM, group=utils.DP_GROUP
            )
            torch.distributed.all_reduce(
                self.denom, op=dist.ReduceOp.SUM, group=utils.DP_GROUP
            )

        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0

        densification_stats = {}
        densification_stats["view_space_grad"] = grads.mean().item()
        densification_stats["view_space_grad_max"] = grads.max().item()

        if total_gs_num <= max_num and not skip_densify:
            self.densify_and_clone(grads, max_grad, extent, self.densify_mask)
            self.densify_and_split(grads, max_grad, extent, self.densify_mask)

        prune_mask = (self.get_opacity < min_opacity).squeeze()
        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            if total_gs_num <= max_num:
                # NOTE: this is bug in its implementation.
                assert torch.all(
                    self.max_radii2D == 0
                ), "In its implementation, max_radii2D is all 0. This is a bug."
                assert torch.all(
                    big_points_vs == False
                ), "In its implementation, big_points_vs is all False. This is a bug."
            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
            prune_mask = torch.logical_or(
                torch.logical_or(prune_mask, big_points_vs), big_points_ws
            )
            prune_mask = torch.logical_and(prune_mask, self.densify_mask)
        self.prune_points(prune_mask)

        torch.cuda.empty_cache()

    def add_densification_stats_v2(
            self, viewspace_point_tensor, gs_update_filter, mean2d_update_filter
    ):  # the :2] is a weird implementation. It is because viewspace_point_tensor is (N, 3) tensor.
        self.xyz_gradient_accum[gs_update_filter] += torch.norm(
            viewspace_point_tensor.grad[mean2d_update_filter, :2], dim=-1, keepdim=True
        )
        self.denom[gs_update_filter] += 1

    def add_densification_stats(
            self, viewspace_point_tensor, update_filter
    ):  # the :2] is a weird implementation. It is because viewspace_point_tensor is (N, 3) tensor.
        self.xyz_gradient_accum[update_filter] += torch.norm(
            viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
        )
        self.denom[update_filter] += 1

    def gsplat_add_densification_stats(
            self, viewspace_point_tensor_grad, update_filter, width, height
    ):  # the :2] is a weird implementation. It is because viewspace_point_tensor is (N, 3) tensor.
        grad = viewspace_point_tensor_grad  # (N, 2)
        # Normalize the gradients to [-1, 1] screen size
        grad[:, 0] *= width * 0.5
        grad[:, 1] *= height * 0.5
        self.xyz_gradient_accum[update_filter] += torch.norm(
            grad[update_filter, :2], dim=-1, keepdim=True
        )
        self.denom[update_filter] += 1

    def group_for_redistribution(self):
        args = utils.get_args()
        if args.gaussians_distribution:
            return utils.DEFAULT_GROUP
        else:
            return utils.SingleGPUGroup()

    def all2all_gaussian_state(self, state, destination, i2j_send_size):
        comm_group = self.group_for_redistribution()

        # state: (N, ...) tensor
        state_to_gpuj = []
        state_from_gpuj = []
        for j in range(comm_group.size()):  # ugly implementation.
            state_to_gpuj.append(state[destination == j, ...].contiguous())
            state_from_gpuj.append(
                torch.zeros(
                    (i2j_send_size[j][comm_group.rank()], *state.shape[1:]),
                    device="cuda",
                )
            )

        # print(f"before all_to_all, ws={comm_group.size()}, rank={comm_group.rank()}")

        torch.distributed.all_to_all(state_from_gpuj, state_to_gpuj, group=comm_group)

        # print(f"after all_to_all, ws={comm_group.size()}, rank={comm_group.rank()}")

        state_from_remote = torch.cat(
            state_from_gpuj, dim=0
        ).contiguous()  # it stucks at here.
        # print(f"state_from_remote, ws={comm_group.size()}, rank={comm_group.rank()}")
        return state_from_remote

    def all2all_tensors_in_optimizer_implementation_1(self, destination, i2j_send_size):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:

                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = self.all2all_gaussian_state(
                        stored_state["momentum_buffer"], destination, i2j_send_size
                    )
                else:
                    stored_state["exp_avg"] = self.all2all_gaussian_state(
                        stored_state["exp_avg"], destination, i2j_send_size
                    )
                    stored_state["exp_avg_sq"] = self.all2all_gaussian_state(
                        stored_state["exp_avg_sq"], destination, i2j_send_size
                    )

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    self.all2all_gaussian_state(
                        group["params"][0], destination, i2j_send_size
                    ),
                    requires_grad=True,
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    self.all2all_gaussian_state(
                        group["params"][0], destination, i2j_send_size
                    ),
                    requires_grad=True,
                )
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def get_all_optimizer_states(self):
        all_tensors = []
        all_shapes = []
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    all_tensors.append(stored_state["momentum_buffer"])
                    all_shapes.append(stored_state["momentum_buffer"].shape)
                else:
                    all_tensors.append(stored_state["exp_avg"])
                    all_shapes.append(stored_state["exp_avg"].shape)

                    all_tensors.append(stored_state["exp_avg_sq"])
                    all_shapes.append(stored_state["exp_avg_sq"].shape)

                all_tensors.append(group["params"][0])
                all_shapes.append(group["params"][0].shape)
            else:
                all_tensors.append(group["params"][0])
                all_shapes.append(group["params"][0].shape)

        return all_tensors, all_shapes

    def update_all_optimizer_states(self, updated_tensors):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = updated_tensors.pop(
                        0
                    ).contiguous()
                else:
                    stored_state["exp_avg"] = updated_tensors.pop(0).contiguous()
                    stored_state["exp_avg_sq"] = updated_tensors.pop(0).contiguous()

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    updated_tensors.pop(0).contiguous(), requires_grad=True
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    updated_tensors.pop(0).contiguous(), requires_grad=True
                )
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def all2all_tensors_in_optimizer_implementation_2(self, destination, i2j_send_size):
        # merge into one single all2all kernal launch.

        # get all optimizer states for all2all
        all_tensors, all_shapes = self.get_all_optimizer_states()
        # flatten all tensors with start_dim=1, then concate them at dim=1
        all_tensors_flatten = [tensor.flatten(start_dim=1) for tensor in all_tensors]
        all_tensors_catted = torch.cat(all_tensors_flatten, dim=1).contiguous()
        all_tensors_flatten = None  # release memory

        # all2all
        all_remote_tensors_catted = self.all2all_gaussian_state(
            all_tensors_catted, destination, i2j_send_size
        )
        all_tensors_catted = None  # release memory

        # split all_tensors_catted to original shapes
        all_remote_tensors_flatten = torch.split(
            all_remote_tensors_catted,
            [shape[1:].numel() for shape in all_shapes],
            dim=1,
        )
        all_remote_tensors_catted = None  # release memory
        all_remote_tensors = [
            tensor.view(tensor.shape[:1] + shape[1:])
            for tensor, shape in zip(all_remote_tensors_flatten, all_shapes)
        ]
        all_remote_tensors_flatten = None  # release memory

        # update optimizer states
        optimizable_tensors = self.update_all_optimizer_states(all_remote_tensors)
        all_remote_tensors = None

        return optimizable_tensors

    def all2all_tensors_in_optimizer(self, destination, i2j_send_size):
        return self.all2all_tensors_in_optimizer_implementation_1(
            destination, i2j_send_size
        )
        # return self.all2all_tensors_in_optimizer_implementation_2(destination, i2j_send_size)
        # when cross node all2all on perl, implementation_2 will get stuck at 1600 iterations, I do not know the reason.

    def get_destination_1(self, world_size):
        # norm p=0
        return torch.randint(0, world_size, (self.get_xyz.shape[0],), device="cuda")

    def need_redistribute_gaussians(self, group):
        args = utils.get_args()
        if group.size() == 1:
            return False
        if utils.get_denfify_iter() == args.redistribute_gaussians_frequency:
            # do redistribution after the first densification.
            return True
        local_n_3dgs = self.get_xyz.shape[0]
        all_local_n_3dgs = [None for _ in range(group.size())]
        torch.distributed.all_gather_object(all_local_n_3dgs, local_n_3dgs,
                                            group=group)
        if min(all_local_n_3dgs) * args.redistribute_gaussians_threshold < max(
                all_local_n_3dgs
        ):
            return True
        return False

    def redistribute_gaussians(self):
        args = utils.get_args()
        if args.redistribute_gaussians_mode == "no_redistribute":
            return

        comm_group_for_redistribution = self.group_for_redistribution()
        if not self.need_redistribute_gaussians(comm_group_for_redistribution):
            return

        # Get each 3dgs' destination GPU.
        if args.redistribute_gaussians_mode == "random_redistribute":
            # random redistribution to balance the number of gaussians on each GPU.
            destination = self.get_destination_1(comm_group_for_redistribution.size())
        else:
            raise ValueError(
                "Invalid redistribute_gaussians_mode: "
                + args.redistribute_gaussians_mode
            )

        # Count the number of 3dgs to be sent to each GPU.
        local2j_send_size = torch.bincount(
            destination, minlength=comm_group_for_redistribution.size()
        ).int()
        assert (
                len(local2j_send_size) == comm_group_for_redistribution.size()
        ), "local2j_send_size: " + str(local2j_send_size)

        i2j_send_size = torch.zeros(
            (
                comm_group_for_redistribution.size(),
                comm_group_for_redistribution.size(),
            ),
            dtype=torch.int,
            device="cuda",
        )
        torch.distributed.all_gather_into_tensor(
            i2j_send_size, local2j_send_size, group=comm_group_for_redistribution
        )
        i2j_send_size = i2j_send_size.cpu().numpy().tolist()
        # print("rank", utils.LOCAL_RANK, "local2j_send_size: ", local2j_send_size, "i2j_send_size: ", i2j_send_size)

        optimizable_tensors = self.all2all_tensors_in_optimizer(
            destination, i2j_send_size
        )
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        self._semantic_feature = optimizable_tensors["sem_f"]
        self.entity_ids = optimizable_tensors["entity_f"]

        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self.sum_visible_count_in_one_batch = torch.zeros(
            (self.get_xyz.shape[0]), device="cuda"
        )
        # NOTE: This function is called right after desify_and_prune. Therefore self.xyz_gradient_accum, self.denom and self.max_radii2D are all zero.
        # We do not need to all2all them here.

        self.send_to_gpui_cnt = torch.zeros(
            (self.get_xyz.shape[0], comm_group_for_redistribution.size()),
            dtype=torch.int,
            device="cuda",
        )

        torch.cuda.empty_cache()


def quaternion_multiply(q1, q2):
    w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
    w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]

    w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
    x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
    y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
    z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2

    return torch.stack((w, x, y, z), dim=-1)


def get_sparse_ids(tensors):
    sparse_ids = None
    with torch.no_grad():
        for tensor in tensors:
            # Apply torch.nonzero()
            nonzero_indices = torch.nonzero(tensor)
            # Extract the row indices
            row_indices = nonzero_indices[:, 0]
            # Count unique rows
            if sparse_ids is None:
                sparse_ids = row_indices
            else:
                sparse_ids = torch.cat((sparse_ids, row_indices))

        sparse_ids = torch.unique(sparse_ids, sorted=True)
        return sparse_ids


def sync_gradients_sparsely(gaussians, group):
    with torch.no_grad():
        sparse_ids = get_sparse_ids(
            [gaussians._xyz.grad.data]
        )  # sparse ids are non-zero ids
        # get boolean mask of sparse ids
        sparse_ids_mask = torch.zeros(
            (gaussians._xyz.shape[0]), dtype=torch.bool, device="cuda"
        )
        sparse_ids_mask[sparse_ids] = True

        torch.distributed.all_reduce(sparse_ids_mask, op=dist.ReduceOp.SUM, group=group)

        def sync_grads(data):
            sparse_grads = data.grad.data[
                sparse_ids_mask
            ].contiguous()  # contiguous() memory is needed for collective communication.
            torch.distributed.all_reduce(
                sparse_grads, op=dist.ReduceOp.SUM, group=group
            )
            data.grad.data[sparse_ids_mask] = sparse_grads

        sync_grads(gaussians._xyz)
        sync_grads(gaussians._features_dc)
        sync_grads(gaussians._features_rest)
        sync_grads(gaussians._opacity)
        sync_grads(gaussians._scaling)
        sync_grads(gaussians._rotation)
        sync_grads(gaussians._semantic_feature)
        sync_grads(gaussians.entity_ids)
        # We must optimize this, because there should be large kernel launch overhead.

    log_file = utils.get_log_file()
    non_zero_indices_cnt = sparse_ids_mask.sum().item()
    total_indices_cnt = sparse_ids_mask.shape[0]
    log_file.write(
        "iterations: [{}, {}) non_zero_indices_cnt: {} total_indices_cnt: {} ratio: {}\n".format(
            utils.get_cur_iter(),
            utils.get_cur_iter() + utils.get_args().bsz,
            non_zero_indices_cnt,
            total_indices_cnt,
            non_zero_indices_cnt / total_indices_cnt,
        )
    )


def sync_gradients_densely(gaussians, group):
    with torch.no_grad():
        def sync_grads(data):
            torch.distributed.all_reduce(
                data.grad.data, op=dist.ReduceOp.SUM, group=group
            )

        sync_grads(gaussians._xyz)
        sync_grads(gaussians._features_dc)
        sync_grads(gaussians._features_rest)
        sync_grads(gaussians._opacity)
        sync_grads(gaussians._scaling)
        sync_grads(gaussians._rotation)
        sync_grads(gaussians._semantic_feature)
        sync_grads(gaussians.entity_ids)


def sync_gradients_fused_densely(gaussians, group):
    with torch.no_grad():
        # 1. cat all parameters' grad to a single tensor
        # 2. allreduce
        # 3. split the allreduced tensor to each parameter's grad
        all_params_grads = [
            param.grad.data
            for param in [
                gaussians._xyz,
                gaussians._features_dc,
                gaussians._features_rest,
                gaussians._opacity,
                gaussians._scaling,
                gaussians._rotation,
                gaussians._semantic_feature,
                gaussians.entity_ids,
            ]
        ]
        all_params_grads_dim1 = [param_grad.shape[1] for param_grad in all_params_grads]
        catted_params_grads = torch.cat(all_params_grads, dim=1).contiguous()
        torch.distributed.all_reduce(
            catted_params_grads, op=dist.ReduceOp.SUM, group=group
        )
        split_params_grads = torch.split(
            catted_params_grads, all_params_grads_dim1, dim=1
        )
        for param_grad, split_param_grad in zip(all_params_grads, split_params_grads):
            param_grad.copy_(split_param_grad)


def sync_gradients_fused_sparsely(gaussians, group):
    raise NotImplementedError("Fused sparse sync gradients is not implemented yet.")


class ControlGaussianModel:

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm

        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.entity_activation = torch.sigmoid
        self.inverse_entity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize

    def __init__(self, sh_degree: int, semantic_feature_dim: int, with_motion_mask=False, vis_control_gs=False):
        self.active_sh_degree = 0
        self.max_sh_degree = sh_degree
        self._xyz = torch.empty(0)
        self.vis_control_gs = vis_control_gs
        if vis_control_gs:
            self._features_dc = torch.empty(0)
            self._features_rest = torch.empty(0)
            self._scaling = torch.empty(0)
            self._rotation = torch.empty(0)
            self._opacity = torch.empty(0)
            self._semantic_feature = torch.empty(0)
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(
            0
        ) 
        self.denom = torch.empty(0)
        self.optimizer = None
        self.percent_dense = 0
        self.spatial_lr_scale = 0

        self.entity_ids = torch.empty(0)
        self.entity_cls = torch.empty(0)
        self.entity_cls_num = torch.empty(0)
        self.is_control_init = False
        self.is_entity_init = False
        self.semantic_feature_dim = semantic_feature_dim

        self._node_radius = torch.empty(0)
        self._node_weight = torch.empty(0)

        self.setup_functions()

    def entity_init(self):
        self.gs_binary_ids = torch.round(self.get_entity).to(torch.long)
        if self.use_truncated_binary:
            self.entity_cls = entity_binary_convert_tensor(
                torch.arange(0, self.entity_cls_num).unsqueeze(0).unsqueeze(0).cuda()).permute(2, 1, 0)
        else:
            self.entity_cls = torch.unique(self.gs_binary_ids, dim=0)
        self.entity_cls_num = self.entity_cls.shape[0]
        self.is_entity_init = True

    def update_sem_centroid(self):
        self.entity_init()
        self.semantic_centroid = torch.zeros([self.entity_cls_num, self._semantic_feature.shape[-1]],
                                             device=self._xyz.device)
        # self.xyz_centroid = torch.zeros([self.entity_cls_num, 3])
        for i in range(self.entity_cls_num):
            cur_entity_id = self.entity_cls[i]
            row_index = torch.where(torch.eq(self.gs_binary_ids, cur_entity_id).squeeze(1).all(dim=1))[0]
            self.semantic_centroid[i, :] = torch.mean(self._semantic_feature[row_index], dim=0)
            # self.xyz_centroid[i, :] = torch.mean(self._xyz[row_index], dim=0)

    def semantic_centroid(self):
        self.entity_init()
        self.semantic_centroid = torch.zeros([self.entity_cls_num, self._semantic_feature.shape[-1]])
        for i in range(self.entity_cls_num):
            cur_entity_id = self.eneity_cls[i]
            row_index = torch.where(torch.eq(self.gs_binary_ids, cur_entity_id).all(dim=1))[0]
            self.semantic_centroid[i, :] = torch.mean(self._semantic_feature[row_index], dim=0)

    def entity_extract_mask(self, entity_id):
        entity_mask = torch.where(self.entity_ids == entity_id)[0]
        return entity_mask

    def capture(self):
        if self.vis_control_gs:
            return (
                self.active_sh_degree,
                self._xyz,
                self._features_dc,
                self._features_rest,
                self._scaling,
                self._rotation,
                self._opacity,
                self.max_radii2D,
                self._semantic_feature,
                self.entity_ids,
                self.xyz_gradient_accum, 
                self.denom,
                self.optimizer.state_dict(),
                self.spatial_lr_scale,
            )
        else:
            return (
                self.active_sh_degree,
                self._xyz,
                self.max_radii2D,
                self.entity_ids,
                self.xyz_gradient_accum,
                self.denom,
                self.optimizer.state_dict(),
                self.spatial_lr_scale,
            )

    def restore(self, model_args, training_args):
        if self.vis_control_gs:
            (
                self.active_sh_degree,
                self._xyz,
                self._features_dc,
                self._features_rest,
                self._scaling,
                self._rotation,
                self._opacity,
                self.max_radii2D,
                self._semantic_feature,
                self.entity_ids,
                xyz_gradient_accum,
                denom,
                opt_dict,
                self.spatial_lr_scale,
            ) = model_args
        else:
            (
                self.active_sh_degree,
                self._xyz,
                self.max_radii2D,
                self.entity_ids,
                xyz_gradient_accum,
                denom,
                opt_dict,
                self.spatial_lr_scale,
            ) = model_args
        self.entity_init()
        self.training_setup(training_args)
        self.xyz_gradient_accum = (
            xyz_gradient_accum 
        )
        self.denom = denom
        if opt_dict is not None:
            self.optimizer.load_state_dict(opt_dict)

    def param_names(self):
        if self.vis_control_gs:
            return ['_xyz', '_features_dc', '_features_rest', '_scaling', '_rotation', '_opcaity', 'max_radii2D',
                    'xyz_gradient_accum', 'entity_ids', '_node_radius', '_node_weight']
        else:
            return ['_xyz', 'max_radii2D', 'xyz_gradient_accum', 'entity_ids', '_node_radius', '_node_weight']

    @property
    def get_entity(self):
        return self.entity_activation(self.entity_ids)

    @property
    def get_scaling(self):
        return self.scaling_activation(self._scaling)

    @property
    def get_rotation(self):
        return self.rotation_activation(self._rotation)

    @property
    def get_xyz(self):
        return self._xyz

    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)

    @property
    def get_opacity(self):
        return self.opacity_activation(self._opacity)

    @property
    def get_semantic_feature(self):
        return self._semantic_feature

    def get_covariance(self, scaling_modifier=1, d_rotation=None, gs_rot_bias=None):
        if d_rotation is not None:
            rotation = quaternion_multiply(self._rotation, d_rotation)
        else:
            rotation = self._rotation
        if gs_rot_bias is not None:
            rotation = rotation / rotation.norm(dim=-1, keepdim=True)
            rotation = quaternion_multiply(gs_rot_bias, rotation)
        return self.covariance_activation(self.get_scaling, scaling_modifier, rotation)

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def get_rotation_bias(self, rotation_bias=None, gs_detach=False, entity_index=None):
        rotation_bias = rotation_bias if rotation_bias is not None else 0.
        if not gs_detach:
            if entity_index is not None:
                return self.rotation_activation(self._rotation.index_add(0, entity_index, rotation_bias))
            else:
                return self.rotation_activation(self._rotation)
        else:
            if entity_index is not None:
                return self.rotation_activation(self._rotation.detach().index_add(0, entity_index, rotation_bias))
            else:
                return self.rotation_activation(self._rotation.detach() + rotation_bias)

    # def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
    def create_from_scnene_gs(self, all_xyz, spatial_lr_scale, control_gs_indices, control_gs_entity_idxs, scene_range,
                              extra_attr):
        log_file = utils.get_log_file()
        # loading could replicated on all ranks.
        self.spatial_lr_scale = spatial_lr_scale
        control_gs_xyz = all_xyz[control_gs_indices].clone()
        if extra_attr is not None:
            all_feature_dc, all_feature_rest, all_semantic_feature, all_scaling, all_rotation, all_opacity = extra_attr
            control_gs_feature_dc = all_feature_dc[control_gs_indices].clone()
            control_gs_feature_rest = all_feature_rest[control_gs_indices].clone()
            control_gs_semantic_feature = all_semantic_feature[control_gs_indices].clone()
            control_gs_scaling = all_scaling[control_gs_indices].clone()
            control_gs_rotation = all_rotation[control_gs_indices].clone()
            control_gs_opacity = all_opacity[control_gs_indices].clone()

        if utils.GLOBAL_RANK == 0:
            print(
                "Number of control_points before initialization : ", all_xyz.shape[0]
            )

        args = utils.get_args()
        if (
                args.gaussians_distribution
        ):  # shard 3dgs storage across all GPU including dp and mp groups.
            shard_world_size = utils.DEFAULT_GROUP.size()
            shard_rank = utils.DEFAULT_GROUP.rank()

            point_ind_l, point_ind_r = utils.get_local_chunk_l_r(
                control_gs_xyz.shape[0], shard_world_size, shard_rank
            )
            control_gs_xyz = control_gs_xyz[point_ind_l:point_ind_r].contiguous()
            control_gs_entity_idxs = control_gs_entity_idxs[point_ind_l:point_ind_r].contiguous()
            if extra_attr is not None:
                control_gs_feature_dc = control_gs_feature_dc[point_ind_l:point_ind_r].contiguous()
                control_gs_feature_rest = control_gs_feature_rest[point_ind_l:point_ind_r].contiguous()
                control_gs_semantic_feature = control_gs_semantic_feature[point_ind_l:point_ind_r].contiguous()
                control_gs_scaling = control_gs_scaling[point_ind_l:point_ind_r].contiguous()
                control_gs_rotation = control_gs_rotation[point_ind_l:point_ind_r].contiguous()
                control_gs_opacity = control_gs_opacity[point_ind_l:point_ind_r].contiguous()
            log_file.write(
                "rank: {}, Number of initialized points: {}\n".format(
                    utils.GLOBAL_RANK, control_gs_xyz.shape[0]
                )
            )

        self._xyz = nn.Parameter(control_gs_xyz.contiguous().requires_grad_(True)).to("cuda")
        self.entity_ids = nn.Parameter(control_gs_entity_idxs.clone().contiguous().requires_grad_(True)).to("cuda")
        if extra_attr is not None:
            self._features_dc = nn.Parameter(control_gs_feature_dc.contiguous().requires_grad_(True)).to("cuda")
            self._features_rest = nn.Parameter(control_gs_feature_rest.contiguous().requires_grad_(True)).to("cuda")
            self._semantic_feature = nn.Parameter(control_gs_semantic_feature.contiguous().requires_grad_(True)).to(
                "cuda")
            self._scaling = nn.Parameter(control_gs_scaling.contiguous().requires_grad_(True)).to("cuda")
            self._rotation = nn.Parameter(control_gs_rotation.contiguous().requires_grad_(True)).to("cuda")
            self._opacity = nn.Parameter(control_gs_opacity.contiguous().requires_grad_(True)).to("cuda")
        self._node_radius = nn.Parameter(
            torch.log(.1 * scene_range + 1e-7) * torch.ones_like(self._xyz[:, :1]).float().to("cuda"))
        self._node_weight = nn.Parameter(torch.ones_like(torch.zeros_like(self._xyz[:, :1]))).to("cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self.sum_visible_count_in_one_batch = torch.zeros(
            (self.get_xyz.shape[0]), device="cuda"
        )

    def all_parameters(self):
        if self.vis_control_gs:
            return [
                self._xyz,
                self._features_dc,
                self._features_rest,
                self._scaling,
                self._rotation,
                self._opacity,
                self._semantic_feature,
                self.entity_ids,
                self._node_radius,
                self._node_weight
            ]
        else:
            return [
                self._xyz,
                self.entity_ids,
                self._node_radius,
                self._node_weight
            ]

    def training_setup(self, training_args):
        self.percent_dense = training_args.percent_dense
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")

        shard_world_size = self.group_for_redistribution().size()
        self.send_to_gpui_cnt = torch.zeros(
            (self.get_xyz.shape[0], shard_world_size), dtype=torch.int, device="cuda"
        )

        args = utils.get_args()
        log_file = utils.get_log_file()
        if self.vis_control_gs:
            l = [
                {
                    "params": [self._xyz],
                    "lr": training_args.position_lr_init
                          * self.spatial_lr_scale
                          * args.lr_scale_pos_and_scale,
                    "name": "xyz",
                },
                {
                    "params": [self._features_dc],
                    "lr": training_args.feature_lr,
                    "name": "f_dc",
                },
                {
                    "params": [self._features_rest],
                    "lr": training_args.feature_lr / 20.0,
                    "name": "f_rest",
                },
                {
                    "params": [self._opacity],
                    "lr": training_args.opacity_lr,
                    "name": "opacity",
                },
                {
                    "params": [self._scaling],
                    "lr": training_args.scaling_lr * args.lr_scale_pos_and_scale,
                    "name": "scaling",
                },
                {
                    "params": [self._rotation],
                    "lr": training_args.rotation_lr,
                    "name": "rotation",
                },
                {
                    "params": [self.entity_ids],
                    "lr": training_args.feature_lr,
                    "name": "entity_f",
                },
                {
                    "params": [self._semantic_feature],
                    "lr": training_args.feature_lr,
                    "name": "sem_f",
                },
                {
                    "params": [self._node_radius],
                    "lr": training_args.node_radius_lr,
                    "name": "node_radius",
                },
                {
                    "params": [self._node_weight],
                    "lr": training_args.node_weight_lr,
                    "name": "node_weight",
                }
            ]
        else:
            l = [
                {
                    "params": [self._xyz],
                    "lr": training_args.position_lr_init
                          * self.spatial_lr_scale
                          * args.lr_scale_pos_and_scale,
                    "name": "xyz",
                },
                {
                    "params": [self.entity_ids],
                    "lr": training_args.feature_lr,
                    "name": "entity_f",
                },

                {
                    "params": [self._node_radius],
                    "lr": training_args.node_radius_lr,
                    "name": "node_radius",
                },
                {
                    "params": [self._node_weight],
                    "lr": training_args.node_weight_lr,
                    "name": "node_weight",
                }
            ]

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        # self.optimizer = torch.optim.SGD(l, lr=0.0, momentum=0.1)

        bsz = utils.get_args().bsz
        for param_group in self.optimizer.param_groups:
            if training_args.lr_scale_mode == "linear":
                lr_scale = bsz
                param_group["lr"] *= lr_scale
            elif training_args.lr_scale_mode == "sqrt":
                lr_scale = np.sqrt(bsz)
                param_group["lr"] *= lr_scale
                if "eps" in param_group:  # Adam
                    param_group["eps"] /= lr_scale
                    param_group["betas"] = [beta ** bsz for beta in param_group["betas"]]
                    # utils.print_rank_0(param_group["name"] + " betas: " + str(param_group["betas"]))
                    log_file.write(
                        param_group["name"]
                        + " betas: "
                        + str(param_group["betas"])
                        + "\n"
                    )
            elif training_args.lr_scale_mode == "accumu":
                lr_scale = 1
            else:
                assert (
                    False
                ), f"lr_scale_mode {training_args.lr_scale_mode} not supported."

        self.xyz_scheduler_args = get_expon_lr_func(
            lr_init=training_args.position_lr_init
                    * self.spatial_lr_scale
                    * lr_scale
                    * args.lr_scale_pos_and_scale,
            lr_final=training_args.position_lr_final
                     * self.spatial_lr_scale
                     * lr_scale
                     * args.lr_scale_pos_and_scale,
            lr_delay_mult=training_args.position_lr_delay_mult,
            max_steps=training_args.position_lr_max_steps,
        )

        utils.check_initial_gpu_memory_usage("after training_setup")

    def log_gaussian_stats(self):
        # log the statistics of the gaussian model
        # number of total 3dgs on this rank
        num_3dgs = self._xyz.shape[0]
        # average size of 3dgs
        # avg_size = torch.mean(torch.max(self.get_scaling, dim=1).values).item()
        # average opacity
        # avg_opacity = torch.mean(self.get_opacity).item()
        stats = {
            "num_3dgs": num_3dgs,
            # "avg_size": avg_size,
            # "avg_opacity": avg_opacity,
        }

        # get the exp_avg, exp_avg_sq state for all parameters
        exp_avg_dict = {}
        exp_avg_sq_dict = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" in stored_state:
                    exp_avg_dict[group["name"]] = torch.mean(
                        torch.norm(stored_state["exp_avg"], dim=-1)
                    ).item()
                    exp_avg_sq_dict[group["name"]] = torch.mean(
                        torch.norm(stored_state["exp_avg_sq"], dim=-1)
                    ).item()
        return stats, exp_avg_dict, exp_avg_sq_dict

    def sync_gradients_for_replicated_3dgs_storage(self, batched_screenspace_pkg):
        args = utils.get_args()

        if "visible_count" in args.grad_normalization_mode:
            # allgather visibility filder from all dp workers, so that each worker contains the visibility filter of all data points.
            batched_locally_preprocessed_visibility_filter_int = [
                x.int()
                for x in batched_screenspace_pkg[
                    "batched_locally_preprocessed_visibility_filter"
                ]
            ]
            sum_batched_locally_preprocessed_visibility_filter_int = torch.sum(
                torch.stack(batched_locally_preprocessed_visibility_filter_int), dim=0
            )
            batched_screenspace_pkg[
                "sum_batched_locally_preprocessed_visibility_filter_int"
            ] = sum_batched_locally_preprocessed_visibility_filter_int

        if args.sync_grad_mode == "dense":
            sync_func = sync_gradients_densely
        elif args.sync_grad_mode == "sparse":
            sync_func = sync_gradients_sparsely
        elif args.sync_grad_mode == "fused_dense":
            sync_func = sync_gradients_fused_densely
        elif args.sync_grad_mode == "fused_sparse":
            sync_func = sync_gradients_fused_sparsely
        else:
            assert False, f"sync_grad_mode {args.sync_grad_mode} not supported."

        if not args.gaussians_distribution and utils.DEFAULT_GROUP.size() > 1:
            sync_func(self, utils.DEFAULT_GROUP)

    def update_learning_rate(self, iteration):
        """Learning rate scheduling per step"""
        for param_group in self.optimizer.param_groups:
            if param_group["name"] == "xyz":
                lr = self.xyz_scheduler_args(iteration)
                param_group["lr"] = lr
                return lr

    def construct_list_of_attributes(self):
        l = ["x", "y", "z", "nx", "ny", "nz"]
        # All channels except the 3 DC
        if self.vis_control_gs:
            for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
                l.append("f_dc_{}".format(i))
            for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
                l.append("f_rest_{}".format(i))
            l.append("opacity")
            for i in range(self._scaling.shape[1]):
                l.append("scale_{}".format(i))
            for i in range(self._rotation.shape[1]):
                l.append("rot_{}".format(i))
            for i in range(self._semantic_feature.shape[1] * self._semantic_feature.shape[2]):
                l.append('sem_{}'.format(i))
        for i in range(self.entity_ids.shape[1] * self.entity_ids.shape[2]):
            l.append('entity_{}'.format(i))
        l.append('node_radius')
        l.append('node_weight')
        return l

    def save_ply(
            self, path
    ):  # here, we should be in torch.no_grad() context. train.py ensures that.
        args = utils.get_args()
        # _xyz = _features_dc = _features_rest = _opacity = _scaling = _rotation = None
        _xyz = _features_dc = _features_rest = _opacity = _scaling = _rotation = _semantic_feature = entity_ids = _node_radius = _node_weight = None
        utils.log_cpu_memory_usage("start save_ply")
        group = utils.DEFAULT_GROUP
        if args.gaussians_distribution and not args.distributed_save:
            # gather all gaussians at rank 0
            def gather_uneven_tensors(tensor):
                # gather size of tensors on different ranks
                tensor_sizes = torch.zeros(
                    (group.size()), dtype=torch.int, device="cuda"
                )
                tensor_sizes[group.rank()] = tensor.shape[0]
                dist.all_reduce(tensor_sizes, op=dist.ReduceOp.SUM)
                # move tensor_sizes to CPU and convert to int list
                tensor_sizes = tensor_sizes.cpu().numpy().tolist()

                # NOTE: Internal implementation of gather could not gather tensors of different sizes.
                # So, I do not use dist.gather(tensor, dst=0) but use dist.send(tensor, dst=0) and dist.recv(tensor, src=i) instead.

                # gather tensors on different ranks using grouped send/recv
                gathered_tensors = []
                if group.rank() == 0:
                    for i in range(group.size()):
                        if i == group.rank():
                            gathered_tensors.append(tensor)
                        else:
                            tensor_from_rk_i = torch.zeros(
                                (tensor_sizes[i],) + tensor.shape[1:],
                                dtype=tensor.dtype,
                                device="cuda",
                            )
                            dist.recv(tensor_from_rk_i, src=i)
                            gathered_tensors.append(tensor_from_rk_i)
                    gathered_tensors = torch.cat(gathered_tensors, dim=0)
                else:
                    dist.send(tensor, dst=0)
                # concatenate gathered tensors

                return (
                    gathered_tensors if group.rank() == 0 else None
                )  # only return gather tensors at rank 0

            _xyz = gather_uneven_tensors(self._xyz)
            if self.vis_control_gs:
                _features_dc = gather_uneven_tensors(self._features_dc)
                _features_rest = gather_uneven_tensors(self._features_rest)
                _opacity = gather_uneven_tensors(self._opacity)
                _scaling = gather_uneven_tensors(self._scaling)
                _rotation = gather_uneven_tensors(self._rotation)
                _semantic_feature = gather_uneven_tensors(self._semantic_feature)
            _node_radius = gather_uneven_tensors(self._node_radius)
            _node_weight = gather_uneven_tensors(self._node_weight)
            entity_ids = gather_uneven_tensors(self.entity_ids)

            if group.rank() != 0:
                return

        elif args.gaussians_distribution and args.distributed_save:
            assert (
                    utils.DEFAULT_GROUP.size() > 1
            ), "distributed_save should be used with more than 1 rank."
            _xyz = self._xyz
            if self.vis_control_gs:
                _features_dc = self._features_dc
                _features_rest = self._features_rest
                _opacity = self._opacity
                _scaling = self._scaling
                _rotation = self._rotation
                _semantic_feature = self._semantic_feature
            _node_weight = self._node_weight
            _node_radius = self._node_radius
            entity_ids = self.entity_ids
            if path.endswith(".ply"):
                path = (
                        path[:-4]
                        + "_rk"
                        + str(utils.GLOBAL_RANK)
                        + "_ws"
                        + str(utils.WORLD_SIZE)
                        + ".ply"
                )
        elif not args.gaussians_distribution:
            if group.rank() != 0:
                return
            _xyz = self._xyz
            if self.vis_control_gs:
                _features_dc = self._features_dc
                _features_rest = self._features_rest
                _opacity = self._opacity
                _scaling = self._scaling
                _rotation = self._rotation
                _semantic_feature = self._semantic_feature
            _node_weight = self._node_weight
            _node_radius = self._node_radius
            entity_ids = self.entity_ids
            if path.endswith(".ply"):
                path = (
                        path[:-4]
                        + "_rk"
                        + str(utils.GLOBAL_RANK)
                        + "_ws"
                        + str(utils.WORLD_SIZE)
                        + ".ply"
                )

        mkdir_p(os.path.dirname(path))

        xyz = _xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        if self.vis_control_gs:
            f_dc = (
                _features_dc.detach()
                .transpose(1, 2)
                .flatten(start_dim=1)
                .contiguous()
                .cpu()
                .numpy()
            )
            f_rest = (
                _features_rest.detach()
                .transpose(1, 2)
                .flatten(start_dim=1)
                .contiguous()
                .cpu()
                .numpy()
            )
            opacities = _opacity.detach().cpu().numpy()
            scale = _scaling.detach().cpu().numpy()
            rotation = _rotation.detach().cpu().numpy()
            _semantic_feature = self._semantic_feature.detach().transpose(1, 2).flatten(
                start_dim=1).contiguous().cpu().numpy()
        node_weight = _node_weight.detach().cpu().numpy()
        node_radius = _node_radius.detach().cpu().numpy()
        entity_ids = self.entity_ids.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()

        utils.log_cpu_memory_usage("after change gpu tensor to cpu numpy")

        dtype_full = [
            (attribute, "f4") for attribute in self.construct_list_of_attributes()
        ]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        if self.vis_control_gs:
            attributes = np.concatenate(
                (xyz, normals, f_dc, f_rest, opacities, scale, rotation, _semantic_feature, entity_ids, node_weight,
                 node_radius), axis=1
            )
        else:
            attributes = np.concatenate(
                (xyz, normals, entity_ids, node_weight, node_radius), axis=1
            )
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, "vertex")

        utils.log_cpu_memory_usage(
            "after change numpy to plyelement before writing ply file"
        )
        PlyData([el]).write(path)
        utils.log_cpu_memory_usage("finish write ply file")
        # remark: max_radii2D, xyz_gradient_accum and denom are not saved here; they are save elsewhere.

    def reset_opacity(self):
        utils.LOG_FILE.write("Resetting opacity to 0.01\n")
        opacities_new = inverse_sigmoid(
            torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)
        )
        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
        self._opacity = optimizable_tensors["opacity"]

    def prune_based_on_opacity(self, min_opacity):
        prune_mask = (self.get_opacity < min_opacity).squeeze()
        utils.LOG_FILE.write(
            "Pruning based on opacity. Percent: {:.2f}\n".format(
                100 * prune_mask.sum().item() / prune_mask.shape[0]
            )
        )
        self.prune_points(prune_mask)

    def distributed_load_ply(self, folder):
        # count the number of files like "point_cloud_rk0_ws4.ply"
        world_size = -1
        for f in os.listdir(folder):
            if "_ws" in f:
                world_size = int(f.split("_ws")[1].split(".")[0])
                break
        assert world_size > 0, "world_size should be greater than 1."

        catted_xyz = []
        catted_features_dc = []
        catted_features_rest = []
        catted_opacity = []
        catted_scaling = []
        catted_rotation = []
        catted_semantic_feature = []
        catted_entity_ids = []
        catted_node_radius = []
        catted_node_weight = []
        for rk in range(world_size):
            one_checkpoint_path = (
                    folder + "/point_cloud_rk" + str(rk) + "_ws" + str(world_size) + ".ply"
            )
            if self.vis_control_gs:
                xyz, features_dc, features_extra, opacities, scales, rots, semantic_feature, entity_ids, node_radius, node_weight = (
                    self.load_raw_ply(one_checkpoint_path)
                )
            else:
                xyz, entity_ids, node_radius, node_weight = (
                    self.load_raw_ply(one_checkpoint_path)
                )
            catted_xyz.append(xyz)
            if self.vis_control_gs:
                catted_features_dc.append(features_dc)
                catted_features_rest.append(features_extra)
                catted_opacity.append(opacities)
                catted_scaling.append(scales)
                catted_rotation.append(rots)
                catted_semantic_feature.append(semantic_feature)
            catted_entity_ids.append(entity_ids)
            catted_node_radius.append(node_radius)
            catted_node_weight.append(node_weight)
        catted_xyz = np.concatenate(catted_xyz, axis=0)
        if self.vis_control_gs:
            catted_features_dc = np.concatenate(catted_features_dc, axis=0)
            catted_features_rest = np.concatenate(catted_features_rest, axis=0)
            catted_opacity = np.concatenate(catted_opacity, axis=0)
            catted_scaling = np.concatenate(catted_scaling, axis=0)
            catted_rotation = np.concatenate(catted_rotation, axis=0)
            catted_semantic_feature = np.concatenate(catted_semantic_feature, axis=0)
        catted_entity_ids = np.concatenate(catted_entity_ids, axis=0)
        catted_node_radius = np.concatenate(catted_node_radius, axis=0)
        catted_node_weight = np.concatenate(catted_node_weight, axis=0)

        self._xyz = nn.Parameter(
            torch.tensor(catted_xyz, dtype=torch.float, device="cuda").requires_grad_(
                True
            )
        )
        if self.vis_control_gs:
            self._features_dc = nn.Parameter(
                torch.tensor(catted_features_dc, dtype=torch.float, device="cuda")
                .transpose(1, 2)
                .contiguous()
                .requires_grad_(True)
            )
            self._features_rest = nn.Parameter(
                torch.tensor(catted_features_rest, dtype=torch.float, device="cuda")
                .transpose(1, 2)
                .contiguous()
                .requires_grad_(True)
            )
            self._opacity = nn.Parameter(
                torch.tensor(
                    catted_opacity, dtype=torch.float, device="cuda"
                ).requires_grad_(True)
            )
            self._scaling = nn.Parameter(
                torch.tensor(
                    catted_scaling, dtype=torch.float, device="cuda"
                ).requires_grad_(True)
            )
            self._rotation = nn.Parameter(
                torch.tensor(
                    catted_rotation, dtype=torch.float, device="cuda"
                ).requires_grad_(True)
            )

            self._semantic_feature = nn.Parameter(
                torch.tensor(catted_semantic_feature, dtype=torch.float, device="cuda")
                .transpose(1, 2)
                .contiguous()
                .requires_grad_(True)
            )
        self.entity_ids = nn.Parameter(
            torch.tensor(catted_entity_ids, dtype=torch.float, device="cuda")
            # .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self._node_radius = nn.Parameter(
            torch.tensor(catted_node_radius, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        self._node_weight = nn.Parameter(
            torch.tensor(catted_node_weight, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        self.active_sh_degree = self.max_sh_degree

    def load_raw_ply(self, path):
        print("Loading ", path)
        plydata = PlyData.read(path)

        xyz = np.stack(
            (
                np.asarray(plydata.elements[0]["x"]),
                np.asarray(plydata.elements[0]["y"]),
                np.asarray(plydata.elements[0]["z"]),
            ),
            axis=1,
        )
        if self.vis_control_gs:
            opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
            features_dc = np.zeros((xyz.shape[0], 3, 1))
            features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
            features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
            features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

            extra_f_names = [
                p.name
                for p in plydata.elements[0].properties
                if p.name.startswith("f_rest_")
            ]
            extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
            assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
            features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
            for idx, attr_name in enumerate(extra_f_names):
                features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
            # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
            features_extra = features_extra.reshape(
                (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)
            )

            count = sum(1 for name in plydata.elements[0].data.dtype.names if name.startswith("sem_"))
            semantic_feature = np.stack([np.asarray(plydata.elements[0][f"sem_{i}"]) for i in range(count)], axis=1)
            semantic_feature = np.expand_dims(semantic_feature, axis=-1)

        count2 = sum(1 for name in plydata.elements[0].data.dtype.names if name.startswith("entity_"))
        entity_ids = np.stack([np.asarray(plydata.elements[0][f"entity_{i}"]) for i in range(count2)], axis=1)
        entity_ids = np.expand_dims(entity_ids, axis=1)

        if self.vis_control_gs:
            scale_names = [
                p.name
                for p in plydata.elements[0].properties
                if p.name.startswith("scale_")
            ]
            scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
            scales = np.zeros((xyz.shape[0], len(scale_names)))
            for idx, attr_name in enumerate(scale_names):
                scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

            rot_names = [
                p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
            ]
            rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
            rots = np.zeros((xyz.shape[0], len(rot_names)))
            for idx, attr_name in enumerate(rot_names):
                rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        node_radius = np.asarray(plydata.elements[0]["node_radius"])[..., np.newaxis]
        node_weight = np.asarray(plydata.elements[0]["node_weight"])[..., np.newaxis]

        args = utils.get_args()
        # The above computation/memory is replicated on all ranks. Because initialization is small, it's ok.
        # Split the point cloud across the ranks.

        if args.gaussians_distribution and utils.WORLD_SIZE > 1:
            chunk = xyz.shape[0] // utils.WORLD_SIZE + 1
            point_ind_l = chunk * utils.LOCAL_RANK
            point_ind_r = min(chunk * (utils.LOCAL_RANK + 1), xyz.shape[0])

            xyz = np.ascontiguousarray(xyz[point_ind_l:point_ind_r])
            if self.vis_control_gs:
                features_dc = np.ascontiguousarray(features_dc[point_ind_l:point_ind_r])
                features_extra = np.ascontiguousarray(
                    features_extra[point_ind_l:point_ind_r]
                )
                scales = np.ascontiguousarray(scales[point_ind_l:point_ind_r])
                rots = np.ascontiguousarray(rots[point_ind_l:point_ind_r])
                opacities = np.ascontiguousarray(opacities[point_ind_l:point_ind_r])
                semantic_feature = np.ascontiguousarray(semantic_feature[point_ind_l:point_ind_r])
            entity_ids = np.ascontiguousarray(entity_ids[point_ind_l:point_ind_r])
            node_radius = np.ascontiguousarray(node_radius[point_ind_l:point_ind_r])
            node_weight = np.ascontiguousarray(node_weight[point_ind_l:point_ind_r])

        if args.drop_initial_3dgs_p > 0.0:
            # drop each point with probability args.drop_initial_3dgs_p
            drop_mask = np.random.rand(xyz.shape[0]) > args.drop_initial_3dgs_p
            xyz = xyz[drop_mask]
            if self.vis_control_gs:
                features_dc = features_dc[drop_mask]
                features_extra = features_extra[drop_mask]
                scales = scales[drop_mask]
                rots = rots[drop_mask]
                opacities = opacities[drop_mask]
                semantic_feature = semantic_feature[drop_mask]
            entity_ids = entity_ids[drop_mask]
            node_radius = node_radius[drop_mask]
            node_weight = node_weight[drop_mask]

        if self.vis_control_gs:
            return xyz, features_dc, features_extra, opacities, scales, rots, semantic_feature, entity_ids, node_radius, node_weight
        else:
            return xyz, entity_ids, node_radius, node_weight

    def one_file_load_ply(self, folder):
        path = os.path.join(folder, "point_cloud.ply")
        if self.vis_control_gs:
            xyz, features_dc, features_extra, opacities, scales, rots, semantic_feature, entity_ids = self.load_raw_ply(
                path)
        else:
            xyz, entity_ids, node_radius, node_weight = self.load_raw_ply(
                path
            )

        self._xyz = nn.Parameter(
            torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        if self.vis_control_gs:
            self._features_dc = nn.Parameter(
                torch.tensor(features_dc, dtype=torch.float, device="cuda")
                .transpose(1, 2)
                .contiguous()
                .requires_grad_(True)
            )
            self._features_rest = nn.Parameter(
                torch.tensor(features_extra, dtype=torch.float, device="cuda")
                .transpose(1, 2)
                .contiguous()
                .requires_grad_(True)
            )
            self._opacity = nn.Parameter(
                torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(
                    True
                )
            )
            self._scaling = nn.Parameter(
                torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)
            )
            self._rotation = nn.Parameter(
                torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)
            )

            self._semantic_feature = nn.Parameter(
                torch.tensor(semantic_feature, dtype=torch.float, device="cuda")
                .transpose(1, 2)
                .contiguous()
                .requires_grad_(True)
            )
        self.entity_ids = nn.Parameter(
            torch.tensor(entity_ids, dtype=torch.float, device="cuda")
            .transpose(1, 2)
            .contiguous()
            .requires_grad_(True)
        )
        self._node_radius = nn.Parameter(
            torch.tensor(node_radius, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        self._node_weight = nn.Parameter(
            torch.tensor(node_weight, dtype=torch.float, device="cuda").requires_grad_(True)
        )
        self.active_sh_degree = self.max_sh_degree

    def load_ply(self, path):
        if os.path.exists(os.path.join(path, "point_cloud.ply")):
            self.one_file_load_ply(path)
        else:
            self.distributed_load_ply(path)

    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if group["name"] == name:
                stored_state = self.optimizer.state.get(group["params"][0], None)
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = torch.zeros_like(tensor)
                else:
                    stored_state["exp_avg"] = torch.zeros_like(tensor)
                    stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = stored_state["momentum_buffer"][
                        mask
                    ]
                else:
                    stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                    stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    (group["params"][0][mask].requires_grad_(True))
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    group["params"][0][mask].requires_grad_(True)
                )
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        if self.vis_control_gs:
            self._features_dc = optimizable_tensors["f_dc"]
            self._features_rest = optimizable_tensors["f_rest"]
            self._opacity = optimizable_tensors["opacity"]
            self._scaling = optimizable_tensors["scaling"]
            self._rotation = optimizable_tensors["rotation"]

            self._semantic_feature = optimizable_tensors["sem_f"]
        self.entity_ids = optimizable_tensors["entity_f"]
        self._node_radius = optimizable_tensors["node_radius"]
        self._node_weight = optimizable_tensors["node_weight"]

        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]

        self.send_to_gpui_cnt = self.send_to_gpui_cnt[valid_points_mask]

        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]
        self.sum_visible_count_in_one_batch = self.sum_visible_count_in_one_batch[
            valid_points_mask
        ]

    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = torch.cat(
                        (
                            stored_state["momentum_buffer"],
                            torch.zeros_like(extension_tensor),
                        ),
                        dim=0,
                    )
                else:
                    stored_state["exp_avg"] = torch.cat(
                        (stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
                        dim=0,
                    )
                    stored_state["exp_avg_sq"] = torch.cat(
                        (
                            stored_state["exp_avg_sq"],
                            torch.zeros_like(extension_tensor),
                        ),
                        dim=0,
                    )

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    torch.cat(
                        (group["params"][0], extension_tensor), dim=0
                    ).requires_grad_(True)
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    torch.cat(
                        (group["params"][0], extension_tensor), dim=0
                    ).requires_grad_(True)
                )
                optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(
            self,
            new_xyz,
            new_features_dc,
            new_features_rest,
            new_opacities,
            new_scaling,
            new_rotation,
            new_semantic_feature,
            new_entity_ids,
            new_node_radius,
            new_node_weight,
            new_send_to_gpui_cnt,
    ):
        if self.vis_control_gs:
            d = {
                "xyz": new_xyz,
                "f_dc": new_features_dc,
                "f_rest": new_features_rest,
                "opacity": new_opacities,
                "scaling": new_scaling,
                "rotation": new_rotation,
                "sem_f": new_semantic_feature,
                "entity_f": new_entity_ids,
                "node_radius": new_node_radius,
                "node_weight": new_node_weight
            }
        else:
            d = {
                "xyz": new_xyz,
                "entity_f": new_entity_ids,
                "node_radius": new_node_radius,
                "node_weight": new_node_weight
            }

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        self._xyz = optimizable_tensors["xyz"]
        if self.vis_control_gs:
            self._features_dc = optimizable_tensors["f_dc"]
            self._features_rest = optimizable_tensors["f_rest"]
            self._opacity = optimizable_tensors["opacity"]
            self._scaling = optimizable_tensors["scaling"]
            self._rotation = optimizable_tensors["rotation"]
            self._semantic_feature = optimizable_tensors["sem_f"]
        self.entity_ids = optimizable_tensors["entity_f"]
        self._node_radius = optimizable_tensors["node_radius"]
        self._node_weight = optimizable_tensors["node_weight"]

        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self.sum_visible_count_in_one_batch = torch.zeros(
            (self.get_xyz.shape[0]), device="cuda"
        )

        self.send_to_gpui_cnt = torch.cat(
            (self.send_to_gpui_cnt, new_send_to_gpui_cnt), dim=0
        )

    def densify_and_split(self, grads, grad_threshold, scene_extent=None, N=2):
        n_init_points = self.get_xyz.shape[0]
        # Extract points that satisfy the gradient condition
        padded_grad = torch.zeros((n_init_points), device="cuda")
        padded_grad[: grads.shape[0]] = grads.squeeze()
        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
        if self.vis_control_gs:
            selected_pts_mask = torch.logical_and(
                selected_pts_mask,
                torch.max(self.get_scaling, dim=1).values
                > self.percent_dense * scene_extent,
            )

        stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
        means = torch.zeros((stds.size(0), 3), device="cuda")
        samples = torch.normal(mean=means, std=stds)
        # [N * number of selected points, 3]

        utils.get_log_file().write(
            "Number of split gaussians: {}\n".format(selected_pts_mask.sum().item())
        )
        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[
            selected_pts_mask
        ].repeat(N, 1)
        if self.vis_control_gs:
            new_scaling = self.scaling_inverse_activation(
                self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)
            )
            new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
            new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
            new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
            new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)

            new_semantic_feature = self._semantic_feature[selected_pts_mask].repeat(N, 1, 1)

        new_entity_ids = self.entity_ids[selected_pts_mask].repeat(N, 1, 1)
        new_node_radius = self._node_radius[selected_pts_mask].repeat(N, 1)
        new_node_weight = self._node_weight[selected_pts_mask].repeat(N, 1)

        new_send_to_gpui_cnt = self.send_to_gpui_cnt[selected_pts_mask].repeat(N, 1)

        if self.vis_control_gs:
            self.densification_postfix(
                new_xyz,
                new_features_dc,
                new_features_rest,
                new_opacity,
                new_scaling,
                new_rotation,
                new_semantic_feature,
                new_entity_ids,
                new_node_radius,
                new_node_weight,
                new_send_to_gpui_cnt,
            )
        else:
            self.densification_postfix(
                new_xyz,
                new_entity_ids,
                new_node_radius,
                new_node_weight,
                new_send_to_gpui_cnt,
            )

        prune_filter = torch.cat(
            (
                selected_pts_mask,
                torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool),
            )
        )
        self.prune_points(prune_filter)

    def densify_and_clone(self, grads, grad_threshold, scene_extent=None, selected_pts_mask=None):
        # Extract points that satisfy the gradient condition
        if selected_pts_mask is None:
            selected_pts_mask = torch.where(
                torch.norm(grads, dim=-1) >= grad_threshold, True, False
            )
            if self.vis_control_gs:
                selected_pts_mask = torch.logical_and(
                    selected_pts_mask,
                    torch.max(self.get_scaling, dim=1).values
                    <= self.percent_dense * scene_extent,
                )

        utils.get_log_file().write(
            "Number of cloned gaussians: {}\n".format(selected_pts_mask.sum().item())
        )
        new_xyz = self._xyz[selected_pts_mask]
        if self.vis_control_gs:
            new_features_dc = self._features_dc[selected_pts_mask]
            new_features_rest = self._features_rest[selected_pts_mask]
            new_opacities = self._opacity[selected_pts_mask]
            new_scaling = self._scaling[selected_pts_mask]
            new_rotation = self._rotation[selected_pts_mask]
            new_semantic_feature = self._semantic_feature[selected_pts_mask]
        new_entity_ids = self.entity_ids[selected_pts_mask]
        new_node_radius = self._node_radius[selected_pts_mask]
        new_node_weight = self._node_weight[selected_pts_mask]

        new_send_to_gpui_cnt = self.send_to_gpui_cnt[selected_pts_mask]

        if self.vis_control_gs:
            self.densification_postfix(
                new_xyz,
                new_features_dc,
                new_features_rest,
                new_opacities,
                new_scaling,
                new_rotation,
                new_semantic_feature,
                new_entity_ids,
                new_node_radius,
                new_node_weight,
                new_send_to_gpui_cnt,
            )
        else:
            self.densification_postfix(
                new_xyz,
                new_entity_ids,
                new_node_radius,
                new_node_weight,
                new_send_to_gpui_cnt,
            )

    def control_point_densify(self, selected_pts_mask=None, new_pts=None):
        # Extract points that satisfy the gradient condition

        utils.get_log_file().write(
            "Number of cloned gaussians: {}\n".format(selected_pts_mask.sum().item())
        )
        new_xyz = new_pts
        new_entity_ids = self.entity_ids[selected_pts_mask]
        new_node_radius = self._node_radius[selected_pts_mask]
        new_node_weight = self._node_weight[selected_pts_mask]
        new_send_to_gpui_cnt = self.send_to_gpui_cnt[selected_pts_mask]

        if self.vis_control_gs:
            new_features_dc = self._features_dc[selected_pts_mask]
            new_features_rest = self._features_rest[selected_pts_mask]
            new_opacities = self._opacity[selected_pts_mask]
            new_scaling = self._scaling[selected_pts_mask]
            new_rotation = self._rotation[selected_pts_mask]
            new_semantic_feature = self._semantic_feature[selected_pts_mask]
            self.densification_postfix(
                new_xyz,
                new_features_dc,
                new_features_rest,
                new_opacities,
                new_scaling,
                new_rotation,
                new_semantic_feature,
                new_entity_ids,
                new_node_radius,
                new_node_weight,
                new_send_to_gpui_cnt,
            )
        else:
            self.densification_postfix(
                new_xyz,
                new_entity_ids,
                new_node_radius,
                new_node_weight,
                new_send_to_gpui_cnt,
            )

    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, total_gs_num, densify_mask=None,
                          max_num=2000_000, skip_densify=False, min_gs_num=2000):
        args = utils.get_args()
        if not args.gaussians_distribution and utils.DEFAULT_GROUP.size() > 1:
            torch.distributed.all_reduce(
                self.max_radii2D, op=dist.ReduceOp.MAX, group=utils.DP_GROUP
            )
            torch.distributed.all_reduce(
                self.xyz_gradient_accum, op=dist.ReduceOp.SUM, group=utils.DP_GROUP
            )
            torch.distributed.all_reduce(
                self.denom, op=dist.ReduceOp.SUM, group=utils.DP_GROUP
            )

        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0

        densification_stats = {}
        densification_stats["view_space_grad"] = grads.mean().item()
        densification_stats["view_space_grad_max"] = grads.max().item()

        if total_gs_num <= max_num and not skip_densify:
            self.densify_and_clone(grads, max_grad, extent)
            self.densify_and_split(grads, max_grad, extent)

        if total_gs_num >= min_gs_num:
            prune_mask = (self.get_opacity < min_opacity).squeeze()
            if max_screen_size:
                big_points_vs = self.max_radii2D > max_screen_size
                if total_gs_num <= max_num and not skip_densify:
                    # NOTE: this is bug in its implementation.
                    assert torch.all(
                        self.max_radii2D == 0
                    ), "In its implementation, max_radii2D is all 0. This is a bug."
                    assert torch.all(
                        big_points_vs == False
                    ), "In its implementation, big_points_vs is all False. This is a bug."
                big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
                prune_mask = torch.logical_or(
                    torch.logical_or(prune_mask, big_points_vs), big_points_ws
                )
            self.prune_points(prune_mask)

        torch.cuda.empty_cache()

    def densify_and_prune_v2(self, max_grad, scene_gs: GaussianModel, K=3):

        args = utils.get_args()
        if not args.gaussians_distribution and utils.DEFAULT_GROUP.size() > 1:
            torch.distributed.all_reduce(
                scene_gs.max_radii2D, op=dist.ReduceOp.MAX, group=utils.DP_GROUP
            )
            torch.distributed.all_reduce(
                scene_gs.xyz_gradient_accum, op=dist.ReduceOp.SUM, group=utils.DP_GROUP
            )
            torch.distributed.all_reduce(
                scene_gs.denom, op=dist.ReduceOp.SUM, group=utils.DP_GROUP
            )

        grads = scene_gs.xyz_gradient_accum / scene_gs.denom
        grads[grads.isnan()] = 0.0

        densification_stats = {}
        densification_stats["view_space_grad"] = grads.mean().item()
        densification_stats["view_space_grad_max"] = grads.max().item()
        node_avg_xgradnorm, node_avg_x, node_edge_count = self.cal_node_importance(x=scene_gs._xyz,
                                                                                   x_entity_ids=scene_gs.get_entity,
                                                                                   K=K, weights=grads)

        point_ind_l, point_ind_r = utils.get_control_gs_global_chunk_l_r(self._xyz.shape[0])
        node_avg_xgradnorm, node_avg_x, node_edge_count = node_avg_xgradnorm[point_ind_l:point_ind_r], node_avg_x[
                                                                                                       point_ind_l:point_ind_r], node_edge_count[
                                                                                                                                 point_ind_l:point_ind_r]
        selected_pts_mask = torch.logical_and(node_avg_xgradnorm > max_grad,
                                              node_avg_x.isnan().logical_not().all(dim=-1))
        self.nodes_color_visualization = torch.ones_like(self._xyz)

        pruned_pts_mask = node_edge_count == 0
        if selected_pts_mask.sum() > 0 or pruned_pts_mask.sum() > 0:
            utils.get_log_file().write("Number of cloned gaussians: {}\n".format(selected_pts_mask.sum().item()))
            utils.get_log_file().write("Number of pruned gaussians: {}\n".format(pruned_pts_mask.sum().item()))
        else:
            return

        new_pts = node_avg_x[selected_pts_mask]

        if selected_pts_mask.sum() > 0:
            self.control_point_densify(selected_pts_mask, new_pts)
            new_point_mask = torch.zeros(new_pts.shape[0], device=selected_pts_mask.device).to(torch.bool)
            pruned_pts_mask = torch.cat([pruned_pts_mask, new_point_mask], dim=0)
        if pruned_pts_mask.sum() > 0:
            self.prune_points(pruned_pts_mask)
        torch.cuda.empty_cache()

    def gather_multiscale_tensor(self, cur_tensor):
        gs_num = torch.tensor([cur_tensor.shape[0]], device=cur_tensor.device)
        gs_num_list = [torch.empty_like(gs_num) for _ in range(utils.DEFAULT_GROUP.size())]
        torch.distributed.all_gather(gs_num_list, gs_num)
        gs_num_list = torch.cat(gs_num_list, dim=0).cpu().tolist()
        max_num = max(gs_num_list)
        if cur_tensor.ndim == 2:
            padded_tensor = torch.zeros(max_num, cur_tensor.shape[-1], device=cur_tensor.device)
            padded_tensor[:cur_tensor.shape[0], :] = cur_tensor
        elif cur_tensor.ndim == 1:
            padded_tensor = torch.zeros(max_num, 1, cur_tensor.shape[-1], device=cur_tensor.device)
            padded_tensor[:cur_tensor.shape[0], :, :] = cur_tensor
        else:
            print("error in dim")

        gathered_tensor = [torch.empty_like(padded_tensor) for _ in range(utils.DEFAULT_GROUP.size())]
        dist.all_gather(gathered_tensor, padded_tensor)

        if cur_tensor.ndim == 2:
            gathered_tensor = [g[:s, :] for g, s in zip(gathered_tensor, gs_num_list)]
        else:
            gathered_tensor = [g[:s, :, :] for g, s in zip(gathered_tensor, gs_num_list)]
        return torch.cat(gathered_tensor, dim=0)

    @torch.no_grad()
    def cal_node_importance(self, x: torch.Tensor, x_entity_ids: torch.Tensor, K=None, weights=None):
        # Calculate the weights of Gaussians on nodes as importance

        K = self.K if K is None else K

        all_nodes = self.gather_multiscale_tensor(self._xyz)
        node_importance = torch.zeros_like(all_nodes[:, 0]).view(-1)
        node_edge_counts = torch.zeros_like(all_nodes[:, 0]).view(-1)
        node_weighted_edge_count = torch.zeros_like(all_nodes[:, 0]).view(-1)
        avg_affected_x = torch.zeros_like(all_nodes)
        weights = torch.ones_like(x[:, 0]) if weights is None else weights

        # node_importance = torch.sum(self.gs._xyz.grad, dim=1)
        node_importance.index_add_(dim=0, index=self.gs2cp_nn_idx.view(-1),
                                   source=((self.gs2cp_nn_weight * weights).view(-1)).view(-1))
        node_edge_counts.index_add_(dim=0, index=self.gs2cp_nn_idx.view(-1),
                                    source=torch.ones_like(self.gs2cp_nn_weight).view(-1))
        node_weighted_edge_count.index_add_(dim=0, index=self.gs2cp_nn_idx.view(-1),
                                            source=self.gs2cp_nn_weight.view(-1))
        avg_affected_x.index_add_(dim=0, index=self.gs2cp_nn_idx.view(-1), source=(
                (self.gs2cp_nn_weight * weights).view(-1, 1) * x.repeat_interleave(self.gs2cp_nn_weight.shape[1],
                                                                                   dim=0)))

        dist.all_reduce(node_importance, op=dist.ReduceOp.SUM, group=utils.DP_GROUP)
        dist.all_reduce(node_edge_counts, op=dist.ReduceOp.SUM, group=utils.DP_GROUP)
        dist.all_reduce(node_weighted_edge_count, op=dist.ReduceOp.SUM, group=utils.DP_GROUP)
        dist.all_reduce(avg_affected_x, op=dist.ReduceOp.SUM, group=utils.DP_GROUP)

        avg_affected_x = avg_affected_x / node_importance[:, None]
        node_importance = node_importance / (node_edge_counts + 1e-7)

        return node_importance, avg_affected_x, node_edge_counts

    def add_densification_stats(
            self, viewspace_point_tensor, update_filter
    ):  # the :2] is a weird implementation. It is because viewspace_point_tensor is (N, 3) tensor.
        self.xyz_gradient_accum[update_filter] += torch.norm(
            viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
        )
        self.denom[update_filter] += 1

    def gsplat_add_densification_stats(
            self, viewspace_point_tensor_grad, update_filter, width, height
    ):  # the :2] is a weird implementation. It is because viewspace_point_tensor is (N, 3) tensor.
        grad = viewspace_point_tensor_grad  # (N, 2)
        # Normalize the gradients to [-1, 1] screen size
        grad[:, 0] *= width * 0.5
        grad[:, 1] *= height * 0.5
        self.xyz_gradient_accum[update_filter] += torch.norm(
            grad[update_filter, :2], dim=-1, keepdim=True
        )
        self.denom[update_filter] += 1

    def group_for_redistribution(self):
        args = utils.get_args()
        if args.gaussians_distribution:
            return utils.DEFAULT_GROUP
        else:
            return utils.SingleGPUGroup()

    def all2all_gaussian_state(self, state, destination, i2j_send_size):
        comm_group = self.group_for_redistribution()

        # state: (N, ...) tensor
        state_to_gpuj = []
        state_from_gpuj = []
        for j in range(comm_group.size()):  # ugly implementation.
            state_to_gpuj.append(state[destination == j, ...].contiguous())
            state_from_gpuj.append(
                torch.zeros(
                    (i2j_send_size[j][comm_group.rank()], *state.shape[1:]),
                    device="cuda",
                )
            )

        torch.distributed.all_to_all(state_from_gpuj, state_to_gpuj, group=comm_group)

        state_from_remote = torch.cat(
            state_from_gpuj, dim=0
        ).contiguous()  # it stucks at here.
        # print(f"state_from_remote, ws={comm_group.size()}, rank={comm_group.rank()}")
        return state_from_remote

    def all2all_tensors_in_optimizer_implementation_1(self, destination, i2j_send_size):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:

                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = self.all2all_gaussian_state(
                        stored_state["momentum_buffer"], destination, i2j_send_size
                    )
                else:
                    stored_state["exp_avg"] = self.all2all_gaussian_state(
                        stored_state["exp_avg"], destination, i2j_send_size
                    )
                    stored_state["exp_avg_sq"] = self.all2all_gaussian_state(
                        stored_state["exp_avg_sq"], destination, i2j_send_size
                    )

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    self.all2all_gaussian_state(
                        group["params"][0], destination, i2j_send_size
                    ),
                    requires_grad=True,
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    self.all2all_gaussian_state(
                        group["params"][0], destination, i2j_send_size
                    ),
                    requires_grad=True,
                )
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def get_all_optimizer_states(self):
        all_tensors = []
        all_shapes = []
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    all_tensors.append(stored_state["momentum_buffer"])
                    all_shapes.append(stored_state["momentum_buffer"].shape)
                else:
                    all_tensors.append(stored_state["exp_avg"])
                    all_shapes.append(stored_state["exp_avg"].shape)

                    all_tensors.append(stored_state["exp_avg_sq"])
                    all_shapes.append(stored_state["exp_avg_sq"].shape)

                all_tensors.append(group["params"][0])
                all_shapes.append(group["params"][0].shape)
            else:
                all_tensors.append(group["params"][0])
                all_shapes.append(group["params"][0].shape)

        return all_tensors, all_shapes

    def update_all_optimizer_states(self, updated_tensors):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            stored_state = self.optimizer.state.get(group["params"][0], None)
            if stored_state is not None:
                if "exp_avg" not in stored_state:
                    stored_state["momentum_buffer"] = updated_tensors.pop(
                        0
                    ).contiguous()
                else:
                    stored_state["exp_avg"] = updated_tensors.pop(0).contiguous()
                    stored_state["exp_avg_sq"] = updated_tensors.pop(0).contiguous()

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(
                    updated_tensors.pop(0).contiguous(), requires_grad=True
                )
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    updated_tensors.pop(0).contiguous(), requires_grad=True
                )
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def all2all_tensors_in_optimizer_implementation_2(self, destination, i2j_send_size):
        # merge into one single all2all kernal launch.

        # get all optimizer states for all2all
        all_tensors, all_shapes = self.get_all_optimizer_states()
        # flatten all tensors with start_dim=1, then concate them at dim=1
        all_tensors_flatten = [tensor.flatten(start_dim=1) for tensor in all_tensors]
        all_tensors_catted = torch.cat(all_tensors_flatten, dim=1).contiguous()
        all_tensors_flatten = None  # release memory

        # all2all
        all_remote_tensors_catted = self.all2all_gaussian_state(
            all_tensors_catted, destination, i2j_send_size
        )
        all_tensors_catted = None  # release memory

        # split all_tensors_catted to original shapes
        all_remote_tensors_flatten = torch.split(
            all_remote_tensors_catted,
            [shape[1:].numel() for shape in all_shapes],
            dim=1,
        )
        all_remote_tensors_catted = None  # release memory
        all_remote_tensors = [
            tensor.view(tensor.shape[:1] + shape[1:])
            for tensor, shape in zip(all_remote_tensors_flatten, all_shapes)
        ]
        all_remote_tensors_flatten = None  # release memory

        # update optimizer states
        optimizable_tensors = self.update_all_optimizer_states(all_remote_tensors)
        all_remote_tensors = None

        return optimizable_tensors

    def all2all_tensors_in_optimizer(self, destination, i2j_send_size):
        return self.all2all_tensors_in_optimizer_implementation_1(
            destination, i2j_send_size
        )
        # return self.all2all_tensors_in_optimizer_implementation_2(destination, i2j_send_size)
        # when cross node all2all on perl, implementation_2 will get stuck at 1600 iterations, I do not know the reason.

    def get_destination_1(self, world_size):
        # norm p=0
        return torch.randint(0, world_size, (self.get_xyz.shape[0],), device="cuda")

    def need_redistribute_gaussians(self, group):
        args = utils.get_args()
        if group.size() == 1:
            return False
        if utils.get_denfify_iter() == args.redistribute_gaussians_frequency:
            # do redistribution after the first densification.
            return True
        local_n_3dgs = self.get_xyz.shape[0]
        all_local_n_3dgs = [None for _ in range(group.size())]
        torch.distributed.all_gather_object(all_local_n_3dgs, local_n_3dgs, group=group)
        if min(all_local_n_3dgs) * args.redistribute_gaussians_threshold < max(
                all_local_n_3dgs
        ):
            return True
        return False

    def redistribute_gaussians(self):
        args = utils.get_args()
        if args.redistribute_gaussians_mode == "no_redistribute":
            return

        comm_group_for_redistribution = self.group_for_redistribution()
        if not self.need_redistribute_gaussians(comm_group_for_redistribution):
            return

        # Get each 3dgs' destination GPU.
        if args.redistribute_gaussians_mode == "random_redistribute":
            # random redistribution to balance the number of gaussians on each GPU.
            destination = self.get_destination_1(comm_group_for_redistribution.size())
        else:
            raise ValueError(
                "Invalid redistribute_gaussians_mode: "
                + args.redistribute_gaussians_mode
            )

        # Count the number of 3dgs to be sent to each GPU.
        local2j_send_size = torch.bincount(
            destination, minlength=comm_group_for_redistribution.size()
        ).int()
        assert (
                len(local2j_send_size) == comm_group_for_redistribution.size()
        ), "local2j_send_size: " + str(local2j_send_size)

        i2j_send_size = torch.zeros(
            (
                comm_group_for_redistribution.size(),
                comm_group_for_redistribution.size(),
            ),
            dtype=torch.int,
            device="cuda",
        )
        torch.distributed.all_gather_into_tensor(
            i2j_send_size, local2j_send_size, group=comm_group_for_redistribution
        )
        i2j_send_size = i2j_send_size.cpu().numpy().tolist()

        optimizable_tensors = self.all2all_tensors_in_optimizer(
            destination, i2j_send_size
        )
        self._xyz = optimizable_tensors["xyz"]
        if self.vis_control_gs:
            self._features_dc = optimizable_tensors["f_dc"]
            self._features_rest = optimizable_tensors["f_rest"]
            self._opacity = optimizable_tensors["opacity"]
            self._scaling = optimizable_tensors["scaling"]
            self._rotation = optimizable_tensors["rotation"]
            self._semantic_feature = optimizable_tensors["sem_f"]
        self.entity_ids = optimizable_tensors["entity_f"]
        self._node_radius = optimizable_tensors["node_radius"]
        self._node_weight = optimizable_tensors["node_weight"]

        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self.sum_visible_count_in_one_batch = torch.zeros(
            (self.get_xyz.shape[0]), device="cuda"
        )
        self.send_to_gpui_cnt = torch.zeros(
            (self.get_xyz.shape[0], comm_group_for_redistribution.size()),
            dtype=torch.int,
            device="cuda",
        )

        torch.cuda.empty_cache()
