"""Scan Utilities: Morton and Hilbert Curve Encoding"""

import torch


def morton_encode(xyz, bits=10):
    """Compute Morton Code (Z-order curve)"""
    device = xyz.device

    # Normalize to [0, 2^bits - 1]
    xyz_min = xyz.min(dim=0)[0]
    xyz_max = xyz.max(dim=0)[0]
    xyz_range = (xyz_max - xyz_min).clamp(min=1e-6)

    xyz_norm = ((xyz - xyz_min) / xyz_range * (2**bits - 1)).long()
    xyz_norm = xyz_norm.clamp(0, 2**bits - 1)

    x, y, z = xyz_norm[:, 0], xyz_norm[:, 1], xyz_norm[:, 2]

    # Interleave bits
    code = torch.zeros(xyz.shape[0], dtype=torch.long, device=device)
    for i in range(bits):
        code |= ((x >> i) & 1) << (3 * i)
        code |= ((y >> i) & 1) << (3 * i + 1)
        code |= ((z >> i) & 1) << (3 * i + 2)

    return code


def morton_sort(xyz, features):
    """Sort point cloud by Morton Code"""
    codes = morton_encode(xyz)
    sorted_indices = torch.argsort(codes)
    inverse_indices = torch.argsort(sorted_indices)

    sorted_xyz = xyz[sorted_indices]
    sorted_features = features[sorted_indices]

    return sorted_xyz, sorted_features, inverse_indices


def hilbert_encode(xyz, bits=10):
    """Compute 3D Hilbert Code"""
    device = xyz.device
    N = xyz.shape[0]

    # Normalize to [0, 2^bits - 1]
    xyz_min = xyz.min(dim=0)[0]
    xyz_max = xyz.max(dim=0)[0]
    xyz_range = (xyz_max - xyz_min).clamp(min=1e-6)

    xyz_norm = ((xyz - xyz_min) / xyz_range * (2**bits - 1)).long()
    xyz_norm = xyz_norm.clamp(0, 2**bits - 1)

    # Simplified 3D Hilbert encoding (state-based lookup)
    # Uses 8 states for different subcube traversal orders
    codes = torch.zeros(N, dtype=torch.long, device=device)

    for i in range(bits - 1, -1, -1):
        # Extract current bit
        xi = (xyz_norm[:, 0] >> i) & 1
        yi = (xyz_norm[:, 1] >> i) & 1
        zi = (xyz_norm[:, 2] >> i) & 1

        # Combine into octant index (0-7)
        octant = (xi << 2) | (yi << 1) | zi

        # Use Gray code transformation to approximate Hilbert order
        gray = octant ^ (octant >> 1)

        codes = (codes << 3) | gray

    return codes


def hilbert_sort(xyz, features):
    """Sort point cloud by Hilbert Code"""
    codes = hilbert_encode(xyz)
    sorted_indices = torch.argsort(codes)
    inverse_indices = torch.argsort(sorted_indices)

    sorted_xyz = xyz[sorted_indices]
    sorted_features = features[sorted_indices]

    return sorted_xyz, sorted_features, inverse_indices


def l2norm(x, dim=-1, eps=1e-6):
    """L2 normalization, aligned with FLA library"""
    return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
