import os
from dataclasses import dataclass, field

import torch
# 导入 checkpoint 工具
from torch.utils.checkpoint import checkpoint

import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.misc import cleanup, get_device
from threestudio.utils.ops import binary_cross_entropy, dot
from threestudio.utils.typing import *


@threestudio.register("mvdream-checkpoint-system")
class MVDreamSystem(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        visualize_samples: bool = False
        use_gradient_checkpointing: bool = False

    cfg: Config

    def configure(self) -> None:
        # set up geometry, material, background, renderer
        super().configure()
        
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.prompt_utils = self.prompt_processor()
        
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)

    def on_load_checkpoint(self, checkpoint):
        for k in list(checkpoint['state_dict'].keys()):
            if k.startswith("guidance."):
                return
        guidance_state_dict = {"guidance." + k: v for (k, v) in self.guidance.state_dict().items()}
        checkpoint['state_dict'] = {**checkpoint['state_dict'], **guidance_state_dict}
        return

    def on_save_checkpoint(self, checkpoint):
        for k in list(checkpoint['state_dict'].keys()):
            if k.startswith("guidance."):
                checkpoint['state_dict'].pop(k)
        return

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        return self.renderer(**batch)

    def training_step(self, batch, batch_idx):
        def create_renderer_output():
            return self.renderer(**batch)
        
        if self.cfg.use_gradient_checkpointing:
            out = checkpoint(create_renderer_output, use_reentrant=True)
        else:
            out = self(batch)

        schedule_out = self.t_scheduler(out["comp_rgb"].shape[0])
        noise_out = self.noise_generator(out, batch, winlose=False)

        guidance_out = self.guidance(
            out["comp_rgb"], self.prompt_utils, **batch, **schedule_out, **noise_out,
        )

        loss = 0.0

        for name, value in guidance_out.items():
            self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])

        if self.C(self.cfg.loss.lambda_orient) > 0:
            if "normal" not in out:
                raise ValueError(
                    "Normal is required for orientation loss, no normal is found in the output."
                )
            loss_orient = (
                              out["weights"].detach()
                              * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
                          ).sum() / (out["opacity"] > 0).sum()
            self.log("train/loss_orient", loss_orient)
            loss += loss_orient * self.C(self.cfg.loss.lambda_orient)

        # ... (后面的代码保持不变) ...

        if self.C(self.cfg.loss.lambda_sparsity) > 0:
            loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
            self.log("train/loss_sparsity", loss_sparsity)
            loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        if self.C(self.cfg.loss.lambda_opaque) > 0:
            opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
            loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
            self.log("train/loss_opaque", loss_opaque)
            loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)

        if self.C(self.cfg.loss.lambda_z_variance) > 0:
            loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean()
            self.log("train/loss_z_variance", loss_z_variance)
            loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance)

        if hasattr(self.cfg.loss, "lambda_eikonal") and self.C(self.cfg.loss.lambda_eikonal) > 0:
            loss_eikonal = (
                (torch.linalg.norm(out["sdf_grad"], ord=2, dim=-1) - 1.0) ** 2
            ).mean()
            self.log("train/loss_eikonal", loss_eikonal)
            loss += loss_eikonal * self.C(self.cfg.loss.lambda_eikonal)

        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))

        return {"loss": loss}

    # ... validation_step, test_step 等其他方法保持不变 ...
    def validation_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.true_global_step}-{batch['index'][0]}.png",
            (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_rgb"][0],
                        "kwargs": {"data_format": "HWC"},
                    },
                ]
                if "comp_rgb" in out
                else []
            )
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            )
            + [
                {
                    "type": "grayscale",
                    "img": out["opacity"][0, :, :, 0],
                    "kwargs": {"cmap": None, "data_range": (0, 1)},
                },
            ],
            name="validation_step",
            step=self.true_global_step,
        )

    def on_validation_epoch_end(self):
        pass

    def test_step(self, batch, batch_idx):
        out = self(batch)
        self.save_rgb_image(f"it{self.true_global_step}-test/opacity/{str(batch['index'][0].item()).zfill(4)}.png", img=out["opacity"][0])
        self.save_rgb_image(f"it{self.true_global_step}-test/rgb_images/{str(batch['index'][0].item()).zfill(4)}.png", img=out["comp_rgb"][0])
        self.save_normal_map(f"it{self.true_global_step}-test/normal_world/{str(batch['index'][0].item()).zfill(4)}.png", img=out["comp_normal"][0].cpu().numpy())
        import numpy as np
        # np.save(self.get_save_path(f"it{self.true_global_step}-test/normal_world/{str(batch['index'][0]).zfill(4)}.npy"), out["comp_normal"][0].cpu().numpy())
        np.save(self.get_save_path(f"it{self.true_global_step}-test/batch_data/{str(batch['index'][0].item()).zfill(4)}.npy"), batch)
        self.save_rgb_image(f"it{self.true_global_step}-test/rgb_images/{str(batch['index'][0].item()).zfill(4)}.png", img=out["comp_rgb"][0])
        self.save_image_grid(
            f"it{self.true_global_step}-test/image_grid/{str(batch['index'][0].item()).zfill(4)}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            )
            + [
                {
                    "type": "grayscale",
                    "img": out["opacity"][0, :, :, 0],
                    "kwargs": {"cmap": None, "data_range": (0, 1)},
                },
            ],
            name="test_step",
            step=self.true_global_step,
        )

    def on_test_epoch_end(self):
        # embedding_dir = "./learned_negative"
        # os.makedirs(embedding_dir, exist_ok=True)
        # prompt = self.prompt_utils.prompt.replace(" ", "_")
        self.save_img_sequence(
            f"it{self.true_global_step}-test",
            f"it{self.true_global_step}-test/rgb_images/",
            "(\d+)\.png",
            save_format="mp4",
            fps=30,
            name="test",
            step=self.true_global_step,
        )
        # torch.save(
        #     self.guidance.learnable_text,
        #     os.path.join(embedding_dir, f"{prompt}.pt")
        # )
