import os
import torch
import numpy as np
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field

from gaussiansplatting.gaussian_renderer import render
from gaussiansplatting.scene import Scene, GaussianModel
from gaussiansplatting.arguments import ModelParams, PipelineParams, get_combined_args,OptimizationParams
from gaussiansplatting.scene.cameras import Camera
from gaussiansplatting.utils.sh_utils import SH2RGB
from gaussiansplatting.scene.gaussian_model import BasicPointCloud

import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy
from threestudio.utils.typing import *


@threestudio.register("gaussiandreamer-vsd-system")
class GaussianDreamer(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        radius: float = 4
        sh_degree: int = 0
        dreambooth: int = 0
        load_type: int = 0
        load_path: str = "./load/shapes/stand.obj"
        loss_type: str = "vsd"


    # ================== 初始化相关 ==========================
    cfg: Config
    def configure(self) -> None:
        self.automatic_optimization = False
        self.loss_type = self.cfg.loss_type
        self.dreambooth = self.cfg.dreambooth
        self.radius = self.cfg.radius
        self.sh_degree =self.cfg.sh_degree
        self.load_type =self.cfg.load_type
        self.load_path = self.cfg.load_path

        self.gaussian = GaussianModel(sh_degree = self.sh_degree)
        bg_color = [1, 1, 1] if False else [0, 0, 0]
        self.background_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    def pcd_init(self) -> BasicPointCloud:
        """加载点云数据, 并处理成3dgs的点云格式"""
        # Since this data set has no colmap data, we start with random points
        if self.load_type== 4: # shap_e
            from threestudio.systems.function.point_cloud import load_from_shape
            coords,rgb = load_from_shape(self.load_path)
        elif self.load_type == 1: # pcd
            from threestudio.systems.function.point_cloud import load_from_pcd
            coords,rgb = load_from_pcd(self.load_path)
        elif self.load_type == 2: # smpl
            from threestudio.systems.function.point_cloud import load_from_smpl
            coords,rgb = load_from_smpl(self.load_path)
        elif self.load_type == 3: # 3dgs
            from threestudio.systems.function.point_cloud import load_from_3dgs
            coords,rgb = load_from_3dgs(self.load_path)
        elif self.load_type == 0: # vggt
            from threestudio.systems.function.point_cloud import load_from_vggt
            save_path = self.get_save_path('instance_images/')
            coords,rgb = load_from_vggt(self.cfg, save_path)
        else:
            raise NotImplementedError(f"load_type {self.load_type} is not implemented, only support [0: shap_e, 1: pcd, 2: smpl, 3: 3dgs]")
        
        bound = self.radius * 1
        pcd = BasicPointCloud(points=coords*bound, colors=rgb, normals=np.zeros((coords.shape[0], 3)))
        return pcd
    
    def dreambooth_finetine(self):
        # 根据mvadapter来进行微调
        if self.cfg.dreambooth == 1:
            cmd = [
                "python", "threestudio/systems/function/dreambooth.py",
                "--pretrained_model_name_or_path", self.cfg.guidance.pretrained_model_name_or_path,
                "--enable_xformers_memory_efficient_attention",
                "--with_prior_preservation",
                "--instance_data_dir", self.get_save_path('instance_images/'),
                "--instance_prompt", self.cfg.prompt_processor.dreambooth_prompt,
                "--class_data_dir", self.get_save_path('class_samples/'),
                "--class_prompt", self.cfg.prompt_processor.prompt.replace('<kth>', ''),
                "--num_class_images", "100",
                "--validation_prompt", self.cfg.prompt_processor.prompt,
                "--output_dir", self.get_save_path('personalization/'),
                "--max_train_steps", str(250),
                "--train_batch_size", str(8),
                "--gradient_accumulation_steps", str(1),
                # "--mixed_precision", "fp16"
            ]
            # 执行dreambooth训练
            import subprocess
            try:
                result = subprocess.run(cmd, check=True)
                print("DreamBooth训练成功完成")
                print(result.stdout)
                if result.returncode != 0:
                    raise subprocess.CalledProcessError(result.returncode, cmd, result.stdout, result.stderr)
            except subprocess.CalledProcessError as e:
                print("DreamBooth训练失败")
                print(f"错误输出: {e.stderr}")
                print(f"返回码: {e.returncode}")
        
            # only used in training
            self.cfg.prompt_processor.pretrained_model_name_or_path = self.get_save_path('personalization/')
            self.cfg.guidance.pretrained_model_name_or_path = self.get_save_path('personalization/')
            if self.loss_type == "vsd":
                self.cfg.guidance.pretrained_model_name_or_path_lora = self.get_save_path('personalization/')
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) 
    
    def on_fit_start(self) -> None:
        super().on_fit_start()

        # 优化器
        # if self.loss_type == "vsd":
        #     optimizer = self.gaussian.optimizer
        #     new_param_group = {
        #         'params': [p for p in self.guidance.submodules.pipe_lora.unet.parameters() if p.requires_grad],
        #         'lr': 1e-4,          
        #         'weight_decay': 1e-2
        #     }
        #     optimizer.add_param_group(new_param_group)

    def configure_optimizers(self):
        self.parser = ArgumentParser(description="Training script parameters")
        
        opt = OptimizationParams(self.parser)
        point_cloud = self.pcd_init()
        self.cameras_extent = 4.0
        self.gaussian.create_from_pcd(point_cloud, self.cameras_extent)

        # 检查转换出的初始3dgs
        save_path = self.get_save_path(f"init_3dgs.ply")
        self.gaussian.save_ply(save_path)
        # 保存转换到rgb空间的点云
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"init-color.ply"), coords, rgb)
        self.dreambooth_finetine()

        

        # 准备训练
        self.pipe = PipelineParams(self.parser)
        self.gaussian.training_setup(opt)
        # 1. 拿到 LoRA 参数
        lora_params = self.guidance.get_lora_parameters()

        # 2. 定义 LoRA 的 lr（自己在 opt 或 cfg 里加一个 lora_lr 也行）
        lora_lr = getattr(opt, "lora_lr", 1e-5)

        # ========= 1) 3DGS 优化器（原来的 Adam） =========
        self.gaussian.training_setup(opt)
        optimizer_3dgs = self.gaussian.optimizer    # 里面还是 Adam(l, lr=0.0, eps=1e-15)

        # ========= 2) LoRA 优化器：单独 AdamW =========
        lora_params = self.guidance.get_lora_parameters()
        lora_lr = getattr(opt, "lora_lr", 1e-5)

        optimizer_lora = torch.optim.AdamW(
            lora_params,
            lr=lora_lr,
            betas=(0.9, 0.999),
            weight_decay=0.0,
        )

        # Lightning 支持多个 optimizer，直接返回 list 即可
        return [optimizer_3dgs, optimizer_lora]
    
    
    # ================== 迭代相关 ==========================
    def forward(self, batch: Dict[str, Any],renderbackground = None) -> Dict[str, Any]:
        if renderbackground is None:
            renderbackground = self.background_tensor
        images = []
        depths = []
        self.viewspace_point_list = []
        for id in range(batch['c2w_3dgs'].shape[0]):
            viewpoint_cam  = Camera(c2w = batch['c2w_3dgs'][id],FoVy = batch['fovy'][id],height = batch['height'],width = batch['width'])
            render_pkg = render(viewpoint_cam, self.gaussian, self.pipe, renderbackground)
            image, viewspace_point_tensor, _, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
            self.viewspace_point_list.append(viewspace_point_tensor)

            if id == 0:
                self.radii = radii
            else:
                self.radii = torch.max(radii,self.radii)
                
            depth = render_pkg["depth_3dgs"]
            depth =  depth.permute(1, 2, 0)
            
            image =  image.permute(1, 2, 0)
            images.append(image)
            depths.append(depth)
            
        images = torch.stack(images, 0)
        depths = torch.stack(depths, 0)
        self.visibility_filter = self.radii>0.0
        render_pkg["comp_rgb"] = images
        render_pkg["depth"] = depths
        render_pkg["opacity"] = depths / (depths.max() + 1e-5)
        return {
            **render_pkg,
        }
    
    def training_step(self, batch, batch_idx):
        self.current_step = self.trainer.fit_loop.epoch_loop._batches_that_stepped
        # print(f"[train] batch_idx={batch_idx}, current_step={self.trainer.fit_loop.epoch_loop._batches_that_stepped}, num_opts={len(self.optimizers())}")
        self.gaussian.update_learning_rate(self.current_step)
        if self.current_step > 500:
            self.guidance.set_min_max_steps(min_step_percent=0.02, max_step_percent=0.55)

        out = self(batch) 

        prompt_utils = self.prompt_processor()
        images = out["comp_rgb"]

        guidance_eval = (self.current_step % 200 == 0)        
        guidance_out = self.guidance(
            images, prompt_utils, **batch, rgb_as_latents=False,guidance_eval=guidance_eval
        )
        
        if self.loss_type == "sds":
            loss = self.compute_sds_loss(guidance_out,out)
            if guidance_eval:
                self.guidance_evaluation_save(
                    out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
                    guidance_out["eval"],
                )
        elif self.loss_type == "vsd":
            if guidance_eval:
                if "sample_image" in guidance_out:
                    imgs = guidance_out["sample_image"]
                    imgs = imgs.permute(0, 2, 3, 1).contiguous()
                    for i in range(imgs.shape[0]):
                        self.save_image_grid(
                            f"it{self.current_step}-batch-{i}.png",
                            [{"type": "rgb", "img": imgs[i], "kwargs": {"data_format": "HWC"}}],
                            name="guidance_sample",
                            step=self.current_step,
                        )
            loss_gs = guidance_out["loss_vsd"]
            loss_gs = loss_gs + self._regularize_3dgs(out)
            opt_gs, opt_lora = self.optimizers()
            self.manual_backward(loss_gs)
            opt_gs.step(); opt_gs.zero_grad(set_to_none=True)


            # 5) 再 LoRA：只用它自己的 loss
            loss_lora = guidance_out["loss_lora"]
            self.manual_backward(loss_lora)
            opt_lora.step(); opt_lora.zero_grad(set_to_none=True)


            # 6) 记录日志（可选）
            self.log("train/loss_gs", loss_gs.detach(), prog_bar=True)
            self.log("train/loss_lora", loss_lora.detach(), prog_bar=True)


            # 7) 返回值（手动优化时 Lightning 不强制依赖返回的 loss）
            return {"loss": (loss_gs.detach() + 0.0 * loss_lora.detach())}

    def on_before_optimizer_step(self, optimizer):
        # print(f"[pre-step] current_step(before)={self.current_step}")
        if optimizer is self.gaussian.optimizer:
            with torch.no_grad():
                if self.current_step < 3000: # 原900
                    if self.current_step ==500 or self.current_step ==1000:
                        self.gaussian.oneupSHdegree()
                    viewspace_point_tensor_grad = torch.zeros_like(self.viewspace_point_list[0])
                    for idx in range(len(self.viewspace_point_list)):
                        viewspace_point_tensor_grad = viewspace_point_tensor_grad + self.viewspace_point_list[idx].grad
                    # Keep track of max radii in image-space for pruning
                    self.gaussian.max_radii2D[self.visibility_filter] = torch.max(self.gaussian.max_radii2D[self.visibility_filter], self.radii[self.visibility_filter])
                    
                    self.gaussian.add_densification_stats(viewspace_point_tensor_grad, self.visibility_filter)

                    # if 300 < self.current_step < 900 and self.current_step % 100 == 0: # 500 100
                    if self.current_step % 100 == 0:
                        size_threshold = 20 if self.current_step > 500 else None # 3000
                        self.gaussian.densify_and_prune(0.0002, 0.05, self.cameras_extent, size_threshold) 
                    # else:
                    #     size_threshold = 20
                    #     if self.current_step % 200 == 0:
                    #         self.gaussian.densify(0.005, self.cameras_extent)
                    #     if self.current_step % 100 == 0:
                    #         self.gaussian.prune(0.05, self.cameras_extent, size_threshold)

            # ===== LoRA 梯度检查 =====
        # if optimizer is not self.gaussian.optimizer:
        #     total_norm = 0.0
        #     count = 0
        #     max_mean = 0.0
        #     max_name = None

        #     for name, p in self.guidance.unet_lora.named_parameters():
        #         if "lora_" in name and p.requires_grad and p.grad is not None:
        #             g = p.grad.detach()
        #             norm = g.norm().item()
        #             mean = g.abs().mean().item()
        #             total_norm += norm
        #             count += 1
        #             if mean > max_mean:
        #                 max_mean = mean
        #                 max_name = name

        #     if count == 0:
        #         print("[LoRA] no grad (no lora_ params with grad)")
        #     else:
        #         print(f"[LoRA] num_params_with_grad={count}, total_grad_norm={total_norm:.6f}, "
        #             f"max_mean_grad={max_mean:.6e} at {max_name}")


    # ================== 验证测试相关 ==========================
    def validation_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.current_step}-{batch['index'][0]}.png",
            (
                [
                    {
                        "type": "rgb",
                        "img": batch["rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    }
                ]
                if "rgb" in batch
                else []
            )
            + [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            ),
            name="validation_step",
            step=self.current_step,
        )
        # save_path = self.get_save_path(f"it{self.current_step}-val.ply")
        # self.gaussian.save_ply(save_path)
        # load_ply(save_path,self.get_save_path(f"it{self.current_step}-val-color.ply"))

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        only_rgb = True
        bg_color = [1, 1, 1] if False else [0, 0, 0]

        testbackground_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        out = self(batch,testbackground_tensor)
        if only_rgb:
            self.save_image_grid(
                f"it{self.current_step}-test/{batch['index'][0]}.png",
                (
                    [
                        {
                            "type": "rgb",
                            "img": batch["rgb"][0],
                            "kwargs": {"data_format": "HWC"},
                        }
                    ]
                    if "rgb" in batch
                    else []
                )
                + [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                + (
                    [
                        {
                            "type": "rgb",
                            "img": out["comp_normal"][0],
                            "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                        }
                    ]
                    if "comp_normal" in out
                    else []
                ),
                name="test_step",
                step=self.current_step,
            )
        else:
            self.save_image_grid(
                f"it{self.current_step}-test/{batch['index'][0]}.png",
                (
                    [
                        {
                            "type": "rgb",
                            "img": batch["rgb"][0],
                            "kwargs": {"data_format": "HWC"},
                        }
                    ]
                    if "rgb" in batch
                    else []
                )
                + [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                + (
                    [
                        {
                            "type": "rgb",
                            "img": out["comp_normal"][0],
                            "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                        }
                    ]
                    if "comp_normal" in out
                    else []
                )
                + (
                    [
                        {
                            "type": "grayscale",
                            "img": out["depth"][0],
                            "kwargs": {},
                        }
                    ]
                    if "depth" in out
                    else []
                )
                + [
                    {
                        "type": "grayscale",
                        "img": out["opacity"][0, :, :, 0],
                        "kwargs": {"cmap": None, "data_range": (0, 1)},
                    },
                ],
                name="test_step",
                step=self.current_step,
            )

    def on_test_epoch_end(self):
        self.save_img_sequence(
            f"it{self.current_step}-test",
            f"it{self.current_step}-test",
            "(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="test",
            step=self.current_step,
        )
        save_path = self.get_save_path(f"last_3dgs.ply")
        self.gaussian.save_ply(save_path)
        # 保存转换到rgb空间的点云
        from threestudio.systems.function.point_cloud import save_ply, load_from_3dgs
        coords, rgb = load_from_3dgs(save_path)
        save_ply(self.get_save_path(f"it{self.current_step}-test-color.ply"), coords, rgb)
    

    def compute_sds_loss(self,guidance_out,out):
        loss = 0.0

        loss = loss + guidance_out['loss_sds'] *self.C(self.cfg.loss['lambda_sds'])
        loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity", loss_sparsity)
        loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
        self.log("train/loss_opaque", loss_opaque)
        loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)
        return loss
    
    def compute_vsd_loss(self,guidance_out,out):
        loss = 0.0
        for name, value in guidance_out.items():
            if not (type(value) is torch.Tensor and value.numel() > 1):
                self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])


        loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity", loss_sparsity)
        loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
        self.log("train/loss_opaque", loss_opaque)
        loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)

        return loss
    
    def _regularize_3dgs(self, out):
        """仅对 3DGS 做正则（稀疏、opacity 先验等）；不包含 LoRA 相关项。"""
        loss = 0.0
        # === 稀疏正则 ===
        loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity", loss_sparsity)
        loss = loss + loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)


        # === opacity 先验（修复你原来 BCE(p,p) 的问题，避免把不透明度往 0/1 暴力推）===
        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        tau = getattr(self.cfg.loss, "opacity_target", 0.5)
        target = torch.full_like(opacity_clamped, tau)
        loss_opaque = binary_cross_entropy(opacity_clamped, target)
        self.log("train/loss_opaque", loss_opaque)
        loss = loss + loss_opaque * self.C(self.cfg.loss.lambda_opaque)


        return loss
    
    # def on_after_backward(self):
    #     # 只检查第一个 LoRA 参数
    #     for name, p in self.guidance.unet_lora.named_parameters():
    #         if "lora_" in name and p.requires_grad:
    #             if p.grad is None:
    #                 print(f"[X] LoRA NOT training: {name} grad=None")
    #             else:
    #                 print(f"[OK] LoRA training: {name} grad mean={p.grad.abs().mean().item():.6f}")
    #             break
    # def on_before_zero_grad(self, optimizer):
    #     # 这里是在 optimizer.step() 之后被调用的
    #     import torch
    #     for name, p in self.guidance.unet_lora.named_parameters():
    #         if "lora_" in name and p.requires_grad:
    #             finite = torch.isfinite(p).all().item()
    #             mean = p.detach().abs().mean().item()
    #             print(f"[param] LoRA {name}: finite={finite}, mean={mean:.6f}")