import numpy as np
import torch
from typing import List, Optional, Tuple
from dataclasses import dataclass


@dataclass
class GaussianBlock:
    block_id: int
    num_points: int
    xyz: torch.Tensor
    scales: torch.Tensor
    rotations: torch.Tensor
    opacity: torch.Tensor
    features_dc: Optional[torch.Tensor] = None
    features_rest: Optional[torch.Tensor] = None

    def to_tensor(self) -> torch.Tensor:
        components = [self.xyz, self.scales, self.rotations, self.opacity]
        if self.features_dc is not None:
            components.append(self.features_dc)
        if self.features_rest is not None:
            components.append(self.features_rest)
        return torch.cat(components, dim=-1)

    @staticmethod
    def from_tensor(block_id: int, tensor: torch.Tensor) -> 'GaussianBlock':
        num_points = tensor.shape[0]
        idx = 0
        xyz = tensor[:, idx:idx+3]; idx += 3
        scales = tensor[:, idx:idx+3]; idx += 3
        rotations = tensor[:, idx:idx+4]; idx += 4
        opacity = tensor[:, idx:idx+1]; idx += 1
        features_dc = tensor[:, idx:idx+3] if tensor.shape[1] > idx else None; idx += 3
        features_rest = tensor[:, idx:idx+45] if tensor.shape[1] > idx else None
        return GaussianBlock(block_id, num_points, xyz, scales, rotations, opacity, features_dc, features_rest)


def compute_morton_code(xyz: torch.Tensor, global_min: torch.Tensor, global_max: torch.Tensor) -> torch.Tensor:
    normalized = (xyz - global_min) / (global_max - global_min + 1e-8)
    normalized = torch.clamp(normalized, 0.0, 1.0)
    grid_size = 1 << 20
    ix = (normalized[:, 0] * (grid_size - 1)).long()
    iy = (normalized[:, 1] * (grid_size - 1)).long()
    iz = (normalized[:, 2] * (grid_size - 1)).long()

    def expand_bits(v):
        v = (v | (v << 32)) & 0x1f00000000ffff
        v = (v | (v << 16)) & 0x1f0000ff0000ff
        v = (v | (v << 8)) & 0x100f00f00f00f00f
        v = (v | (v << 4)) & 0x10c30c30c30c30c3
        v = (v | (v << 2)) & 0x1249249249249249
        return v

    xx = expand_bits(ix)
    yy = expand_bits(iy)
    zz = expand_bits(iz)
    return xx | (yy << 1) | (zz << 2)


def compute_block_bounds(xyz: np.ndarray, block_size: int) -> np.ndarray:
    n_points = xyz.shape[0]
    num_blocks = (n_points + block_size - 1) // block_size
    block_bounds = np.zeros((num_blocks, 6), dtype=np.float32)

    for block_id in range(num_blocks):
        start_idx = block_id * block_size
        end_idx = min(start_idx + block_size, n_points)
        block_xyz = xyz[start_idx:end_idx]
        block_bounds[block_id, :3] = block_xyz.min(axis=0)
        block_bounds[block_id, 3:6] = block_xyz.max(axis=0)

    return block_bounds


class FrustumCuller:
    def __init__(self, block_bounds: np.ndarray, scene_radius: float = 100.0):
        self.block_bounds = block_bounds
        self.num_blocks = len(block_bounds)
        self.scene_radius = scene_radius

    def cull(self, view_matrix: np.ndarray, proj_matrix: np.ndarray) -> List[int]:
        planes = self._extract_frustum_planes(view_matrix, proj_matrix)
        visible = []
        for block_id in range(self.num_blocks):
            if self._test_aabb_planes(block_id, planes):
                visible.append(block_id)
        return visible

    def _extract_frustum_planes(self, view: np.ndarray, proj: np.ndarray) -> np.ndarray:
        vp = proj @ view
        planes = np.zeros((6, 4), dtype=np.float32)
        planes[0] = vp[3] + vp[0]  # Left
        planes[1] = vp[3] - vp[0]  # Right
        planes[2] = vp[3] + vp[1]  # Bottom
        planes[3] = vp[3] - vp[1]  # Top
        planes[4] = vp[3] + vp[2]  # Near
        planes[5] = vp[3] - vp[2]  # Far

        norms = np.linalg.norm(planes[:, :3], axis=1, keepdims=True)
        planes = planes / (norms + 1e-8)
        return planes

    def _test_aabb_planes(self, block_id: int, planes: np.ndarray) -> bool:
        bounds = self.block_bounds[block_id]
        min_pt = bounds[:3]
        max_pt = bounds[3:6]

        for i in range(6):
            normal = planes[i, :3]
            d = planes[i, 3]
            p_vertex = np.where(normal >= 0, max_pt, min_pt)
            if np.dot(normal, p_vertex) + d < 0:
                return False
        return True
