import os
import math
from dataclasses import dataclass

import numpy as np
import threestudio
import torch
from threestudio.systems.base import BaseLift3DSystem
from threestudio.systems.utils import parse_optimizer, parse_scheduler
from threestudio.utils.loss import tv_loss
from threestudio.utils.ops import get_cam_info_gaussian
from threestudio.utils.typing import *
from torch.cuda.amp import autocast
from scipy.spatial import cKDTree
from threestudio.models.geometry.gaussian_base import BasicPointCloud, Camera


@threestudio.register("gaussian-splatting-system")
class GaussianSplatting(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        visualize_samples: bool = False
        radius: float = 4
        sh_degree: int = 0
        dreambooth: int = 0
        load_type: int = 0
        load_path: str = "./load/shapes/stand.obj"
        loss_type: str = "sds"

    cfg: Config

    def configure(self) -> None:
        # set up geometry, material, background, renderer
        super().configure()
        self.automatic_optimization = False
        self.dreambooth = self.cfg.dreambooth
        self.radius = self.cfg.radius
        self.sh_degree =self.cfg.sh_degree
        self.load_type =self.cfg.load_type
        self.loss_type = self.cfg.loss_type
        self.load_path = self.cfg.load_path
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.prompt_utils = self.prompt_processor()

    def pcd_init(self):
        """加载点云数据, 并处理成3dgs的点云格式"""
        # Since this data set has no colmap data, we start with random points
        if self.load_type== 4: # shap_e
            # 省事save_path
            from threestudio.systems.function.point_cloud import load_from_shape
            from threestudio.systems.function.point_cloud import load_from_vggt
            save_path = self.get_save_path('instance_images/')
            coords,rgb = load_from_vggt(self.cfg, save_path)
            coords,rgb = load_from_shape(self.cfg.prompt_processor.prompt)
        elif self.load_type == 1: # pcd
            from threestudio.systems.function.point_cloud import load_from_pcd
            coords,rgb = load_from_pcd(self.load_path)
            coords = coords[:, [0, 2, 1]]  # [x, z, y] -> [x, y, z]
        elif self.load_type == 2: # smpl
            from threestudio.systems.function.point_cloud import load_from_smpl
            coords,rgb = load_from_smpl(self.load_path)
        elif self.load_type == 3: # 3dgs
            from threestudio.systems.function.point_cloud import load_from_3dgs
            coords,rgb = load_from_3dgs(self.load_path)
        elif self.load_type == 0: # vggt
            from threestudio.systems.function.point_cloud import load_from_vggt
            save_path = self.get_save_path('instance_images/')
            coords,rgb = load_from_vggt(self.cfg, save_path)
            coords = coords[:, [0, 2, 1]]  # [x, z, y] -> [x, y, z]
            # 交换 y 和 z 轴（修正坐标系）
           
        else:
            raise NotImplementedError(f"load_type {self.load_type} is not implemented, only support [0: shap_e, 1: pcd, 2: smpl, 3: 3dgs]")
        
        # 先中心化到原点
        coords -= np.mean(coords, axis=0)
        # 计算每个轴的范围（最大值-最小值）
        x_length = np.max(coords[:,0]) - np.min(coords[:,0])
        y_length = np.max(coords[:,1]) - np.min(coords[:,1])
        z_length = np.max(coords[:,2]) - np.min(coords[:,2])
        # 找到最长的轴，等比例缩放使得最长轴的范围是[-1, 1]
        max_length = max(x_length, y_length, z_length) / 2
        if max_length > 0:
            coords /= max_length

        # =========================
        # 3) Voxel-downsample to target_num WITHOUT enlarging spacing too much
        #    - compute nn_median (d50)
        #    - set voxel_size ~ 0.85*d50 (clamped)
        #    - if still > target_num, sample voxels (NOT increasing voxel_size)
        # =========================
        target_num = 100000
        if coords.shape[0] > target_num:
            tree = cKDTree(coords)
            dists, _ = tree.query(coords, k=2)
            nn = dists[:, 1]
            d50 = float(np.median(nn))
            voxel_size = float(np.clip(0.85 * d50, 1e-4, 0.02))

            threestudio.info(f'[VOXEL INIT] nn_median={d50:.4e}, voxel_init={voxel_size:.4e}')

            min_bound = np.min(coords, axis=0)
            q = np.floor((coords - min_bound) / voxel_size).astype(np.int32)

            # 每个 voxel 取一个代表点
            _, unique_indices = np.unique(q, axis=0, return_index=True)

            # 如果 voxel 代表点仍然 > 10w：直接抽 voxel，不增大 voxel_size
            if len(unique_indices) > target_num:
                np.random.shuffle(unique_indices)
                unique_indices = unique_indices[:target_num]

            coords = coords[unique_indices]
            rgb = rgb[unique_indices]
            threestudio.info(f"Downsampled to {coords.shape[0]} points with voxel_size={voxel_size:.4e}")

            # =========================
            # 4) Remove outliers (fix huge NN max like 1e-1)
            # =========================
            tree2 = cKDTree(coords)
            d2, _ = tree2.query(coords, k=2)
            nn2 = d2[:, 1]

            nn2_med = float(np.median(nn2))
            nn2_mean = float(np.mean(nn2))
            nn2_max = float(np.max(nn2))
            threestudio.info(f"[NN AFTER] median={nn2_med:.4e}, mean={nn2_mean:.4e}, max={nn2_max:.4e}")

            p99 = float(np.quantile(nn2, 0.99))
            keep = nn2 < (3.0 * p99)
            kept = int(np.sum(keep))
            if kept < coords.shape[0]:
                coords = coords[keep]
                rgb = rgb[keep]
                threestudio.info(f"[OUTLIER] kept={kept}/{keep.shape[0]} (thr=3*p99={3.0*p99:.4e})")

            # 如果剔除后少于 target_num，补回一些（从当前 coords 中随机重复补足不推荐）
            # 一般不需要补，少一点点问题不大；你也可以允许 95k~100k 浮动
            if coords.shape[0] > target_num:
                idx = np.random.choice(coords.shape[0], target_num, replace=False)
                coords = coords[idx]
                rgb = rgb[idx]

        # =========================
        # 5) Build BasicPointCloud
        # =========================
        r = np.linalg.norm(coords, axis=1)
        threestudio.info(f"[PCD FINAL] n={coords.shape[0]} | r_min={r.min():.3f}, r_mean={r.mean():.3f}, r_max={r.max():.3f}")

        normals = np.zeros((coords.shape[0], 3), dtype=np.float32)
        pcd = BasicPointCloud(points=coords, colors=rgb, normals=normals)
        return pcd

    def configure_optimizers(self):
        point_cloud = self.pcd_init()
        # 对于无背景的 3D 物体生成，4.0 是标准且合适的值，能防止高斯球过大破坏细节
        self.geometry.create_from_pcd(point_cloud, 4)
        
        with torch.no_grad():
            s = self.geometry.get_scaling
            threestudio.info(f"[INIT SCALE] min={s.min().item():.4e}, mean={s.mean().item():.4e}, max={s.max().item():.4e}")
        self.geometry.training_setup()
        save_path = self.get_save_path(f"init_3dgs.ply")
        self.geometry.save_ply(save_path)
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"init-color.ply"), coords, rgb)
        optim = self.geometry.optimizer
        if hasattr(self, "merged_optimizer"):
            return [optim]
        if hasattr(self.cfg.optimizer, "name"):
            net_optim = parse_optimizer(self.cfg.optimizer, self)
            optim = self.geometry.merge_optimizer(net_optim)
            self.merged_optimizer = True
        else:
            self.merged_optimizer = False
        return [optim]

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        self.geometry.update_learning_rate(self.global_step)
        outputs = self.renderer.batch_forward(batch)
        return outputs

    def on_fit_start(self) -> None:
        super().on_fit_start()
        # 根据mvadapter来进行微调
        if self.cfg.dreambooth == 1:
            cmd = [
                "python", "threestudio/systems/function/dreambooth_full.py",
                "--pretrained_model_name_or_path", self.cfg.guidance.pretrained_model_name_or_path,
                "--enable_xformers_memory_efficient_attention",
                "--with_prior_preservation",
                "--instance_data_dir", self.get_save_path('instance_images/'),
                # "--instance_data_dir", "/data4/xiejch/data/gaussian_dreamer/multi_view_images/horse_images",
                "--instance_prompt", self.cfg.prompt_processor.dreambooth_prompt,
                "--class_data_dir", self.get_save_path('class_samples/'),
                "--class_prompt", self.cfg.prompt_processor.prompt.replace('<kth>', ''),
                "--num_class_images", "100",
                "--validation_prompt", self.cfg.prompt_processor.prompt,
                "--output_dir", self.get_save_path('personalization/'),
                "--max_train_steps", str(400),
                "--train_batch_size", str(4),
                "--gradient_accumulation_steps", str(1),
                # "--mixed_precision", "fp16"
            ]
            # 执行dreambooth训练
            import subprocess
            try:
                result = subprocess.run(cmd, check=True)
                print("DreamBooth训练成功完成")
                print(result.stdout)
                if result.returncode != 0:
                    raise subprocess.CalledProcessError(result.returncode, cmd, result.stdout, result.stderr)
            except subprocess.CalledProcessError as e:
                print("DreamBooth训练失败")
                print(f"错误输出: {e.stderr}")
                print(f"返回码: {e.returncode}")
        
            # only used in training
            self.cfg.prompt_processor.pretrained_model_name_or_path = self.get_save_path('personalization/')
            self.cfg.guidance.pretrained_model_name_or_path = self.get_save_path('personalization/')
            if self.loss_type == "vsd":
                self.cfg.guidance.pretrained_model_name_or_path_lora = self.get_save_path('personalization/')
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)

        # DreamBooth 微调模型在此时已经加载到内存中，可以安全删除磁盘上的权重目录以节省空间
        if self.cfg.dreambooth == 1:
            import shutil
            personalization_dir = self.get_save_path("personalization/")
            if os.path.isdir(personalization_dir):
                try:
                    shutil.rmtree(personalization_dir)
                    print(f"[INFO] Removed DreamBooth personalization dir: {personalization_dir}")
                except Exception as e:
                    print(f"[WARN] Failed to remove personalization dir {personalization_dir}: {e}")

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        out = self(batch)

        visibility_filter = out["visibility_filter"]
        radii = out["radii"]
        guidance_inp = out["comp_rgb"]
        # import pdb; pdb.set_trace()
        viewspace_point_tensor = out["viewspace_points"]
        guidance_eval = (self.true_global_step % 100 == 0)   
        guidance_out = self.guidance(
            guidance_inp, self.prompt_utils, **batch, rgb_as_latents=False,guidance_eval=guidance_eval
        )

        loss_sds = 0.0
        loss = 0.0
        if guidance_eval:
            self.guidance_evaluation_save(
                out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
                guidance_out["eval"],
            )
        self.log(
            "gauss_num",
            int(self.geometry.get_xyz.shape[0]),
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        for name, value in guidance_out.items():
            # 跳过 'eval' 键，因为它包含字典和tensor，不能直接记录
            if name == "eval":
                continue
            self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss_sds += value * self.C(
                    self.cfg.loss[name.replace("loss_", "lambda_")]
                )
        
        # if self.true_global_step % 100 == 0:
        #     threestudio.info(f"[Step {self.true_global_step}] loss_sds: {loss_sds.item():.6f}")

        xyz_mean = None
        if self.cfg.loss["lambda_position"] > 0.0:
            xyz_mean = self.geometry.get_xyz.norm(dim=-1)
            loss_position = xyz_mean.mean()
            self.log(f"train/loss_position", loss_position)
            loss += self.C(self.cfg.loss["lambda_position"]) * loss_position

        if self.cfg.loss["lambda_opacity"] > 0.0:
            scaling = self.geometry.get_scaling.norm(dim=-1)
            loss_opacity = (
                scaling.detach().unsqueeze(-1) * self.geometry.get_opacity
            ).sum()
            self.log(f"train/loss_opacity", loss_opacity)
            loss += self.C(self.cfg.loss["lambda_opacity"]) * loss_opacity

        if self.cfg.loss["lambda_scales"] > 0.0:
            scale_sum = torch.sum(self.geometry.get_scaling)
            self.log(f"train/scales", scale_sum)
            loss += self.C(self.cfg.loss["lambda_scales"]) * scale_sum

        if self.cfg.loss["lambda_tv_loss"] > 0.0:
            loss_tv = self.C(self.cfg.loss["lambda_tv_loss"]) * tv_loss(
                out["comp_rgb"].permute(0, 3, 1, 2)
            )
            self.log(f"train/loss_tv", loss_tv)
            loss += loss_tv

        if (
            out.__contains__("comp_depth")
            and self.cfg.loss["lambda_depth_tv_loss"] > 0.0
        ):
            alpha = out["comp_alpha"].permute(0, 3, 1, 2)  # [B,1,H,W]

            loss_depth_tv = self.C(self.cfg.loss["lambda_depth_tv_loss"]) * (
                tv_loss(out["comp_normal"].permute(0, 3, 1, 2), weight=alpha) +
                tv_loss(out["comp_depth"].permute(0, 3, 1, 2), weight=alpha)
            )
            self.log(f"train/loss_depth_tv", loss_depth_tv)
            loss += loss_depth_tv

        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))

        loss_sds.backward(retain_graph=True)
        iteration = self.global_step
        self.geometry.update_states(
            iteration,
            visibility_filter,
            radii,
            viewspace_point_tensor,
        )
        if loss > 0:
            loss.backward()
        opt.step()
        opt.zero_grad(set_to_none=True)
        return {"loss": loss_sds}

    def validation_step(self, batch, batch_idx):
        out = self(batch)
        # import pdb; pdb.set_trace()
        self.save_image_grid(
            f"it{self.global_step}-{batch['index'][0]}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal_vis"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            ),
            name="validation_step",
            step=self.global_step,
        )

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.global_step}-test/rgb_{batch['index'][0]}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            ),
            name="test_step",
            step=self.global_step,
        )
        if batch["index"][0] == 0:
            save_path = self.get_save_path("point_cloud.ply")
            self.geometry.save_ply(save_path)

    def on_test_epoch_end(self):
        self.save_img_sequence(
            f"it{self.global_step}-test",
            f"it{self.global_step}-test",
            "rgb_(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="test",
            step=self.global_step,
        )

    def on_load_checkpoint(self, ckpt_dict) -> None:
        num_pts = ckpt_dict["state_dict"]["geometry._xyz"].shape[0]
        pcd = BasicPointCloud(
            points=np.zeros((num_pts, 3)),
            colors=np.zeros((num_pts, 3)),
            normals=np.zeros((num_pts, 3)),
        )
        self.geometry.create_from_pcd(pcd, 10)
        self.geometry.training_setup()
        super().on_load_checkpoint(ckpt_dict)