# 12/18 修改了sh_degree的更新方式，使用安全方法


import os
import math
import torch
import numpy as np
import torch.nn.functional as F
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field

from gaussiansplatting.gaussian_renderer import render
from gaussiansplatting.scene import Scene, GaussianModel
from gaussiansplatting.arguments import ModelParams, PipelineParams, get_combined_args,OptimizationParams
from gaussiansplatting.scene.cameras import Camera
from gaussiansplatting.utils.sh_utils import SH2RGB
from gaussiansplatting.scene.gaussian_model import BasicPointCloud

import threestudio
# from threestudio.utils.loss import tv_loss
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy
from threestudio.utils.typing import *


@threestudio.register("gaussiandreamer-system-1208")
class GaussianDreamer(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        radius: float = 4
        sh_degree: int = 0
        dreambooth: int = 0
        load_type: int = 0
        load_path: str = "./load/shapes/stand.obj"
        loss_type: str = "sds"
        back_ground_color: Tuple[float, float, float] = (1, 1, 1)


    # ================== 初始化相关 ==========================
    cfg: Config
    def configure(self):
        self.loss_type = self.cfg.loss_type
        self.dreambooth = self.cfg.dreambooth
        self.radius = self.cfg.radius
        self.sh_degree =self.cfg.sh_degree
        self.load_type =self.cfg.load_type
        self.load_path = self.cfg.load_path
        

        self.gaussian = GaussianModel(sh_degree = self.sh_degree)
        # bg_color = [1, 1, 1] if False else [0, 0, 0]
        # self.background_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
        self.background = threestudio.find(self.cfg.background_type)(self.cfg.background)
        self.background_tensor = torch.tensor(
            self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
        )

    def pcd_init(self):
        """加载点云数据, 并处理成3dgs的点云格式"""
        # Since this data set has no colmap data, we start with random points
        if self.load_type== 4: # shap_e
            from threestudio.systems.function.point_cloud import load_from_shape
            coords,rgb = load_from_shape(self.load_path)
        elif self.load_type == 1: # pcd
            from threestudio.systems.function.point_cloud import load_from_pcd
            coords,rgb = load_from_pcd(self.load_path)
        elif self.load_type == 2: # smpl
            from threestudio.systems.function.point_cloud import load_from_smpl
            coords,rgb = load_from_smpl(self.load_path)
        elif self.load_type == 3: # 3dgs
            from threestudio.systems.function.point_cloud import load_from_3dgs
            coords,rgb = load_from_3dgs(self.load_path)
        elif self.load_type == 0: # vggt
            from threestudio.systems.function.point_cloud import load_from_vggt
            save_path = self.get_save_path('instance_images/')
            coords,rgb = load_from_vggt(self.cfg, save_path)
        else:
            raise NotImplementedError(f"load_type {self.load_type} is not implemented, only support [0: shap_e, 1: pcd, 2: smpl, 3: 3dgs]")
        
        # 先中心化到原点
        coords -= np.mean(coords, axis=0)
        # 计算每个轴的范围（最大值-最小值）
        x_length = np.max(coords[:,0]) - np.min(coords[:,0])
        y_length = np.max(coords[:,1]) - np.min(coords[:,1])
        z_length = np.max(coords[:,2]) - np.min(coords[:,2])
        # 找到最长的轴，等比例缩放使得最长轴的范围是[-1, 1]
        max_length = max(x_length, y_length, z_length) / 2
        if max_length > 0:
            coords /= max_length
        pcd = BasicPointCloud(points=coords, colors=rgb, normals=np.zeros((coords.shape[0], 3)))
        return pcd
    
    def on_fit_start(self):
        super().on_fit_start()

        # 根据mvadapter来进行微调
        if self.cfg.dreambooth == 1:
            cmd = [
                "python", "threestudio/systems/function/dreambooth_full.py",
                "--pretrained_model_name_or_path", self.cfg.guidance.pretrained_model_name_or_path,
                "--enable_xformers_memory_efficient_attention",
                "--with_prior_preservation",
                "--instance_data_dir", self.get_save_path('instance_images/'),
                "--instance_prompt", self.cfg.prompt_processor.dreambooth_prompt,
                "--class_data_dir", self.get_save_path('class_samples/'),
                "--class_prompt", self.cfg.prompt_processor.prompt.replace('<kth>', ''),
                "--num_class_images", "100",
                "--validation_prompt", self.cfg.prompt_processor.prompt,
                "--output_dir", self.get_save_path('personalization/'),
                "--max_train_steps", str(400),
                "--train_batch_size", str(4),
                "--gradient_accumulation_steps", str(1),
                # "--mixed_precision", "fp16"
            ]
            # 执行dreambooth训练
            import subprocess
            try:
                result = subprocess.run(cmd, check=True)
                print("DreamBooth训练成功完成")
                print(result.stdout)
                if result.returncode != 0:
                    raise subprocess.CalledProcessError(result.returncode, cmd, result.stdout, result.stderr)
            except subprocess.CalledProcessError as e:
                print("DreamBooth训练失败")
                print(f"错误输出: {e.stderr}")
                print(f"返回码: {e.returncode}")
        
            # only used in training
            self.cfg.prompt_processor.pretrained_model_name_or_path = self.get_save_path('personalization/')
            self.cfg.guidance.pretrained_model_name_or_path = self.get_save_path('personalization/')
            if self.loss_type == "vsd":
                self.cfg.guidance.pretrained_model_name_or_path_lora = self.get_save_path('personalization/')
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)

        # DreamBooth 微调模型在此时已经加载到内存中，可以安全删除磁盘上的权重目录以节省空间
        if self.cfg.dreambooth == 1:
            import shutil
            personalization_dir = self.get_save_path("personalization/")
            if os.path.isdir(personalization_dir):
                try:
                    shutil.rmtree(personalization_dir)
                    print(f"[INFO] Removed DreamBooth personalization dir: {personalization_dir}")
                except Exception as e:
                    print(f"[WARN] Failed to remove personalization dir {personalization_dir}: {e}")
        
        # 优化器
        # if self.loss_type == "vsd":
        #     optimizer = self.gaussian.optimizer
        #     new_param_group = {
        #         'params': [p for p in self.guidance.submodules.pipe_lora.unet.parameters() if p.requires_grad],
        #         'lr': 1e-4,          
        #         'weight_decay': 1e-2
        #     }
        #     optimizer.add_param_group(new_param_group)

    def configure_optimizers(self):
        self.parser = ArgumentParser(description="Training script parameters")
        
        opt = OptimizationParams(self.parser)
        point_cloud = self.pcd_init()
        self.cameras_extent = 4.0
        self.gaussian.create_from_pcd(point_cloud, self.cameras_extent)
        
        # 初始化 active_sh_degree 为 1
        try:
            self.gaussian.set_sh_degree(0)
            print(f"[INFO] Initialized active_sh_degree to 0")
        except ValueError as e:
            print(f"[WARNING] Failed to initialize active_sh_degree to 0: {e}")

        # 检查转换出的初始3dgs
        save_path = self.get_save_path(f"init_3dgs.ply")
        self.gaussian.save_ply(save_path)
        # 保存转换到rgb空间的点云
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"init-color.ply"), coords, rgb)
        # 准备训练
        self.pipe = PipelineParams(self.parser)

        # 用初始化的 3DGS，沿一圈渲染几张图（相当于在 pcd init 之后立刻 test）
        self.initial_views(num_views=64)

        self.gaussian.training_setup(opt)
        ret = {
            "optimizer": self.gaussian.optimizer,
        }
        return ret
    
    
    # ================== 迭代相关 ==========================
    def forward(self, batch: Dict[str, Any],renderbackground = None) -> Dict[str, Any]:
        # if renderbackground is None:
        #     renderbackground = self.background_tensor
        bg_color = self.background_tensor * 0
        images = []
        depths = []
        self.viewspace_point_list = []
        for id in range(batch['c2w_3dgs'].shape[0]):
            viewpoint_cam  = Camera(c2w = batch['c2w_3dgs'][id],FoVy = batch['fovy'][id],height = batch['height'],width = batch['width'])
            render_pkg = render(viewpoint_cam, self.gaussian, self.pipe, bg_color)
            image, viewspace_point_tensor, _, radii,alpha = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"], render_pkg["rendered_alpha"]
            rays_d = batch['rays_d'][id] # [B, N, 3] 射线方向
            comp_rgb_bg = self.background(dirs=rays_d.unsqueeze(0))
            _, H, W = image.shape
            image = image + (1 - alpha) * comp_rgb_bg.reshape(
                H, W, 3
            ).permute(2, 0, 1)
            self.viewspace_point_list.append(viewspace_point_tensor)

            if id == 0:
                self.radii = radii
            else:
                self.radii = torch.max(radii,self.radii)
                
            depth = render_pkg["depth_3dgs"]
            depth =  depth.permute(1, 2, 0)
            
            image =  image.permute(1, 2, 0)
            images.append(image)
            depths.append(depth)
            
        images = torch.stack(images, 0)
        depths = torch.stack(depths, 0)
        self.visibility_filter = self.radii>0.0
        render_pkg["comp_rgb"] = images
        render_pkg["depth"] = depths
        render_pkg["opacity"] = depths / (depths.max() + 1e-5)
        return {
            **render_pkg,
        }
    
    def training_step(self, batch, batch_idx):
        self.gaussian.update_learning_rate(self.true_global_step)
        
        # 在1000步时将sh_degree从0增加到1
        if self.true_global_step == 1000:
            try:
                current_sh_degree = self.gaussian.active_sh_degree
                new_sh_degree = current_sh_degree + 1
                self.gaussian.set_sh_degree(new_sh_degree)
                print(f"[INFO] Step {self.true_global_step}: Updated active_sh_degree from {current_sh_degree} to {new_sh_degree}")
            except ValueError as e:
                print(f"[WARNING] Step {self.true_global_step}: Failed to update active_sh_degree: {e}")
        
        out = self(batch) 
        prompt_utils = self.prompt_processor()
        images = out["comp_rgb"]
        # ============ [插入这里] ============
        guidance_eval = (self.true_global_step % 100 == 0)        
        guidance_out = self.guidance(
            images, prompt_utils, **batch, rgb_as_latents=False,guidance_eval=guidance_eval
        )
        
        if self.loss_type == "sds":
            loss = self.compute_sds_loss(guidance_out,out)
            if guidance_eval:
                # ================== [新增] 打印高斯点数量 ==================
                num_points = self.gaussian.get_xyz.shape[0]
                print(f"\n[INFO] Total Gaussian Points: {num_points}")
                self.guidance_evaluation_save(
                    out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
                    guidance_out["eval"],
                )
        elif self.loss_type == "vsd":
            loss = self.compute_vsd_loss(guidance_out,out)
            if guidance_eval:
                if "sample_image" in guidance_out:
                    imgs = guidance_out["sample_image"]
                    imgs = imgs.permute(0, 2, 3, 1).contiguous()
                    for i in range(imgs.shape[0]):
                        self.save_image_grid(
                            f"it{self.true_global_step}-{i}.png",
                            [{"type": "rgb", "img": imgs[i], "kwargs": {"data_format": "HWC"}}],
                            name="guidance_sample",
                            step=self.true_global_step,
                        )
        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))
        if self.true_global_step > 0 and self.true_global_step % 300 == 0:
            self.save_video_loop(self.true_global_step)
        return {"loss": loss}

    def on_before_optimizer_step(self, optimizer):
        max_steps = self.trainer.max_steps or 1200
        # densify_until_iter = int(max_steps * 0.7)
        densify_until_iter = 1200
        start_densify_iter = 0         
        densify_interval = 100          # 每 100 步 densify 一次
        max_points = 800_000          # 高斯点上限
        with torch.no_grad():
            num_points = self.gaussian.get_xyz.shape[0]
            if self.true_global_step < densify_until_iter:
                viewspace_point_tensor_grad = torch.zeros_like(
                    self.viewspace_point_list[0]
                )
                for idx in range(len(self.viewspace_point_list)):
                    viewspace_point_tensor_grad = (
                        viewspace_point_tensor_grad
                        + self.viewspace_point_list[idx].grad
                    )

                # 更新每个点在屏幕上的最大半径
                self.gaussian.max_radii2D[self.visibility_filter] = torch.max(
                    self.gaussian.max_radii2D[self.visibility_filter],
                    self.radii[self.visibility_filter],
                )

                # 累积 densification 需要的统计量
                self.gaussian.add_densification_stats(
                    viewspace_point_tensor_grad, self.visibility_filter
                )

                # 到了指定步数才触发 densify / prune
                if (
                    self.true_global_step >= start_densify_iter
                    and self.true_global_step % densify_interval == 0
                ):  
                    # ==== 关键：根据当前 step 动态提高 grad_threshold ====
                    # if self.true_global_step < 300:
                    #     grad_threshold = 2e-4
                    # elif self.true_global_step < 500:
                    #     grad_threshold = 1e-3
                    # else:
                    #     grad_threshold = 2e-3
                    grad_threshold = 0.01
                    # --------- 新：随机裁剪 ---------
                    size_threshold = 20 if self.true_global_step > 500 else None
                    self.gaussian.densify(grad_threshold, self.cameras_extent)
                    self.gaussian.random_prune(0.05, self.cameras_extent, size_threshold, max_points)

                    # --------- 旧：通过opacity阈值裁剪 ---------
                    # min_opacity = 0.05           # 不透明度低于这个就 prune 掉
                    # if num_points >= max_points:
                    #     # 超过上限：只 prune，不再 densify
                    #     grad_threshold = 1e9    # 设置成巨大的阈值，相当于禁用 densify
                    # self.gaussian.densify_and_prune(
                    #     grad_threshold, min_opacity, self.cameras_extent, size_threshold
                    # )


    # ================== 验证测试相关 ==========================
    def validation_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.true_global_step}-{batch['index'][0]}.png",
            (
                [
                    {
                        "type": "rgb",
                        "img": batch["rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    }
                ]
                if "rgb" in batch
                else []
            )
            + [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            ),
            name="validation_step",
            step=self.true_global_step,
        )
        # save_path = self.get_save_path(f"it{self.true_global_step}-val.ply")
        # self.gaussian.save_ply(save_path)
        # load_ply(save_path,self.get_save_path(f"it{self.true_global_step}-val-color.ply"))

    def initial_views(self, num_views: int = None):
        """
        完全按照 RandomCameraDataset 的方式，
        使用 numpy 生成一圈球面视角, 保存至init_views文件夹
        """
        from threestudio.data.uncond_v2 import pose_spherical
        device = torch.device("cuda")

        # ====== 1. 使用与 Dataset 一致的参数 ======
        if num_views is None:
            num_views = getattr(self.cfg, "n_test_views", 8)

        azimuth_deg = np.linspace(0.0, 360.0, num_views, endpoint=False)
        elevation_deg = np.full_like(azimuth_deg, float(15))
        camera_distances = np.full_like(azimuth_deg, float(4))
        fovy_deg = np.full_like(azimuth_deg, float(70))

        # ====== 2. 构造 3DGS 的 c2w_3dgs （完全照 Dataset） ======
        c2w_list = []

        for i in range(num_views):
            # >>> 和 RandomCameraDataset 完全一致 <<<
            theta = float(azimuth_deg[i] + 180.0 - self.load_type * 90.0)
            phi   = float(-elevation_deg[i])
            radius = float(camera_distances[i])

            render_pose = pose_spherical(theta, phi, radius)  # numpy 4×4

            # 转 torch 继续下面步骤
            render_pose_t = torch.tensor(render_pose, dtype=torch.float32, device=device)

            # === 以下完全照 RandomCameraDataset ===
            matrix = torch.linalg.inv(render_pose_t)
            R = -torch.transpose(matrix[:3, :3], 0, 1)
            R[:, 0] = -R[:, 0]
            T = -matrix[:3, 3]

            c2w = torch.cat([R, T[:, None]], dim=1)
            bottom = torch.tensor([[0, 0, 0, 1]], dtype=torch.float32, device=device)
            c2w = torch.cat([c2w, bottom], dim=0)

            c2w_list.append(c2w)

        c2w_3dgs = torch.stack(c2w_list, dim=0)  # [B, 4, 4]

        # ===== 3. 渲染部分 =====
        H = 512
        W = 512
        bg = self.background_tensor

        fovy = torch.tensor(fovy_deg * math.pi / 180.0, device=device, dtype=torch.float32)

        for i in range(num_views):
            cam = Camera(
                c2w=c2w_3dgs[i],
                FoVy=fovy[i],
                height=H,
                width=W,
            )

            render_pkg = render(cam, self.gaussian, self.pipe, bg)
            img = render_pkg["render"]  # [3, H, W]

            self.save_image_grid(
                f"init_views/{i:03d}.png",
                [
                    {
                        "type": "rgb",
                        "img": img.permute(1, 2, 0),
                        "kwargs": {"data_format": "HWC"},
                    }
                ],
                name="init_test",
                step=i,
            )
        
        # 生成 MP4 视频
        self.save_img_sequence(
            "init_views.mp4",
            "init_views",
            r"(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="init_views",
            step=0,
        )
    
    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        only_rgb = True
        bg_color = [1, 1, 1] if False else [0, 0, 0]

        testbackground_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        out = self(batch,testbackground_tensor)
        if only_rgb:
            self.save_image_grid(
                f"it{self.true_global_step}-test/{batch['index'][0]}.png",
                (
                    [
                        {
                            "type": "rgb",
                            "img": batch["rgb"][0],
                            "kwargs": {"data_format": "HWC"},
                        }
                    ]
                    if "rgb" in batch
                    else []
                )
                + [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                + (
                    [
                        {
                            "type": "rgb",
                            "img": out["comp_normal"][0],
                            "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                        }
                    ]
                    if "comp_normal" in out
                    else []
                ),
                name="test_step",
                step=self.true_global_step,
            )
        else:
            self.save_image_grid(
                f"it{self.true_global_step}-test/{batch['index'][0]}.png",
                (
                    [
                        {
                            "type": "rgb",
                            "img": batch["rgb"][0],
                            "kwargs": {"data_format": "HWC"},
                        }
                    ]
                    if "rgb" in batch
                    else []
                )
                + [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                + (
                    [
                        {
                            "type": "rgb",
                            "img": out["comp_normal"][0],
                            "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                        }
                    ]
                    if "comp_normal" in out
                    else []
                )
                + (
                    [
                        {
                            "type": "grayscale",
                            "img": out["depth"][0],
                            "kwargs": {},
                        }
                    ]
                    if "depth" in out
                    else []
                )
                + [
                    {
                        "type": "grayscale",
                        "img": out["opacity"][0, :, :, 0],
                        "kwargs": {"cmap": None, "data_range": (0, 1)},
                    },
                ],
                name="test_step",
                step=self.true_global_step,
            )

    def on_test_epoch_end(self):
        self.save_img_sequence(
            f"it{self.true_global_step}-test",
            f"it{self.true_global_step}-test",
            "(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="test",
            step=self.true_global_step,
        )
        # ================== [新增] 打印高斯点数量 ==================
        num_points = self.gaussian.get_xyz.shape[0]
        print(f"\n[INFO] Total Gaussian Points: {num_points}")
        save_path = self.get_save_path(f"last_3dgs.ply")
        self.gaussian.save_ply(save_path)
        # 保存转换到rgb空间的点云
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"it{self.true_global_step}-test-color.ply"), coords, rgb)
    

    def compute_sds_loss(self, guidance_out, out):
        loss = 0.0

        # 1. SDS 主损失（保持不变）
        loss_sds = guidance_out['loss_sds'] * self.C(self.cfg.loss['lambda_sds'])
        loss = loss + loss_sds
        self.log("train/loss_sds", guidance_out['loss_sds'])

        # 2. 高斯级别的 sparsity 正则（改这里）
        gauss_opacity = self.gaussian.get_opacity            # [N, 1]
        loss_sparsity = (gauss_opacity ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity_gauss", loss_sparsity)
        loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        # （可选）只做监控：看看像素级 alpha 的均值，不加到 loss 里
        alpha = out["opacity"]                               # [B, H, W, 1]
        self.log("train/alpha_mean", alpha.mean())

        # 3. 高斯级别的不透明度 entropy 正则（改这里）
        gauss_opacity_clamped = gauss_opacity.clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(
            gauss_opacity_clamped, gauss_opacity_clamped
        )
        self.log("train/loss_opaque_gauss", loss_opaque)
        loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)
        if "lambda_scales" in self.cfg.loss and self.cfg.loss["lambda_scales"] > 0.0:
            scales = self.gaussian.get_scaling  # [N, 3] 或 [N, ?]
            loss_scales = scales.mean()
            self.log("train/loss_scales", loss_scales)
            loss += self.C(self.cfg.loss["lambda_scales"]) * loss_scales

        # 4. 记录总 loss
        self.log("train/loss_total", loss)
        return loss
    
    def compute_vsd_loss(self,guidance_out,out):
        loss = 0.0
        for name, value in guidance_out.items():
            if not (type(value) is torch.Tensor and value.numel() > 1):
                self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])


        loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity", loss_sparsity)
        loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
        self.log("train/loss_opaque", loss_opaque)
        loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)
        return loss

    def save_video_loop(self, step):
        """每隔一定步数渲染一圈视频"""
        from threestudio.data.uncond_v2 import pose_spherical
        device = torch.device("cuda")
        
        # 参数设置：使用当前配置的半径，或默认值
        num_views = 64
        azimuth_deg = np.linspace(0.0, 360.0, num_views, endpoint=False)
        elevation_deg = np.full_like(azimuth_deg, float(15))
        # 尝试使用配置中的 radius (通常对应 eval_camera_distance)，如果没定义则用 4.0
        dist = getattr(self, 'radius', 4.0)
        camera_distances = np.full_like(azimuth_deg, float(dist))
        fovy_deg = np.full_like(azimuth_deg, float(70))

        c2w_list = []
        for i in range(num_views):
            theta = float(azimuth_deg[i] + 180.0 - self.load_type * 90.0)
            phi   = float(-elevation_deg[i])
            radius = float(camera_distances[i])
            render_pose = pose_spherical(theta, phi, radius)
            render_pose_t = torch.tensor(render_pose, dtype=torch.float32, device=device)
            
            matrix = torch.linalg.inv(render_pose_t)
            R = -torch.transpose(matrix[:3, :3], 0, 1)
            R[:, 0] = -R[:, 0]
            T = -matrix[:3, 3]
            c2w = torch.cat([R, T[:, None]], dim=1)
            bottom = torch.tensor([[0, 0, 0, 1]], dtype=torch.float32, device=device)
            c2w = torch.cat([c2w, bottom], dim=0)
            c2w_list.append(c2w)
        
        c2w_3dgs = torch.stack(c2w_list, dim=0)
        H, W = 512, 512
        bg = self.background_tensor
        fovy = torch.tensor(fovy_deg * math.pi / 180.0, device=device, dtype=torch.float32)

        video_dir = f"it{step}-video"
        
        for i in range(num_views):
            cam = Camera(
                c2w=c2w_3dgs[i], FoVy=fovy[i], height=H, width=W,
            )
            render_pkg = render(cam, self.gaussian, self.pipe, bg)
            img = render_pkg["render"]
            
            self.save_image_grid(
                f"{video_dir}/{i:03d}.png",
                [{"type": "rgb", "img": img.permute(1, 2, 0), "kwargs": {"data_format": "HWC"}}],
                name=f"video_step_{step}",
                step=step
            )

        self.save_img_sequence(
            f"it{step}-loop.mp4",
            video_dir,
            r"(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="video_loop",
            step=step
        )
        print(f"[INFO] Saved loop video for step {step}")