#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#
import math
import os
import random
import sys
from dataclasses import dataclass, field
from datetime import datetime
from typing import NamedTuple

import numpy as np
import threestudio
import torch
import torch.nn as nn
import torch.nn.functional as F
from plyfile import PlyData, PlyElement
from simple_knn._C import distCUDA2
from threestudio.models.geometry.base import BaseGeometry
from threestudio.utils.misc import C
from threestudio.utils.typing import *

from .gaussian_io import GaussianIO

C0 = 0.28209479177387814


def RGB2SH(rgb):
    return (rgb - 0.5) / C0


def SH2RGB(sh):
    return sh * C0 + 0.5


def inverse_sigmoid(x):
    return torch.log(x / (1 - x))


def strip_lowerdiag(L):
    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")

    uncertainty[:, 0] = L[:, 0, 0]
    uncertainty[:, 1] = L[:, 0, 1]
    uncertainty[:, 2] = L[:, 0, 2]
    uncertainty[:, 3] = L[:, 1, 1]
    uncertainty[:, 4] = L[:, 1, 2]
    uncertainty[:, 5] = L[:, 2, 2]
    return uncertainty


def strip_symmetric(sym):
    return strip_lowerdiag(sym)


def gaussian_3d_coeff(xyzs, covs):
    # xyzs: [N, 3]
    # covs: [N, 6]
    x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
    a, b, c, d, e, f = (
        covs[:, 0],
        covs[:, 1],
        covs[:, 2],
        covs[:, 3],
        covs[:, 4],
        covs[:, 5],
    )

    # eps must be small enough !!!
    inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
    inv_a = (d * f - e**2) * inv_det
    inv_b = (e * c - b * f) * inv_det
    inv_c = (e * b - c * d) * inv_det
    inv_d = (a * f - c**2) * inv_det
    inv_e = (b * c - e * a) * inv_det
    inv_f = (a * d - b**2) * inv_det

    power = (
        -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f)
        - x * y * inv_b
        - x * z * inv_c
        - y * z * inv_e
    )

    power[power > 0] = -1e10  # abnormal values... make weights 0

    return torch.exp(power)


def build_rotation(r):
    norm = torch.sqrt(
        r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
    )

    q = r / norm[:, None]

    R = torch.zeros((q.size(0), 3, 3), device="cuda")

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y * y + z * z)
    R[:, 0, 1] = 2 * (x * y - r * z)
    R[:, 0, 2] = 2 * (x * z + r * y)
    R[:, 1, 0] = 2 * (x * y + r * z)
    R[:, 1, 1] = 1 - 2 * (x * x + z * z)
    R[:, 1, 2] = 2 * (y * z - r * x)
    R[:, 2, 0] = 2 * (x * z - r * y)
    R[:, 2, 1] = 2 * (y * z + r * x)
    R[:, 2, 2] = 1 - 2 * (x * x + y * y)
    return R


def build_scaling_rotation(s, r):
    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
    R = build_rotation(r)

    L[:, 0, 0] = s[:, 0]
    L[:, 1, 1] = s[:, 1]
    L[:, 2, 2] = s[:, 2]

    L = R @ L
    return L


def safe_state(silent):
    old_f = sys.stdout

    class F:
        def __init__(self, silent):
            self.silent = silent

        def write(self, x):
            if not self.silent:
                if x.endswith("\n"):
                    old_f.write(
                        x.replace(
                            "\n",
                            " [{}]\n".format(
                                str(datetime.now().strftime("%d/%m %H:%M:%S"))
                            ),
                        )
                    )
                else:
                    old_f.write(x)

        def flush(self):
            old_f.flush()

    sys.stdout = F(silent)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.set_device(torch.device("cuda:0"))


class BasicPointCloud(NamedTuple):
    points: np.array
    colors: np.array
    normals: np.array


class Camera(NamedTuple):
    FoVx: torch.Tensor
    FoVy: torch.Tensor
    camera_center: torch.Tensor
    image_width: int
    image_height: int
    world_view_transform: torch.Tensor
    full_proj_transform: torch.Tensor


