from dataclasses import dataclass
from typing import Tuple, Union

import numpy as np
import torch
from pytorch3d.renderer import (
    PointsRasterizationSettings,
    PointsRasterizer,
    NormWeightedCompositor,
)
from pytorch3d.utils import cameras_from_opencv_projection
from pytorch3d.structures import Pointclouds


@dataclass
class ProjectionConfig:
    projection_resolution: Union[str, Tuple[int, int]] = (720, 1280)
    adaptive_fill: bool = True
    adaptive_radius: int = 1

    def get_target_resolution(self, original_size: Tuple[int, int]) -> Tuple[int, int]:
        if self.projection_resolution == 'original':
            return original_size[0], original_size[1]
        elif isinstance(self.projection_resolution, (tuple, list)) and len(self.projection_resolution) == 2:
            H_target, W_target = self.projection_resolution
            return int(H_target), int(W_target)
        else:
            raise ValueError(f"无效的 projection_resolution 配置: {self.projection_resolution}")

    def scale_intrinsics(self, K: np.ndarray, original_size: Tuple[int, int]) -> np.ndarray:
        if self.projection_resolution == 'original':
            return K.copy()
        elif isinstance(self.projection_resolution, (tuple, list)) and len(self.projection_resolution) == 2:
            H_target, W_target = self.projection_resolution
            H_orig, W_orig = original_size
            K_scaled = K.copy()
            scale_x = float(W_target) / float(W_orig)
            scale_y = float(H_target) / float(H_orig)
            K_scaled[0, 0] *= scale_x
            K_scaled[1, 1] *= scale_y
            K_scaled[0, 2] *= scale_x
            K_scaled[1, 2] *= scale_y
            return K_scaled
        else:
            raise ValueError(f"无效的 projection_resolution 配置: {self.projection_resolution}")


class PointCloudProjector:
    def __init__(self, config: ProjectionConfig):
        self.config = config
        self.device = self._get_device()
        self._norm_compositor = NormWeightedCompositor()

    def _get_device(self) -> str:
        return 'cuda' if torch.cuda.is_available() else 'cpu'

    def project_to_image(
        self,
        points: np.ndarray,
        colors: np.ndarray,
        K: np.ndarray,
        camera_to_world: np.ndarray,
        H: int,
        W: int,
    ) -> Tuple[np.ndarray, np.ndarray]:
        if points.shape[0] == 0:
            return np.zeros((H, W, 3), dtype=np.uint8), np.zeros((H, W), dtype=np.float32)

        points_torch = torch.tensor(points, dtype=torch.float32, device=self.device)
        colors_torch = torch.tensor(colors, dtype=torch.float32, device=self.device).clamp(0.0, 1.0)

        world_to_camera_np = np.linalg.inv(camera_to_world).astype(np.float32)
        world_to_camera = torch.tensor(world_to_camera_np, dtype=torch.float32, device=self.device)

        ones = torch.ones((points_torch.shape[0], 1), dtype=torch.float32, device=self.device)
        pts_h = torch.cat([points_torch, ones], dim=1)
        pts_cam = (world_to_camera @ pts_h.T).T[:, :3]
        z_cam_all = pts_cam[:, 2]
        front_mask = z_cam_all > 0.0
        if not torch.any(front_mask):
            return np.zeros((H, W, 3), dtype=np.uint8), np.zeros((H, W), dtype=np.float32)

        pts_world_valid = points_torch[front_mask]
        cols_valid = colors_torch[front_mask]
        z_cam_valid = z_cam_all[front_mask]

        R_wc = world_to_camera[:3, :3].unsqueeze(0)
        T_wc = world_to_camera[:3, 3].unsqueeze(0)
        K_cv = torch.tensor(K, dtype=torch.float32, device=self.device).unsqueeze(0)
        image_size = torch.tensor([[H, W]], dtype=torch.float32, device=self.device)
        cameras = cameras_from_opencv_projection(R_wc, T_wc, K_cv, image_size)

        pcl = Pointclouds(points=[pts_world_valid], features=[cols_valid])

        rgb0, depth0, mask0 = self._render_with_radius(pcl, cameras, H, W, 1.5, z_cam_valid)
        rgb1, depth1, mask1 = self._render_with_radius(pcl, cameras, H, W, 3.0, z_cam_valid)

        rgb_out = rgb1.clone()
        depth_out = depth1.clone()
        if mask0.any():
            rgb_out[mask0] = rgb0[mask0]
            depth_out[mask0] = depth0[mask0]

        rgb_img = rgb_out.cpu().numpy()
        depth_img = depth_out.clamp(min=0.0).cpu().numpy()

        return rgb_img, depth_img

    def _render_with_radius(
        self,
        pcl: Pointclouds,
        cameras,
        H: int,
        W: int,
        pxl_radius: float,
        z_cam_valid: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        radius_ndc_local = float(pxl_radius) * 2.0 / float(max(H, W))
        raster_settings_local = PointsRasterizationSettings(
            image_size=(H, W),
            radius=radius_ndc_local,
            bin_size=None,
            max_points_per_bin=512 * 512,
        )
        rasterizer_local = PointsRasterizer(cameras=cameras, raster_settings=raster_settings_local)
        compositor_local = self._norm_compositor
        frags = rasterizer_local(pcl)
        r_ndc = raster_settings_local.radius
        dists2 = frags.dists.permute(0, 3, 1, 2)
        weights = torch.exp(-dists2 / (2.0 * (r_ndc * r_ndc) + 1e-12))
        valid_idx = (frags.idx[..., 0] >= 0).unsqueeze(1)
        weights = weights * valid_idx
        images_chw = compositor_local(
            frags.idx.long().permute(0, 3, 1, 2),
            weights,
            pcl.features_packed().permute(1, 0),
        )
        images_nhwc = images_chw.permute(0, 2, 3, 1)
        rgb_local = (images_nhwc[0].clamp(0.0, 1.0) * 255.0).to(torch.uint8)
        idx_local = frags.idx[0, ..., 0]
        depth_local = torch.zeros((H, W), dtype=torch.float32, device=self.device)
        valid_local = idx_local >= 0
        if valid_local.any():
            depth_local[valid_local] = z_cam_valid[idx_local[valid_local].long()]
        return rgb_local, depth_local, valid_local

    @staticmethod
    def depth_to_pointcloud(
        rgb_image: np.ndarray,
        depth_image: np.ndarray,
        K: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        H, W = depth_image.shape
        fx, fy = K[0, 0], K[1, 1]
        cx, cy = K[0, 2], K[1, 2]

        y, x = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')

        valid_mask = depth_image > 0
        valid_x = x[valid_mask]
        valid_y = y[valid_mask]
        valid_z = depth_image[valid_mask]

        points_x = (valid_x - cx) / fx * valid_z
        points_y = (valid_y - cy) / fy * valid_z
        points_z = valid_z

        points = np.stack([points_x, points_y, points_z], axis=-1)
        colors = rgb_image[valid_mask].astype(np.float32) / 255.0

        return points, colors


