# 12/18 修改了sh_degree的更新方式，使用安全方法
# 12/19 修改了loss的计算方式，使用分开backward的方式

import os
import math
import torch
import numpy as np
import torch.nn.functional as F
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field

from gaussiansplatting.gaussian_renderer import render
from gaussiansplatting.scene import Scene, GaussianModel
from gaussiansplatting.arguments import ModelParams, PipelineParams, get_combined_args,OptimizationParams
from gaussiansplatting.scene.cameras import Camera
from gaussiansplatting.utils.sh_utils import SH2RGB
from gaussiansplatting.scene.gaussian_model import BasicPointCloud

import threestudio

from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy
from threestudio.utils.typing import *
from threestudio.utils.loss import tv_loss


@threestudio.register("gaussiandreamer-system-1219")
class GaussianDreamer(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        radius: float = 4
        sh_degree: int = 0
        dreambooth: int = 0
        load_type: int = 0
        load_path: str = "./load/shapes/stand.obj"
        loss_type: str = "sds"
        back_ground_color: Tuple[float, float, float] = (1, 1, 1)


    # ================== 初始化相关 ==========================
    cfg: Config
    def configure(self):
        self.loss_type = self.cfg.loss_type
        self.dreambooth = self.cfg.dreambooth
        self.radius = self.cfg.radius
        self.sh_degree =self.cfg.sh_degree
        self.load_type =self.cfg.load_type
        self.load_path = self.cfg.load_path
        

        self.gaussian = GaussianModel(sh_degree = self.sh_degree)
        bg_color = [1, 1, 1] if False else [0, 0, 0]
        self.background_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
        self.automatic_optimization = False
        # self.background = threestudio.find(self.cfg.background_type)(self.cfg.background)
        # self.background_tensor = torch.tensor(
        #     self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
        # )

    def pcd_init(self):
        """加载点云数据, 并处理成3dgs的点云格式"""
        # Since this data set has no colmap data, we start with random points
        if self.load_type== 4: # shap_e
            from threestudio.systems.function.point_cloud import load_from_shape
            coords,rgb = load_from_shape(self.load_path)
        elif self.load_type == 1: # pcd
            from threestudio.systems.function.point_cloud import load_from_pcd
            coords,rgb = load_from_pcd(self.load_path)
        elif self.load_type == 2: # smpl
            from threestudio.systems.function.point_cloud import load_from_smpl
            coords,rgb = load_from_smpl(self.load_path)
        elif self.load_type == 3: # 3dgs
            from threestudio.systems.function.point_cloud import load_from_3dgs
            coords,rgb = load_from_3dgs(self.load_path)
        elif self.load_type == 0: # vggt
            from threestudio.systems.function.point_cloud import load_from_vggt
            save_path = self.get_save_path('instance_images/')
            coords,rgb = load_from_vggt(self.cfg, save_path)
        else:
            raise NotImplementedError(f"load_type {self.load_type} is not implemented, only support [0: shap_e, 1: pcd, 2: smpl, 3: 3dgs]")
        
        # 先中心化到原点
        coords -= np.mean(coords, axis=0)
        # 计算每个轴的范围（最大值-最小值）
        x_length = np.max(coords[:,0]) - np.min(coords[:,0])
        y_length = np.max(coords[:,1]) - np.min(coords[:,1])
        z_length = np.max(coords[:,2]) - np.min(coords[:,2])
        # 找到最长的轴，等比例缩放使得最长轴的范围是[-1, 1]
        max_length = max(x_length, y_length, z_length) / 2
        if max_length > 0:
            coords /= max_length
        pcd = BasicPointCloud(points=coords, colors=rgb, normals=np.zeros((coords.shape[0], 3)))
        return pcd
    
    def on_fit_start(self):
        super().on_fit_start()

        # 根据mvadapter来进行微调
        if self.cfg.dreambooth == 1:
            cmd = [
                "python", "threestudio/systems/function/dreambooth_full.py",
                "--pretrained_model_name_or_path", self.cfg.guidance.pretrained_model_name_or_path,
                "--enable_xformers_memory_efficient_attention",
                "--with_prior_preservation",
                "--instance_data_dir", self.get_save_path('instance_images/'),
                # "--instance_data_dir", "../data/gaussian_dreamer/multi_view_images/horse_images",
                "--instance_prompt", self.cfg.prompt_processor.dreambooth_prompt,
                "--class_data_dir", self.get_save_path('class_samples/'),
                "--class_prompt", self.cfg.prompt_processor.prompt.replace('<kth>', ''),
                "--num_class_images", "100",
                "--validation_prompt", self.cfg.prompt_processor.prompt,
                "--output_dir", self.get_save_path('personalization/'),
                "--max_train_steps", str(400),
                "--train_batch_size", str(4),
                "--gradient_accumulation_steps", str(1),
                # "--mixed_precision", "fp16"
            ]
            # 执行dreambooth训练
            import subprocess
            try:
                result = subprocess.run(cmd, check=True)
                print("DreamBooth训练成功完成")
                print(result.stdout)
                if result.returncode != 0:
                    raise subprocess.CalledProcessError(result.returncode, cmd, result.stdout, result.stderr)
            except subprocess.CalledProcessError as e:
                print("DreamBooth训练失败")
                print(f"错误输出: {e.stderr}")
                print(f"返回码: {e.returncode}")
        
            # only used in training
            self.cfg.prompt_processor.pretrained_model_name_or_path = self.get_save_path('personalization/')
            self.cfg.guidance.pretrained_model_name_or_path = self.get_save_path('personalization/')
            if self.loss_type == "vsd":
                self.cfg.guidance.pretrained_model_name_or_path_lora = self.get_save_path('personalization/')
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)

        # DreamBooth 微调模型在此时已经加载到内存中，可以安全删除磁盘上的权重目录以节省空间
        if self.cfg.dreambooth == 1:
            import shutil
            personalization_dir = self.get_save_path("personalization/")
            if os.path.isdir(personalization_dir):
                try:
                    shutil.rmtree(personalization_dir)
                    print(f"[INFO] Removed DreamBooth personalization dir: {personalization_dir}")
                except Exception as e:
                    print(f"[WARN] Failed to remove personalization dir {personalization_dir}: {e}")
        
        # 优化器
        # if self.loss_type == "vsd":
        #     optimizer = self.gaussian.optimizer
        #     new_param_group = {
        #         'params': [p for p in self.guidance.submodules.pipe_lora.unet.parameters() if p.requires_grad],
        #         'lr': 1e-4,          
        #         'weight_decay': 1e-2
        #     }
        #     optimizer.add_param_group(new_param_group)

    def configure_optimizers(self):
        self.parser = ArgumentParser(description="Training script parameters")
        
        opt = OptimizationParams(self.parser)
        point_cloud = self.pcd_init()
        self.cameras_extent = 4.0
        self.gaussian.create_from_pcd(point_cloud, self.cameras_extent)
        
        # 初始化 active_sh_degree 为 1
        try:
            self.gaussian.set_sh_degree(1)
            print(f"[INFO] Initialized active_sh_degree to 1")
        except ValueError as e:
            print(f"[WARNING] Failed to initialize active_sh_degree to 1: {e}")

        # 检查转换出的初始3dgs
        save_path = self.get_save_path(f"init_3dgs.ply")
        self.gaussian.save_ply(save_path)
        # 保存转换到rgb空间的点云
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"init-color.ply"), coords, rgb)
        # 准备训练
        self.pipe = PipelineParams(self.parser)

        # 用初始化的 3DGS，沿一圈渲染几张图（相当于在 pcd init 之后立刻 test）
        self.initial_views(num_views=64)

        self.gaussian.training_setup(opt)
        ret = {
            "optimizer": self.gaussian.optimizer,
        }
        return ret
    
    # ================== 迭代相关 ==========================
    def forward(self, batch: Dict[str, Any], renderbackground=None) -> Dict[str, Any]:
        if renderbackground is None:
            renderbackground = self.background_tensor

        images, depths = [], []
        self.viewspace_point_list = []
        self.vis_list = []
        self.radii_list = []

        # 用于“全局可见性”的聚合（方便别处用）
        radii_max = None

        for i in range(batch["c2w_3dgs"].shape[0]):
            viewpoint_cam = Camera(
                c2w=batch["c2w_3dgs"][i],
                FoVy=batch["fovy"][i],
                height=batch["height"],
                width=batch["width"],
            )
            render_pkg_i = render(viewpoint_cam, self.gaussian, self.pipe, renderbackground)
            
            image_i = render_pkg_i["render"]                # [3,H,W]
            depth_i = render_pkg_i["depth_3dgs"]            # [1,H,W] or [H,W] 取决于实现
            vsp_i   = render_pkg_i["viewspace_points"]      # [N,3] leaf, grad 会回传到这里
            if vsp_i.requires_grad:
                vsp_i.retain_grad()
            vis_i   = render_pkg_i["visibility_filter"]     # [N] bool
            radii_i = render_pkg_i["radii"]                 # [N]

            # --- 存 per-view，用于 densify stats（很关键） ---
            self.viewspace_point_list.append(vsp_i)
            self.vis_list.append(vis_i)
            self.radii_list.append(radii_i)

            # --- 聚合一个 max radii，用于你现有的 self.visibility_filter 逻辑 ---
            radii_max = radii_i if radii_max is None else torch.max(radii_max, radii_i)

            # --- 图像堆叠 ---
            images.append(image_i.permute(1, 2, 0))         # -> [H,W,3]
            depths.append(depth_i.permute(1, 2, 0))         # -> [H,W,1]（确保一致）

        images = torch.stack(images, 0)   # [B,H,W,3]
        depths = torch.stack(depths, 0)   # [B,H,W,1]

        self.radii = radii_max
        self.visibility_filter = radii_max > 0.0

        # 用最后一个 render_pkg_i 作为基底（跟你原来一样）
        render_pkg_i["comp_rgb"] = images
        render_pkg_i["comp_depth"] = depths
        render_pkg_i["opacity"] = depths / (depths.max() + 1e-5)

        return {**render_pkg_i}

    
    def training_step(self, batch, batch_idx):
        # 动态更新 sh_degree：初始为1，每隔1/3最大步数后加1
        max_steps = self.trainer.max_steps if hasattr(self.trainer, 'max_steps') and self.trainer.max_steps else 1200
        step_third = max_steps / 3.0
        
        # 从1开始，每隔1/3步数加1
        target_sh_degree = 1 + int(self.true_global_step / step_third)
        target_sh_degree = min(target_sh_degree, self.gaussian.max_sh_degree)  # 不超过最大sh_degree
        
        # 只有当目标sh_degree大于当前active_sh_degree时才更新（使用安全方法）
        if target_sh_degree > self.gaussian.active_sh_degree:
            try:
                self.gaussian.set_sh_degree(target_sh_degree)
                print(f"[INFO] Step {self.true_global_step}: Updated active_sh_degree to {target_sh_degree}")
            except ValueError as e:
                print(f"[WARNING] Step {self.true_global_step}: Failed to update active_sh_degree: {e}")
        
        self.gaussian.update_learning_rate(self.true_global_step)
        if self.true_global_step > 500:
            self.guidance.set_min_max_steps(min_step_percent=0.02, max_step_percent=0.55)

        out = self(batch) 

        prompt_utils = self.prompt_processor()
        images = out["comp_rgb"]
        # ============ [插入这里] ============
        guidance_eval = (self.true_global_step % 100 == 0)        
        guidance_out = self.guidance(
            images, prompt_utils, **batch, rgb_as_latents=False,guidance_eval=guidance_eval
        )
        
        if self.loss_type == "sds":
            loss = self.compute_total_loss(guidance_out,out)
            if guidance_eval:
                # ================== [新增] 打印高斯点数量 ==================
                num_points = self.gaussian.get_xyz.shape[0]
                print(f"\n[INFO] Total Gaussian Points: {num_points}")
                self.guidance_evaluation_save(
                    out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
                    guidance_out["eval"],
                )
        elif self.loss_type == "vsd":
            loss = self.compute_vsd_loss(guidance_out,out)
            if guidance_eval:
                if "sample_image" in guidance_out:
                    imgs = guidance_out["sample_image"]
                    imgs = imgs.permute(0, 2, 3, 1).contiguous()
                    for i in range(imgs.shape[0]):
                        self.save_image_grid(
                            f"it{self.true_global_step}-{i}.png",
                            [{"type": "rgb", "img": imgs[i], "kwargs": {"data_format": "HWC"}}],
                            name="guidance_sample",
                            step=self.true_global_step,
                        )
        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))
        return {"loss": loss}

    # ================== 验证测试相关 ==========================
    def validation_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.true_global_step}-{batch['index'][0]}.png",
            (
                [
                    {
                        "type": "rgb",
                        "img": batch["rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    }
                ]
                if "rgb" in batch
                else []
            )
            + [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            ),
            name="validation_step",
            step=self.true_global_step,
        )
        # save_path = self.get_save_path(f"it{self.true_global_step}-val.ply")
        # self.gaussian.save_ply(save_path)
        # load_ply(save_path,self.get_save_path(f"it{self.true_global_step}-val-color.ply"))

    def initial_views(self, num_views: int = None):
        """
        完全按照 RandomCameraDataset 的方式，
        使用 numpy 生成一圈球面视角, 保存至init_views文件夹
        """
        from threestudio.data.uncond import pose_spherical
        device = torch.device("cuda")

        # ====== 1. 使用与 Dataset 一致的参数 ======
        if num_views is None:
            num_views = getattr(self.cfg, "n_test_views", 8)

        azimuth_deg = np.linspace(0.0, 360.0, num_views, endpoint=False)
        elevation_deg = np.full_like(azimuth_deg, float(15))
        camera_distances = np.full_like(azimuth_deg, float(4))
        fovy_deg = np.full_like(azimuth_deg, float(70))

        # ====== 2. 构造 3DGS 的 c2w_3dgs （完全照 Dataset） ======
        c2w_list = []

        for i in range(num_views):
            # >>> 和 RandomCameraDataset 完全一致 <<<
            theta = float(azimuth_deg[i] + 180.0 - self.load_type * 90.0)
            phi   = float(-elevation_deg[i])
            radius = float(camera_distances[i])

            render_pose = pose_spherical(theta, phi, radius)  # numpy 4×4

            # 转 torch 继续下面步骤
            render_pose_t = torch.tensor(render_pose, dtype=torch.float32, device=device)

            # === 以下完全照 RandomCameraDataset ===
            matrix = torch.linalg.inv(render_pose_t)
            R = -torch.transpose(matrix[:3, :3], 0, 1)
            R[:, 0] = -R[:, 0]
            T = -matrix[:3, 3]

            c2w = torch.cat([R, T[:, None]], dim=1)
            bottom = torch.tensor([[0, 0, 0, 1]], dtype=torch.float32, device=device)
            c2w = torch.cat([c2w, bottom], dim=0)

            c2w_list.append(c2w)

        c2w_3dgs = torch.stack(c2w_list, dim=0)  # [B, 4, 4]

        # ===== 3. 渲染部分 =====
        H = 512
        W = 512
        bg = self.background_tensor

        fovy = torch.tensor(fovy_deg * math.pi / 180.0, device=device, dtype=torch.float32)

        for i in range(num_views):
            cam = Camera(
                c2w=c2w_3dgs[i],
                FoVy=fovy[i],
                height=H,
                width=W,
            )

            render_pkg = render(cam, self.gaussian, self.pipe, bg)
            img = render_pkg["render"]  # [3, H, W]

            self.save_image_grid(
                f"init_views/{i:03d}.png",
                [
                    {
                        "type": "rgb",
                        "img": img.permute(1, 2, 0),
                        "kwargs": {"data_format": "HWC"},
                    }
                ],
                name="init_test",
                step=i,
            )
        
        # 生成 MP4 视频
        self.save_img_sequence(
            "init_views.mp4",
            "init_views",
            r"(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="init_views",
            step=0,
        )
    
    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        only_rgb = True
        bg_color = [1, 1, 1] if False else [0, 0, 0]

        testbackground_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        out = self(batch,testbackground_tensor)
        if only_rgb:
            self.save_image_grid(
                f"it{self.true_global_step}-test/{batch['index'][0]}.png",
                (
                    [
                        {
                            "type": "rgb",
                            "img": batch["rgb"][0],
                            "kwargs": {"data_format": "HWC"},
                        }
                    ]
                    if "rgb" in batch
                    else []
                )
                + [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                + (
                    [
                        {
                            "type": "rgb",
                            "img": out["comp_normal"][0],
                            "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                        }
                    ]
                    if "comp_normal" in out
                    else []
                ),
                name="test_step",
                step=self.true_global_step,
            )
        else:
            self.save_image_grid(
                f"it{self.true_global_step}-test/{batch['index'][0]}.png",
                (
                    [
                        {
                            "type": "rgb",
                            "img": batch["rgb"][0],
                            "kwargs": {"data_format": "HWC"},
                        }
                    ]
                    if "rgb" in batch
                    else []
                )
                + [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                + (
                    [
                        {
                            "type": "rgb",
                            "img": out["comp_normal"][0],
                            "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                        }
                    ]
                    if "comp_normal" in out
                    else []
                )
                + (
                    [
                        {
                            "type": "grayscale",
                            "img": out["depth"][0],
                            "kwargs": {},
                        }
                    ]
                    if "depth" in out
                    else []
                )
                + [
                    {
                        "type": "grayscale",
                        "img": out["opacity"][0, :, :, 0],
                        "kwargs": {"cmap": None, "data_range": (0, 1)},
                    },
                ],
                name="test_step",
                step=self.true_global_step,
            )

    def on_test_epoch_end(self):
        self.save_img_sequence(
            f"it{self.true_global_step}-test",
            f"it{self.true_global_step}-test",
            "(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="test",
            step=self.true_global_step,
        )
        # ================== [新增] 打印高斯点数量 ==================
        num_points = self.gaussian.get_xyz.shape[0]
        print(f"\n[INFO] Total Gaussian Points: {num_points}")
        save_path = self.get_save_path(f"last_3dgs.ply")
        self.gaussian.save_ply(save_path)
        # 保存转换到rgb空间的点云
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"it{self.true_global_step}-test-color.ply"), coords, rgb)
    

    def compute_total_loss(self, guidance_out, out):
        """
        计算总损失，并进行backward和优化
        """
        opt = self.optimizers()

        # ========= Stage 1: SDS =========
        loss_sds = guidance_out["loss_sds"] * self.C(self.cfg.loss["lambda_sds"])
        self.log("train/loss_sds", guidance_out["loss_sds"])
        loss_sds.backward(retain_graph=True)

        # 只累计 densify stats（不做 densify/prune）
        self.collect_densify_stats_only()

        # ========= Stage 2: Regularization =========
        loss_reg = self.compute_regularization_loss(out)
        if loss_reg is not None and torch.is_tensor(loss_reg) and loss_reg.requires_grad:
            loss_reg.backward()

        # ========= Optimizer step =========
        opt.step()
        opt.zero_grad(set_to_none=True)

        # ========= Post-step: densify/prune =========
        self.apply_densify_prune_post_step()

        return loss_sds


    def compute_regularization_loss(self, out):
        """
        Regularization loss for Gaussian parameters.
        This loss MUST NOT be used for densification statistics.
        """
        loss = 0.0

        # ========== 1. position regularization ==========
        if self.cfg.loss["lambda_position"] > 0.0:
            xyz_norm = self.gaussian.get_xyz.norm(dim=-1)
            loss_position = xyz_norm.mean()
            self.log("train/loss_position", loss_position)
            loss += self.C(self.cfg.loss["lambda_position"]) * loss_position

        # ========== 2. opacity regularization ==========
        if self.cfg.loss["lambda_opacity"] > 0.0:
            scaling = self.gaussian.get_scaling.norm(dim=-1)
            loss_opacity = (
                scaling.detach().unsqueeze(-1) * self.gaussian.get_opacity
            ).sum()
            self.log("train/loss_opacity", loss_opacity)
            loss += self.C(self.cfg.loss["lambda_opacity"]) * loss_opacity

        # ========== 3. scale regularization ==========
        if self.cfg.loss["lambda_scales"] > 0.0:
            scale_sum = torch.sum(self.gaussian.get_scaling)
            self.log("train/scales", scale_sum)
            loss += self.C(self.cfg.loss["lambda_scales"]) * scale_sum

        # ========== 4. RGB TV loss ==========
        if self.cfg.loss["lambda_tv_loss"] > 0.0:
            loss_tv = tv_loss(out["comp_rgb"].permute(0, 3, 1, 2))
            self.log("train/loss_tv", loss_tv)
            loss += self.C(self.cfg.loss["lambda_tv_loss"]) * loss_tv

        # ========== 5. depth / normal TV loss ==========
        if (
            "comp_depth" in out
            and self.cfg.loss["lambda_depth_tv_loss"] > 0.0
        ):
            loss_depth_tv = (tv_loss(out["comp_depth"].permute(0, 3, 1, 2)))
            self.log("train/loss_depth_tv", loss_depth_tv)
            loss += self.C(self.cfg.loss["lambda_depth_tv_loss"]) * loss_depth_tv

        self.log("train/loss_reg_total", loss)
        return loss


    def compute_vsd_loss(self,guidance_out,out):
        loss = 0.0
        for name, value in guidance_out.items():
            if not (type(value) is torch.Tensor and value.numel() > 1):
                self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])


        loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity", loss_sparsity)
        loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
        self.log("train/loss_opaque", loss_opaque)
        loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)
        return loss

    @torch.no_grad()
    def collect_densify_stats_only(self):
        bs = len(self.viewspace_point_list)
        for i in range(bs):
            vsp = self.viewspace_point_list[i]
            vis = self.vis_list[i]

            # 1. 安全检查：如果 vsp.grad 还没回传过来（比如backward前），直接跳过
            if vsp.grad is None:
                continue

            # 2. 正常更新 max_radii2D
            self.gaussian.max_radii2D = torch.max(
                self.gaussian.max_radii2D, self.radii_list[i].float()
            )
            
            # 3. 调用你新增的正确函数 (累加 .grad)
            # 假设你已经把 add_densification_stats_grad 加到了 GaussianModel 里
            self.gaussian.add_densification_stats_grad(vsp, vis)


    @torch.no_grad()
    def apply_densify_prune_post_step(self):
        # 如果 trainer 没有 max_steps，默认 1200
        max_steps = self.trainer.max_steps if (hasattr(self.trainer, 'max_steps') and self.trainer.max_steps) else 1200
        
        # 你的配置逻辑
        start_densify_iter = 0
        densify_until_iter = int(max_steps * 0.75) # 例如 900 步停止生长
        densify_interval = 100
        max_points = 800_000

        # 判断是否在生长区间内
        if (
            self.true_global_step >= start_densify_iter
            and self.true_global_step % densify_interval == 0
            and self.true_global_step < densify_until_iter
        ):
            # =========== [关键配置] ===========
            # 始终保持 0.0002 (2e-4)，这是 SDS 最佳实践
            grad_threshold = 0.0002 
            
            # 后期限制分裂尺寸，防止大球无限分裂
            size_threshold = 20 if self.true_global_step > 500 else None
            
            # 只保留一条清爽的 Log，确认它在工作即可
            print(f"[INFO] Step {self.true_global_step}: Densifying... (Threshold: {grad_threshold})")
            
            # 执行分裂
            self.gaussian.densify(grad_threshold, self.cameras_extent)
            
            # 执行剪枝
            self.gaussian.random_prune(0.05, self.cameras_extent, size_threshold, max_points)
            