# ============================================================
# diff_gaussian_rasterizer_shading.py
# ============================================================
import math
from dataclasses import dataclass

import threestudio
import torch
import torch.nn.functional as F
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.renderers.base import Rasterizer
from threestudio.utils.typing import *

from ..materials.gaussian_material import GaussianDiffuseWithPointLightMaterial
from .gaussian_batch_renderer import GaussianBatchRenderer


class Depth2Normal(torch.nn.Module):
    """
    Input:  x [B,C,H,W] where C=3 (xyz map channels)
    Output: normal (unnormalized) [B,3,H,W]
    """
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # 使用 Sobel 算子替代简单的中心差分
        # Sobel 在计算梯度的同时会在垂直方向做高斯平滑，能有效消除 3DGS 的表面噪点
        self.delx = torch.tensor(
            [
                [-1.0, 0.0, 1.0],
                [-2.0, 0.0, 2.0],
                [-1.0, 0.0, 1.0],
            ],
            dtype=torch.float32,
        )
        self.dely = torch.tensor(
            [
                [-1.0, -2.0, -1.0],
                [ 0.0,  0.0,  0.0],
                [ 1.0,  2.0,  1.0],
            ],
            dtype=torch.float32,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        kx = self.delx.view(1, 1, 3, 3).to(device=x.device, dtype=x.dtype)
        ky = self.dely.view(1, 1, 3, 3).to(device=x.device, dtype=x.dtype)

        # per-channel conv: reshape (B*C,1,H,W)
        dx = F.conv2d(x.reshape(B * C, 1, H, W), kx, padding=1).reshape(B, C, H, W)
        dy = F.conv2d(x.reshape(B * C, 1, H, W), ky, padding=1).reshape(B, C, H, W)

        # normal = cross(dP/dx, dP/dy)  (sign may differ by convention)
        n = torch.cross(dx, dy, dim=1)
        return n


@threestudio.register("diff-gaussian-rasterizer-shading")
class DiffGaussian(Rasterizer, GaussianBatchRenderer):
    @dataclass
    class Config(Rasterizer.Config):
        debug: bool = False
        back_ground_color: Tuple[float, float, float] = (1, 1, 1)

        # alpha threshold used for stable albedo division and TV weights
        alpha_floor: float = 0.05

    cfg: Config

    def configure(
        self,
        geometry: BaseGeometry,
        material: BaseMaterial,
        background: BaseBackground,
    ) -> None:
        if not isinstance(material, GaussianDiffuseWithPointLightMaterial):
            raise NotImplementedError(
                "diff-gaussian-rasterizer-shading only support Gaussian material."
            )
        super().configure(geometry, material, background)
        self.normal_module = Depth2Normal()
        self.background_tensor = torch.tensor(
            self.cfg.back_ground_color, dtype=torch.float32, device="cuda"
        )

    def forward(
        self,
        viewpoint_camera,
        bg_color: torch.Tensor,
        scaling_modifier=1.0,
        override_color=None,
        **kwargs
    ) -> Dict[str, Any]:
        """
        Returns:
          - render: [3,H,W]
          - normal: [3,H,W]  (VIS in [0,1] for display)
          - normal_raw: [H,W,3] in [-1,1]  (TRAIN)
          - depth: [H,W,1]  (TRAIN)
          - mask: [1,H,W] alpha
        """
        # NOTE: if you use neural background, raster bg can be zeros
        bg_color = bg_color * 0

        pc = self.geometry

        # screen-space points (for grad + densification stats)
        screenspace_points = torch.zeros_like(
            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
        )
        try:
            screenspace_points.retain_grad()
        except Exception:
            pass

        tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

        raster_settings = GaussianRasterizationSettings(
            image_height=int(viewpoint_camera.image_height),
            image_width=int(viewpoint_camera.image_width),
            tanfovx=tanfovx,
            tanfovy=tanfovy,
            bg=bg_color,
            scale_modifier=scaling_modifier,
            viewmatrix=viewpoint_camera.world_view_transform,
            projmatrix=viewpoint_camera.full_proj_transform,
            sh_degree=pc.active_sh_degree,
            campos=viewpoint_camera.camera_center,
            prefiltered=False,
            debug=False,
        )
        rasterizer = GaussianRasterizer(raster_settings=raster_settings)

        means3D = pc.get_xyz
        means2D = screenspace_points
        opacity = pc.get_opacity
        scales = pc.get_scaling
        rotations = pc.get_rotation
        cov3D_precomp = None

        shs = None
        colors_precomp = None
        if override_color is None:
            shs = pc.get_features
        else:
            colors_precomp = override_color

        # --------- Rays (important: normalize!) ----------
        batch_idx = kwargs["batch_idx"]
        rays_d = kwargs["rays_d"][batch_idx]
        rays_o = kwargs["rays_o"][batch_idx]
        rays_d = F.normalize(rays_d, dim=-1, eps=1e-6)

        # neural background expects unit directions
        comp_rgb_bg = self.background(dirs=rays_d.unsqueeze(0))  # [1,H*W,3] or [1,H,W,3] depending on impl

        # --------- Rasterize ----------
        rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
            means3D=means3D,
            means2D=means2D,
            shs=shs,
            colors_precomp=colors_precomp,
            opacities=opacity,
            scales=scales,
            rotations=rotations,
            cov3D_precomp=cov3D_precomp,
        )
        # rendered_image: [3,H,W]
        _, H, W = rendered_image.shape

        # --------- XYZ map (assumes rendered_depth is ray t) ----------
        # If your rendered_depth is view-space z, you must unproject instead.
        xyz_map = rays_o + rendered_depth.permute(1, 2, 0) * rays_d  # [H,W,3]

        # --------- Normal from xyz map ----------
        n = self.normal_module(xyz_map.permute(2, 0, 1).unsqueeze(0))[0]  # [3,H,W]
        n = F.normalize(n, dim=0, eps=1e-6)  # [-1,1]

        # Optional predicted normal from geometry
        if getattr(pc.cfg, "pred_normal", False):
            pred_normal_map, _, _, _ = rasterizer(
                means3D=means3D,
                means2D=torch.zeros_like(means2D),
                shs=pc.get_normal.unsqueeze(1),
                colors_precomp=None,
                opacities=opacity,
                scales=scales,
                rotations=rotations,
                cov3D_precomp=cov3D_precomp,
            )
        else:
            pred_normal_map = None

        # --------- Shading normal selection ----------
        if pred_normal_map is not None:
            shading_normal = pred_normal_map.permute(1, 2, 0).detach() * 2.0 - 1.0
            shading_normal = F.normalize(shading_normal, dim=2, eps=1e-6)
        else:
            shading_normal = n.permute(1, 2, 0)  # [H,W,3]

        # --------- Stable albedo ----------
        alpha_floor = float(getattr(self.cfg, "alpha_floor", 0.05))
        alpha_safe = rendered_alpha.clamp_min(alpha_floor)  # [1,H,W]
        albedo = (rendered_image / alpha_safe).clamp(0, 1).permute(1, 2, 0)  # [H,W,3]

        # --------- Light positions ----------
        light_positions = kwargs["light_positions"][batch_idx, None, None, :].expand(H, W, -1)

        # --------- Material shading ----------
        rgb_fg = self.material(
            positions=xyz_map,
            shading_normal=shading_normal,
            albedo=albedo,
            light_positions=light_positions,
        ).permute(2, 0, 1)  # [3,H,W]

        # --------- Composite ----------
        # comp_rgb_bg might be [1,H,W,3] or [1,N,3]; adapt minimally:
        if comp_rgb_bg.ndim == 3:
            # [1, H*W, 3] -> [H,W,3]
            comp_rgb_bg_hw = comp_rgb_bg.reshape(1, H, W, 3)[0]
        else:
            # [1,H,W,3]
            comp_rgb_bg_hw = comp_rgb_bg[0]
        bg_chw = comp_rgb_bg_hw.permute(2, 0, 1)  # [3,H,W]

        rendered_image = rgb_fg * rendered_alpha + (1.0 - rendered_alpha) * bg_chw
        rendered_image = rendered_image.clamp(0, 1)

        # --------- Outputs for TRAIN vs VIS ----------
        # TRAIN: raw normal [-1,1], depth, alpha
        normal_raw = n.permute(1, 2, 0)                # [H,W,3]
        depth_train = rendered_depth.permute(1, 2, 0)   # [H,W,1]
        alpha_train = rendered_alpha.permute(1, 2, 0)   # [H,W,1]

        # VIS: normal in [0,1] blended with gray background
        normal_vis = (n * 0.5 + 0.5)  # [3,H,W]
        normal_vis = normal_vis * rendered_alpha + (1.0 - rendered_alpha) * 0.5  # blend
        # IMPORTANT: DO NOT detach training tensors here.

        if self.training:
            screenspace_points.retain_grad()

        return {
            "render": rendered_image,
            "normal": normal_vis,         # [3,H,W] for visualization
            "normal_raw": normal_raw,     # [H,W,3] for training
            "depth": depth_train,         # [H,W,1] for training
            "alpha": alpha_train,         # [H,W,1] for weighting losses
            "pred_normal": pred_normal_map,
            "mask": rendered_alpha,       # [1,H,W]
            "viewspace_points": screenspace_points,
            "visibility_filter": radii > 0,
            "radii": radii,
        }