@threestudio.register("gaussian-splatting-debug")
class GaussianBaseModel(BaseGeometry, GaussianIO):
    @dataclass
    class Config(BaseGeometry.Config):
        max_num: int = 500000
        sh_degree: int = 0
        position_lr: Any = 0.001
        scale_lr: Any = 0.003
        feature_lr: Any = 0.01
        opacity_lr: Any = 0.05
        scaling_lr: Any = 0.005
        rotation_lr: Any = 0.005
        pred_normal: bool = False
        normal_lr: Any = 0.001

        densification_interval: int = 50
        prune_interval: int = 50
        opacity_reset_interval: int = 100000
        densify_from_iter: int = 100
        prune_from_iter: int = 100
        densify_until_iter: int = 2000
        prune_until_iter: int = 2000
        densify_grad_threshold: Any = 0.01
        min_opac_prune: Any = 0.005
        split_thresh: Any = 0.01
        radii2d_thresh: Any = 1000

        

        sphere: bool = False
        prune_big_points: bool = False
        color_clip: Any = 2.0

        geometry_convert_from: str = ""
        load_ply_only_vertex: bool = False
        init_num_pts: int = 100
        pc_init_radius: float = 0.8
        opacity_init: float = 0.1

        shap_e_guidance_config: dict = field(default_factory=dict)

    cfg: Config

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm

        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize

    def configure(self) -> None:
        super().configure()
        self.active_sh_degree = 0
        self.max_sh_degree = self.cfg.sh_degree
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(0)
        self.denom = torch.empty(0)

        if self.cfg.pred_normal:
            self._normal = torch.empty(0)
        self.optimizer = None
        self.setup_functions()

        if self.cfg.geometry_convert_from.startswith("shap-e:"):
            shap_e_guidance = threestudio.find("shap-e-guidance")(
                self.cfg.shap_e_guidance_config
            )
            prompt = self.cfg.geometry_convert_from[len("shap-e:") :]
            xyz, color = shap_e_guidance(prompt)

            pcd = BasicPointCloud(
                points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
            )
            self.create_from_pcd(pcd, 10)
            self.training_setup()

        # Support Initialization from OpenLRM, Please see https://github.com/Adamdad/threestudio-lrm
        elif self.cfg.geometry_convert_from.startswith("lrm:"):
            lrm_guidance = threestudio.find("lrm-guidance")(
                self.cfg.shap_e_guidance_config
            )
            prompt = self.cfg.geometry_convert_from[len("lrm:") :]
            xyz, color = lrm_guidance(prompt)

            pcd = BasicPointCloud(
                points=xyz, colors=color, normals=np.zeros((xyz.shape[0], 3))
            )
            # self.create_from_pcd(pcd, 10)
            # self.training_setup()

        elif os.path.exists(self.cfg.geometry_convert_from):
            threestudio.info(
                "Loading point cloud from %s" % self.cfg.geometry_convert_from
            )
            if self.cfg.geometry_convert_from.endswith(".ckpt"):
                ckpt_dict = torch.load(self.cfg.geometry_convert_from)
                num_pts = ckpt_dict["state_dict"]["geometry._xyz"].shape[0]
                pcd = BasicPointCloud(
                    points=np.zeros((num_pts, 3)),
                    colors=np.zeros((num_pts, 3)),
                    normals=np.zeros((num_pts, 3)),
                )
                self.create_from_pcd(pcd, 10)
                self.training_setup()
                new_ckpt_dict = {}
                for key in self.state_dict():
                    if ckpt_dict["state_dict"].__contains__("geometry." + key):
                        new_ckpt_dict[key] = ckpt_dict["state_dict"]["geometry." + key]
                    else:
                        new_ckpt_dict[key] = self.state_dict()[key]
                self.load_state_dict(new_ckpt_dict)
            elif self.cfg.geometry_convert_from.endswith(".ply"):
                if self.cfg.load_ply_only_vertex:
                    plydata = PlyData.read(self.cfg.geometry_convert_from)
                    vertices = plydata["vertex"]
                    positions = np.vstack(
                        [vertices["x"], vertices["y"], vertices["z"]]
                    ).T
                    if vertices.__contains__("red"):
                        colors = (
                            np.vstack(
                                [vertices["red"], vertices["green"], vertices["blue"]]
                            ).T
                            / 255.0
                        )
                    else:
                        shs = np.random.random((positions.shape[0], 3)) / 255.0
                        C0 = 0.28209479177387814
                        colors = shs * C0 + 0.5
                    normals = np.zeros_like(positions)
                    pcd = BasicPointCloud(
                        points=positions, colors=colors, normals=normals
                    )
                    self.create_from_pcd(pcd, 10)
                else:
                    self.load_ply(self.cfg.geometry_convert_from)
                self.training_setup()
        else:
            threestudio.info("Geometry not found, initilization with random points")
            num_pts = self.cfg.init_num_pts
            phis = np.random.random((num_pts,)) * 2 * np.pi
            costheta = np.random.random((num_pts,)) * 2 - 1
            thetas = np.arccos(costheta)
            mu = np.random.random((num_pts,))
            radius = self.cfg.pc_init_radius * np.cbrt(mu)
            x = radius * np.sin(thetas) * np.cos(phis)
            y = radius * np.sin(thetas) * np.sin(phis)
            z = radius * np.cos(thetas)
            xyz = np.stack((x, y, z), axis=1)

            shs = np.random.random((num_pts, 3)) / 255.0
            C0 = 0.28209479177387814
            color = shs * C0 + 0.5
            pcd = BasicPointCloud(
                points=xyz, colors=color, normals=np.zeros((num_pts, 3))
            )

            # self.create_from_pcd(pcd, 10)
            # self.training_setup()

    @property
    def get_scaling(self):
        if self.cfg.sphere:
            s = self.scaling_activation(
                torch.mean(self._scaling, dim=-1).unsqueeze(-1).repeat(1, 3)
            )
        else:
            s = self.scaling_activation(self._scaling)
        
        # --- 新增代码：防止高斯球过小 ---
        # 1e-4 是一个经验值，根据你的场景尺度调整。
        # 如果颗粒感依然严重，尝试改大到 1e-3
        return torch.clamp(s, min=1e-4)

    @property
    def get_rotation(self):
        return self.rotation_activation(self._rotation)

    @property
    def get_xyz(self):
        return self._xyz

    @property
    def get_features(self):
        features_dc = self._features_dc
        features_dc = features_dc.clip(-self.color_clip, self.color_clip)
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)

    @property
    def get_opacity(self):
        return self.opacity_activation(self._opacity)

    @property
    def get_normal(self):
        if self.cfg.pred_normal:
            return self._normal
        else:
            raise ValueError("Normal is not predicted")

    def get_covariance(self, scaling_modifier=1):
        return self.covariance_activation(
            self.get_scaling, scaling_modifier, self._rotation
        )

    def oneupSHdegree(self):
        print(self.max_sh_degree)
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1
            threestudio.info(f"Upgrading SH degree to {self.active_sh_degree}")

    def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = (
            torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
            .float()
            .cuda()
        )
        features[:, :3, 0] = fused_color
        features[:, 3:, 1:] = 0.0

        threestudio.info(
            f"Number of points at initialisation:{fused_point_cloud.shape[0]}"
        )

        dist2 = torch.clamp_min(
            distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
            0.0000001,
        )
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        opacities = inverse_sigmoid(
            self.cfg.opacity_init
            * torch.ones(
                (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
            )
        )

        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(
            features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)
        )
        self._features_rest = nn.Parameter(
            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)
        )
        self._scaling = nn.Parameter(scales.requires_grad_(True))
        self._rotation = nn.Parameter(rots.requires_grad_(True))
        self._opacity = nn.Parameter(opacities.requires_grad_(True))
        if self.cfg.pred_normal:
            normals = torch.zeros((fused_point_cloud.shape[0], 3), device="cuda")
            self._normal = nn.Parameter(normals.requires_grad_(True))
        self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
        self.fused_point_cloud = fused_point_cloud.cpu().clone().detach()
        self.features = features.cpu().clone().detach()
        self.scales = scales.cpu().clone().detach()
        self.rots = rots.cpu().clone().detach()
        self.opacities = opacities.cpu().clone().detach()
        print("max_sh_degree", self.max_sh_degree)
        with torch.no_grad():
            r = torch.norm(self.get_xyz, dim=1)
            self.cameras_extent = torch.quantile(r, 0.95).item()  # e.g. ~1.0 左右

    def training_setup(self):
        training_args = self.cfg
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")

        l = [
            {
                "params": [self._xyz],
                "lr": C(training_args.position_lr, 0, 0),
                "name": "xyz",
            },
            {
                "params": [self._features_dc],
                "lr": C(training_args.feature_lr, 0, 0),
                "name": "f_dc",
            },
            {
                "params": [self._features_rest],
                "lr": C(training_args.feature_lr, 0, 0) / 20.0,
                "name": "f_rest",
            },
            {
                "params": [self._opacity],
                "lr": C(training_args.opacity_lr, 0, 0),
                "name": "opacity",
            },
            {
                "params": [self._scaling],
                "lr": C(training_args.scaling_lr, 0, 0),
                "name": "scaling",
            },
            {
                "params": [self._rotation],
                "lr": C(training_args.rotation_lr, 0, 0),
                "name": "rotation",
            },
        ]
        if self.cfg.pred_normal:
            l.append(
                {
                    "params": [self._normal],
                    "lr": C(training_args.normal_lr, 0, 0),
                    "name": "normal",
                },
            )

        self.optimize_params = [
            "xyz",
            "f_dc",
            "f_rest",
            "opacity",
            "scaling",
            "rotation",
        ]
        self.optimize_list = l
        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)

    def merge_optimizer(self, net_optimizer):
        l = self.optimize_list
        for param in net_optimizer.param_groups:
            l.append(
                {
                    "params": param["params"],
                    "lr": param["lr"],
                }
            )
        self.optimizer = torch.optim.Adam(l, lr=0.0)
        return self.optimizer

    def update_learning_rate(self, iteration):
        """Learning rate scheduling per step"""
        for param_group in self.optimizer.param_groups:
            if not ("name" in param_group):
                continue
            if param_group["name"] == "xyz":
                param_group["lr"] = C(
                    self.cfg.position_lr, 0, iteration, interpolation="exp"
                )
            if param_group["name"] == "scaling":
                param_group["lr"] = C(
                    self.cfg.scaling_lr, 0, iteration, interpolation="exp"
                )
            if param_group["name"] == "f_dc":
                param_group["lr"] = C(
                    self.cfg.feature_lr, 0, iteration, interpolation="exp"
                )
            if param_group["name"] == "f_rest":
                param_group["lr"] = (
                    C(self.cfg.feature_lr, 0, iteration, interpolation="exp") / 20.0
                )
            if param_group["name"] == "opacity":
                param_group["lr"] = C(
                    self.cfg.opacity_lr, 0, iteration, interpolation="exp"
                )
            if param_group["name"] == "rotation":
                param_group["lr"] = C(
                    self.cfg.rotation_lr, 0, iteration, interpolation="exp"
                )
            if param_group["name"] == "normal":
                param_group["lr"] = C(
                    self.cfg.normal_lr, 0, iteration, interpolation="exp"
                )
        self.color_clip = C(self.cfg.color_clip, 0, iteration)

    def reset_opacity(self):
        opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
        # opacities_new = inverse_sigmoid(self.get_opacity * 0.9)
        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
        self._opacity = optimizable_tensors["opacity"]

    def to(self, device="cpu"):
        self._xyz = self._xyz.to(device)
        self._features_dc = self._features_dc.to(device)
        self._features_rest = self._features_rest.to(device)
        self._opacity = self._opacity.to(device)
        self._scaling = self._scaling.to(device)
        self._rotation = self._rotation.to(device)
        if self.cfg.pred_normal:
            self._normal = self._normal.to(device)

    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if ("name" in group) and group["name"] == name:
                stored_state = self.optimizer.state.get(group["params"][0], None)
                # import pdb; pdb.set_trace()
                stored_state["exp_avg"] = torch.zeros_like(tensor)
                stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group["params"][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group["params"][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if ("name" in group) and (group["name"] in self.optimize_params):
                stored_state = self.optimizer.state.get(group["params"][0], None)
                if stored_state is not None:
                    stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                    stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                    del self.optimizer.state[group["params"][0]]
                    group["params"][0] = nn.Parameter(
                        (group["params"][0][mask].requires_grad_(True))
                    )
                    self.optimizer.state[group["params"][0]] = stored_state

                    optimizable_tensors[group["name"]] = group["params"][0]
                else:
                    group["params"][0] = nn.Parameter(
                        group["params"][0][mask].requires_grad_(True)
                    )
                    optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        if self.cfg.pred_normal:
            self._normal = optimizable_tensors["normal"]

        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]

        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]


    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if ("name" in group) and (group["name"] in self.optimize_params):
                extension_tensor = tensors_dict[group["name"]]
                stored_state = self.optimizer.state.get(group["params"][0], None)
                if stored_state is not None:
                    stored_state["exp_avg"] = torch.cat(
                        (stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
                        dim=0,
                    )
                    stored_state["exp_avg_sq"] = torch.cat(
                        (
                            stored_state["exp_avg_sq"],
                            torch.zeros_like(extension_tensor),
                        ),
                        dim=0,
                    )

                    del self.optimizer.state[group["params"][0]]
                    group["params"][0] = nn.Parameter(
                        torch.cat(
                            (group["params"][0], extension_tensor), dim=0
                        ).requires_grad_(True)
                    )
                    self.optimizer.state[group["params"][0]] = stored_state

                    optimizable_tensors[group["name"]] = group["params"][0]
                else:
                    group["params"][0] = nn.Parameter(
                        torch.cat(
                            (group["params"][0], extension_tensor), dim=0
                        ).requires_grad_(True)
                    )
                    optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(
        self,
        new_xyz,
        new_features_dc,
        new_features_rest,
        new_opacities,
        new_scaling,
        new_rotation,
        new_normal=None,
    ):
        d = {
            "xyz": new_xyz,
            "f_dc": new_features_dc,
            "f_rest": new_features_rest,
            "opacity": new_opacities,
            "scaling": new_scaling,
            "rotation": new_rotation,
        }
        if self.cfg.pred_normal:
            d.update({"normal": new_normal})

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        if self.cfg.pred_normal:
            self._normal = optimizable_tensors["normal"]

        self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self._xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
        
        # 验证梯度累积是否清零
        # threestudio.info(f"[Densify Postfix] xyz_gradient_accum sum: {self.xyz_gradient_accum.sum().item():.4f}, denom sum: {self.denom.sum().item():.4f}")
    @torch.no_grad()
    def densify_and_split(self, grads, grad_threshold, N=2,
                          split_scale_quantile: float = 0.90, # 保留参数名以防报错，但不再使用
                          fallback_relax: bool = True):
        
        # 1. 整理梯度格式
        if grads.dim() == 2 and grads.shape[1] == 1:
            g1 = grads.squeeze(1)
        else:
            g1 = grads.reshape(-1)
            
        padded_grad = torch.zeros((self._xyz.shape[0],), device=self._xyz.device, dtype=g1.dtype)
        padded_grad[: g1.shape[0]] = g1
        
        # [核心修改 1] 梯度初筛：找出梯度大的点
        # 注意：这个 grad_threshold 是由 YAML 传入的，稍后我们在 YAML 里把它调低
        grad_mask = padded_grad >= grad_threshold

        # 2. 获取所有点的最大尺度 (Max Scale)
        scale_max = self.get_scaling.max(dim=1).values
        
        # [核心修改 2] 移除 Quantile，改用绝对尺寸判断
        # 逻辑：只要点不是极小(>0.003)，且梯度大，就必须分裂。
        # 0.003 是根据你的日志数据(大点约0.01)设定的安全下限。
        size_threshold = 0.003 
        
        # 此外，也可以加上场景范围约束（可选）
        if hasattr(self, "cameras_extent"):
             size_threshold = max(size_threshold, 0.01 * self.cameras_extent)

        # 3. 生成分裂掩码：梯度大 AND 尺寸够大
        split_mask = grad_mask & (scale_max >= size_threshold)

        # 如果没有点需要分裂，直接返回
        if split_mask.sum() == 0:
            return

        # 4. 执行分裂 (标准的 3DGS 逻辑)
        # 在原位置附近采样 N 个新点
        stds = self.get_scaling[split_mask].repeat(N, 1)
        means = torch.zeros((stds.size(0), 3), device=self._xyz.device, dtype=stds.dtype)
        samples = torch.normal(mean=means, std=stds)
        
        rots = build_rotation(self._rotation[split_mask]).repeat(N, 1, 1)
        # 计算新坐标
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[split_mask].repeat(N, 1)
        
        # 新点的 Scale 缩小 (除以 1.6 或 N)
        new_scaling = self.scaling_inverse_activation(
            self.get_scaling[split_mask].repeat(N, 1) / (0.8 * N)
        )
        
        # 复制其他属性
        new_rotation = self._rotation[split_mask].repeat(N, 1)
        new_features_dc = self._features_dc[split_mask].repeat(N, 1, 1)
        new_features_rest = self._features_rest[split_mask].repeat(N, 1, 1)
        new_opacity = self._opacity[split_mask].repeat(N, 1)
        
        new_normal = None
        if self.cfg.pred_normal:
            new_normal = self._normal[split_mask].repeat(N, 1)

        # 将新点加入优化器
        self.densification_postfix(
            new_xyz, new_features_dc, new_features_rest,
            new_opacity, new_scaling, new_rotation, new_normal
        )

        # 5. [重要] 删除被分裂的老点
        # 既然已经分裂成了小点，原来的大点必须删掉，否则就会形成“圆圈重叠”
        prune_filter = torch.cat(
            (split_mask, torch.zeros(N * int(split_mask.sum()), device=self._xyz.device, dtype=torch.bool)),
            dim=0
        )
        self.prune_points(prune_filter)


    @torch.no_grad()
    def densify_and_clone(self, grads, grad_threshold,
                        clone_scale_quantile: float = 0.90):
        """
        Clone: high-grad AND scale is in bottom clone_scale_quantile fraction.
        e.g. clone_scale_quantile=0.90 -> clone smallest 90% scales (among grad-selected).
        This makes clone/split complementary by construction.
        """
        device = self._xyz.device
        n_init_points = self._xyz.shape[0]

        if grads.dim() == 2 and grads.shape[1] == 1:
            g1 = grads.squeeze(1)
        else:
            g1 = grads.reshape(-1)

        padded_grad = torch.zeros((n_init_points,), device=device, dtype=g1.dtype)
        padded_grad[: g1.shape[0]] = g1

        grad_mask = padded_grad >= grad_threshold
        ratio_over = grad_mask.float().mean().item()

        if not grad_mask.any():
            # threestudio.info(f"[Densify-Clone] ratio_over_thresh={ratio_over:.4%}, num_clone=0 (no grad pts)")
            return

        scale_max = self.get_scaling.max(dim=1).values.detach()
        s_sel = scale_max[grad_mask]
        # bottom q => threshold at q-quantile, keep <= threshold
        s_th = self._safe_quantile(s_sel.float(), clone_scale_quantile, default=-float("inf"))
        clone_mask = grad_mask & (scale_max <= s_th)

        num_clone = int(clone_mask.sum().item())
        # threestudio.info(
        #     f"[Densify-Clone] ratio_over_thresh={ratio_over:.4%}, "
        #     f"grad_pts={int(grad_mask.sum().item())}, "
        #     f"scale_q={clone_scale_quantile:.2f}, scale_th={s_th:.4e}, "
        #     f"num_clone={num_clone}"
        # )

        if num_clone == 0:
            return

        new_xyz = self._xyz[clone_mask]
        new_features_dc = self._features_dc[clone_mask]
        new_features_rest = self._features_rest[clone_mask]
        new_opacities = self._opacity[clone_mask]
        new_scaling = self._scaling[clone_mask]
        new_rotation = self._rotation[clone_mask]

        if self.cfg.pred_normal:
            new_normal = self._normal[clone_mask]
        else:
            new_normal = None

        self.densification_postfix(
            new_xyz, new_features_dc, new_features_rest,
            new_opacities, new_scaling, new_rotation, new_normal
        )
    # def densify_and_split(self, grads, grad_threshold, N=2,split_scale_quantile: float = 0.90,
    #                   fallback_relax: bool = True):

    #     n_init_points = self._xyz.shape[0]
    #     # Extract points that satisfy the gradient condition
    #     padded_grad = torch.zeros((n_init_points), device="cuda")
    #     padded_grad[: grads.shape[0]] = grads.squeeze()
    #     # --- 调试代码 Start ---
    #     # 计算满足梯度条件的点的最大 scaling
    #     max_scale_values = torch.max(self.get_scaling, dim=1).values
    #     mask_grad_only = torch.where(padded_grad >= grad_threshold, True, False)

    #     if mask_grad_only.any():
    #         current_max_scale = max_scale_values[mask_grad_only].max().item()
    #         current_avg_scale = max_scale_values[mask_grad_only].mean().item()
    #         threshold_val = 0.1 * self.cameras_extent
            
    #         print(f"[Debug] Threshold (0.1 * extent): {threshold_val}")
    #         print(f"[Debug] Max Scale of grad points: {current_max_scale}")
    #         print(f"[Debug] Avg Scale of grad points: {current_avg_scale}")
            
    #         if current_max_scale < threshold_val:
    #             print("[Debug] >>> 所有点都太小了，无法触发 Split！")
    #     # --- 调试代码 End ---
    #     selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
        
    #     # 统计满足梯度条件的点数比例
    #     ratio_over = selected_pts_mask.float().mean().item()
        
    #     selected_pts_mask = torch.logical_and(
    #         selected_pts_mask,
    #         torch.max(self.get_scaling, dim=1).values > 0.1*self.cameras_extent
    #     )
        
    #     # 打印分裂点数
    #     num_split_selected = selected_pts_mask.sum().item()
    #     threestudio.info(f"[Densify-Split] ratio_over_thresh={ratio_over:.4%}, num_split={num_split_selected}")

    #     # divide N to enhance robustness
    #     stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
    #     means = torch.zeros((stds.size(0), 3), device="cuda")
    #     samples = torch.normal(mean=means, std=stds)
    #     rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
    #     new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[
    #         selected_pts_mask
    #     ].repeat(N, 1)
    #     new_scaling = self.scaling_inverse_activation(
    #         self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)
    #     )
    #     new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
    #     new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
    #     new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
    #     new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
    #     if self.cfg.pred_normal:
    #         new_normal = self._normal[selected_pts_mask].repeat(N, 1)
    #     else:
    #         new_normal = None

    #     self.densification_postfix(
    #         new_xyz,
    #         new_features_dc,
    #         new_features_rest,
    #         new_opacity,
    #         new_scaling,
    #         new_rotation,
    #         new_normal,
    #     )

    #     prune_filter = torch.cat(
    #         (
    #             selected_pts_mask,
    #             torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool),
    #         )
    #     )
    #     self.prune_points(prune_filter)

    # def densify_and_clone(self, grads, grad_threshold):
    #     # Extract points that satisfy the gradient condition
    #     selected_pts_mask = torch.where(
    #         torch.norm(grads, dim=-1) >= grad_threshold, True, False
    #     )
        
    #     # 统计满足梯度条件的点数比例
    #     ratio_over = selected_pts_mask.float().mean().item()
        
    #     selected_pts_mask = torch.logical_and(
    #         selected_pts_mask,
    #         torch.max(self.get_scaling, dim=1).values <= 0.1*self.cameras_extent
    #     )
        
    #     # 打印克隆点数
    #     num_clone_selected = selected_pts_mask.sum().item()
    #     threestudio.info(f"[Densify-Clone] ratio_over_thresh={ratio_over:.4%}, num_clone={num_clone_selected}")

    #     new_xyz = self._xyz[selected_pts_mask]
    #     new_features_dc = self._features_dc[selected_pts_mask]
    #     new_features_rest = self._features_rest[selected_pts_mask]
    #     new_opacities = self._opacity[selected_pts_mask]
    #     new_scaling = self._scaling[selected_pts_mask]
    #     new_rotation = self._rotation[selected_pts_mask]
    #     if self.cfg.pred_normal:
    #         new_normal = self._normal[selected_pts_mask]
    #     else:
    #         new_normal = None

    #     self.densification_postfix(
    #         new_xyz,
    #         new_features_dc,
    #         new_features_rest,
    #         new_opacities,
    #         new_scaling,
    #         new_rotation,
    #         new_normal,
    #     )

    def densify(self, max_grad):
        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0

        # --- 2) 在 reset 之前打印：这才是真正用于 densify 的统计 ---
        with torch.no_grad():
            g = grads.squeeze(-1) if grads.dim() == 2 else grads
            denom0 = (self.denom <= 0).sum().item()
            # threestudio.info(
            #     f"[Densify-Pre] xyz_grad_sum={self.xyz_gradient_accum.sum().item():.4e}, "
            #     f"denom_sum={self.denom.sum().item():.4e}, denom0_cnt={denom0}, "
            #     f"g_max={g.max().item():.4e}, g_mean={g.mean().item():.4e}, "
            #     f"g_p90={torch.quantile(g, 0.9).item():.4e}, g_p99={torch.quantile(g, 0.99).item():.4e}"
            # )

        self.densify_and_clone(grads, max_grad)
        self.densify_and_split(grads, max_grad)

    def prune(self, min_opacity, extent,max_screen_size):
        prune_mask = (self.get_opacity < min_opacity).squeeze()
        if self.cfg.prune_big_points:
            big_points_vs = self.max_radii2D > (torch.mean(self.max_radii2D) * 3)
            prune_mask = torch.logical_or(prune_mask, big_points_vs)
        self.prune_points(prune_mask)
        # 3) 对过大的点只缩小 scale，不直接删
        if max_screen_size:
            # 重新计算当前点数
            num_points = self.get_xyz.shape[0]
            if num_points > 0:
                # --------- 新增：radii 统计打印（验证 size_threshold 是否触发）---------
                with torch.no_grad():
                    r = self.max_radii2D.detach().float()
                    r_max = r.max().item()
                    r_mean = r.mean().item()
                    r_p99 = torch.quantile(r, 0.99).item() if r.numel() > 0 else 0.0
            # -------------------------------------------------------------
                big_points_vs = self.max_radii2D > max_screen_size
                big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
                big_mask = torch.logical_or(big_points_vs, big_points_ws)
                # --------- 新增：触发数量打印 ---------
                with torch.no_grad():
                    n_vs = int(big_points_vs.sum().item())
                    n_ws = int(big_points_ws.sum().item())
                    n_big = int(big_mask.sum().item())
                    # 避免 mean=0 导致 inf
                    ratio_max_mean = (r_max / (r_mean + 1e-12)) if r_mean > 0 else float("inf")
                    # threestudio.info(
                    #     f"[PRUNE-RADII] size_th={max_screen_size} | "
                    #     f"r_max={r_max:.3f}, r_mean={r_mean:.3f}, r_p99={r_p99:.3f} | "
                    #     f"max/mean={ratio_max_mean:.2f} | "
                    #     f"big_vs={n_vs}, big_ws={n_ws}, big_any={n_big} / {num_points}"
                    # )
                # ------------------------------------

                if big_mask.any():
                    # 将这些点的 scale 缩小一半（在激活空间操作，再映射回参数空间）
                    current_scaling = self.get_scaling  # 已经过 activation 的尺度
                    new_scaling = current_scaling.clone()
                    new_scaling[big_mask] = current_scaling[big_mask] / 2.0

                    # 映射回参数空间，并替换到 optimizer 中
                    new_scaling_param = self.scaling_inverse_activation(new_scaling)
                    optimizable_tensors = self.replace_tensor_to_optimizer(new_scaling_param, "scaling")
                    self._scaling = optimizable_tensors["scaling"]
        torch.cuda.empty_cache()
    

    def random_prune(self, min_opacity, extent, max_screen_size, max_points: int):
        """
        优化后的强制压缩策略：
        1. 清洗：优先删掉【透明度低】和【尺寸过大】的垃圾点。
        2. 截断：如果清洗后依然超标，再随机删除。
        """
        device = self._xyz.device
        
        # --- 第一步：构建“必杀名单” (垃圾点) ---
        
        # 1.1 透明度过低
        prune_mask = (self.get_opacity < min_opacity).squeeze()
        
        # 1.2 尺寸过大 (屏幕空间 OR 世界空间)
        # 建议：对于过大的点，直接删除，而不是缩小。因为它们通常是伪影。
        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
            big_mask = torch.logical_or(big_points_vs, big_points_ws)
            
            # 将大点加入必杀名单
            prune_mask = torch.logical_or(prune_mask, big_mask)
            
            # [日志] 看看我们清理了多少垃圾
            # if big_mask.sum() > 0:
            #     print(f"[Random Prune] Removing {big_mask.sum()} big points.")

        # 执行删除
        if prune_mask.any():
            self.prune_points(prune_mask)
            torch.cuda.empty_cache()

        # --- 第二步：硬性截断 (随机删减) ---
        
        # 重新检查点数
        num_points = self.get_xyz.shape[0]
        
        # 如果清洗后依然超标，必须随机杀掉多余的
        if num_points > max_points:
            # 计算需要保留的数量
            # 注意：这里不需要 +100 的缓冲，直接杀到 max_points 即可，或者留一点点缓冲
            num_to_kill = num_points - max_points
            
            # 生成随机索引
            # 方法：生成一个随机排列，取前 num_to_kill 个作为删除对象
            # 或者：保留前 max_points 个 (如果顺序无关的话)
            # 3DGS 的点顺序通常无关紧要，但为了保险，用随机掩码
            
            choice = torch.randperm(num_points, device=device)
            # 标记要杀掉的索引 (choice的前 N 个)
            kill_indices = choice[:num_to_kill]
            
            # 构建 mask
            final_prune_mask = torch.zeros(num_points, dtype=torch.bool, device=device)
            final_prune_mask[kill_indices] = True
            
            self.prune_points(final_prune_mask)
            
            # [日志]
            print(f"[Random Prune] Hard cap triggered. Randomly removed {num_to_kill} points.")

        torch.cuda.empty_cache()

    def add_densification_stats(self, viewspace_point_tensor, update_filter):
        # 1. 取 screen-space 梯度（x,y）
        g = torch.norm(
            viewspace_point_tensor.grad[update_filter, :2],
            dim=-1,
            keepdim=True,
        )
        
        self.xyz_gradient_accum[update_filter] += g
        self.denom[update_filter] += 1
    # def add_densification_stats(self, viewspace_point_tensor, update_filter):
    #     # 1. 取 screen-space 梯度（x,y）
    #     g = torch.norm(
    #         viewspace_point_tensor.grad[update_filter, :2],
    #         dim=-1,
    #         keepdim=True,
    #     )
    #     # ===== DEBUG: 打印原始梯度分布（低频）=====
    #     # 只在前 200 step，每 50 step 打印一次，防止刷屏
    #     if (
    #         hasattr(self, "global_step")
    #         and self.global_step < 600
    #         and self.global_step % 50 == 0
    #         and g.numel() > 0
    #     ):
    #         with torch.no_grad():
    #             g_flat = g.detach().flatten()
    #             print(
    #                 f"[DENSIFY-DEBUG][step {self.global_step}] "
    #                 f"raw_grad | "
    #                 f"mean={g_flat.mean():.4e}, "
    #                 f"max={g_flat.max():.4e}, "
    #                 f"p90={torch.quantile(g_flat, 0.9):.4e}, "
    #                 f"p99={torch.quantile(g_flat, 0.99):.4e}"
    #             )

    #     # 2. 梯度裁剪（防止 SDS + 大 guidance 爆炸）
    #     # 经验安全值：0.03 ~ 0.1，质量优先建议 0.05
    #     g = torch.clamp(g, max=0.05)

    #     # 3. 归一化（让 densify 看“相对重要性”，而不是绝对尺度）
    #     g = g / (g.mean() + 1e-6)

    #     # 4. 累计统计量
    #     self.xyz_gradient_accum[update_filter] += g
    #     self.denom[update_filter] += 1

    # @torch.no_grad()
    # def update_states(
    #     self,
    #     iteration,
    #     visibility_filter,
    #     radii,
    #     viewspace_point_tensor,
    # ):
    #     if self._xyz.shape[0] >= self.cfg.max_num + 100:
    #         prune_mask = torch.randperm(self._xyz.shape[0]).to(self._xyz.device)
    #         prune_mask = prune_mask > self.cfg.max_num
    #         self.prune_points(prune_mask)
    #         return
    #     # Keep track of max radii in image-space for pruning
    #     # loop over batch
    #     bs = len(viewspace_point_tensor)
    #     for i in range(bs):
    #         radii_i = radii[i]
    #         visibility_filter_i = visibility_filter[i]
    #         viewspace_point_tensor_i = viewspace_point_tensor[i]
    #         self.max_radii2D = torch.max(self.max_radii2D, radii_i.float())

    #         self.add_densification_stats(viewspace_point_tensor_i, visibility_filter_i)

    #     if (
    #         iteration > self.cfg.prune_from_iter
    #         and iteration < self.cfg.prune_until_iter
    #         and iteration % self.cfg.prune_interval == 0
    #     ):
    #         self.prune(self.cfg.min_opac_prune, self.cfg.radii2d_thresh)
    #         if iteration % self.cfg.opacity_reset_interval == 0:
    #             self.reset_opacity()

    #     if (
    #         iteration > self.cfg.densify_from_iter
    #         and iteration < self.cfg.densify_until_iter
    #         and iteration % self.cfg.densification_interval == 0
    #     ):
    #         self.densify(self.cfg.densify_grad_threshold)
    @torch.no_grad()
    def update_states(
        self,
        iteration: int,
        visibility_filter,
        radii,
        viewspace_point_tensor,
    ):
        """
        Update densification/pruning statistics and perform densify/prune.
        Strategy:
        1) Update stats FIRST (must use current point set; do NOT change topology before this).
        2) Run scheduled prune (delete low-opacity / big points).
        3) Run scheduled densify ONLY IF not over capacity.
        4) Finally, enforce hard max_num with a gentle random_prune (cap).
        """
        
        # Reset stats count
        self._densify_stats_count = 0

        device = self._xyz.device
        N0 = int(self._xyz.shape[0])
        max_num = int(self.cfg.max_num)

        # buffer to avoid oscillation: only hard-cap when exceeding this
        cap_buffer_ratio = 1.05
        cap_soft = int(max_num * cap_buffer_ratio)

        # -------------------------
        # 1) Update stats (NO topology change before this)
        # -------------------------
        # Keep track of max radii in image-space for pruning + densification stats
        bs = len(viewspace_point_tensor)
        for i in range(bs):
            radii_i = radii[i]
            visibility_filter_i = visibility_filter[i]
            vpts_i = viewspace_point_tensor[i]

            # max_radii2D is per-point, so must match current N
            # (we assume radii_i shape matches current point count)
            self.max_radii2D[visibility_filter_i] = torch.max(self.max_radii2D[visibility_filter_i], radii_i.float()[visibility_filter_i])

            # accumulate gradient stats for densify
            self.add_densification_stats(vpts_i, visibility_filter_i)
            
            # 每隔 100 步打印一次梯度统计信息
            if iteration % 100 == 0:
                g = torch.norm(vpts_i.grad[visibility_filter_i, :2], dim=-1)
                if g.numel() > 0:
                    with torch.no_grad():
                        g_flat = g.detach().flatten()
                        q50 = torch.quantile(g_flat, 0.5)
                        q90 = torch.quantile(g_flat, 0.9)
                        q99 = torch.quantile(g_flat, 0.99)

                        # threestudio.info(
                        #     f"[GRAD-STATS][Step {iteration}] num_pts={g.numel()} | "
                        #     f"min={g_flat.min():.4e}, "
                        #     f"max={g_flat.max():.4e}, "
                        #     f"mean={g_flat.mean():.4e}, "
                        #     f"std={g_flat.std():.4e}, "
                        #     f"p50={q50:.4e}, "
                        #     f"p90={q90:.4e}, "
                        #     f"p99={q99:.4e}"
                        # )
        if iteration == 1000:
            self.oneupSHdegree()
        elif iteration == 2000:
            self.oneupSHdegree()     
        elif iteration == 3000:
            self.oneupSHdegree()
        
        # 打印 stats 调用次数
        # if iteration % 100 == 0:
        #     count = getattr(self, "_densify_stats_count", 0)
        #     threestudio.info(f"[Step {iteration}] add_densification_stats called {count} times")

        # recompute after stats (still same topology)
        N_after_stats = int(self._xyz.shape[0])

        size_threshold = 20 if iteration > self.cfg.densify_from_iter else None
        extent = getattr(self, "cameras_extent", 4.0)
        # -------------------------
        # 2) Scheduled prune (topology may change AFTER this point)
        # -------------------------
        do_prune = (
            iteration > self.cfg.prune_from_iter
            and iteration < self.cfg.prune_until_iter
            and iteration % self.cfg.prune_interval == 0
        )
        if do_prune:
            self.prune(self.cfg.min_opac_prune, extent, size_threshold)

            # opacity reset (your cfg is huge by default; keep logic but it may never trigger)
            if iteration % self.cfg.opacity_reset_interval == 0:
                self.reset_opacity()

        # -------------------------
        # 3) Scheduled densify (skip if already near/over capacity)
        # -------------------------
        N_before_densify = int(self._xyz.shape[0])

        # If we are already beyond max_num, or close to it, disable densify to avoid blow-up
        over_hard = N_before_densify > max_num
        near_soft = N_before_densify > cap_soft  # already too many -> cap soon
        # You can make this stricter if needed, e.g. > 0.98*max_num
        near_cap = N_before_densify > int(max_num * 0.98)

        do_densify = (
            (not over_hard)  # hard rule
            and (not near_soft)  # don't densify when already above soft cap
            and (not near_cap)  # conservative: don't densify when very close to max
            and iteration > self.cfg.densify_from_iter
            and iteration < self.cfg.densify_until_iter
            and iteration % self.cfg.densification_interval == 0
        )
        if do_densify:
            self.densify(self.cfg.densify_grad_threshold)

        # -------------------------
        # 4) Final hard cap: enforce max_num (gentle)
        # -------------------------
        N_final = int(self._xyz.shape[0])
        if N_final > cap_soft:
            # extent is needed for "big point" detection in random_prune
            # prefer a real extent if you have it; fallback to 4.0
            

            self.random_prune(
                min_opacity=self.cfg.min_opac_prune,
                extent=extent,
                max_screen_size=size_threshold,
                max_points=max_num,
            )
            # After this, topology changed; DO NOT use the old visibility/radii tensors anymore.

        torch.cuda.empty_cache()

    def _safe_quantile(self, x: torch.Tensor, q: float, default: float = 0.0) -> float:
        """x: 1D float tensor on cuda"""
        if x.numel() == 0:
            return default
        q = float(min(max(q, 0.0), 1.0))
        return torch.quantile(x, q).item()