# utils/statistics.py
import torch, torch_scatter as ts
from typing import Dict, Tuple

# ---------- voxel hashing ----------
def coords_to_hash(coords: torch.Tensor) -> torch.Tensor:
    coords = coords.to(torch.int64)           # avoid overflow
    return (coords[:, 0] << 48) + (coords[:, 1] << 32) + \
           (coords[:, 2] << 16) +  coords[:, 3]

# _hash_cache = {}                              # {id(coords.storage()): (dict, N)}
_hash_cache: Dict[Tuple[int, int], Tuple[dict, int]] = {}

@torch.no_grad()
def build_hash2row(ref_coords: torch.Tensor):
    key = (ref_coords.storage().data_ptr(), ref_coords.size(0))  # 주소 + 행수
    if key in _hash_cache:
        return _hash_cache[key]
    h = coords_to_hash(ref_coords).cpu().numpy()
    table = {int(v): i for i, v in enumerate(h)}
    _hash_cache[key] = (table, len(h))
    return table, len(h)
# @torch.no_grad()
# def build_hash2row(ref_coords: torch.Tensor):
#     key = ref_coords.storage().data_ptr()
#     if key in _hash_cache:
#         return _hash_cache[key]
#     h     = coords_to_hash(ref_coords).cpu().numpy()
#     table = {int(v): i for i, v in enumerate(h)}
#     _hash_cache[key] = (table, len(h))
#     return table, len(h)

# ---------- feature gathering ----------
def gather_feat(coords, feats, hash2row, n_row):
    idx = torch.tensor([hash2row.get(int(h), -1)
                        for h in coords_to_hash(coords).cpu().numpy()],
                       device=feats.device, dtype=torch.long)
    valid = (idx >= 0) & (idx < n_row)
    out   = feats.new_zeros(n_row, feats.shape[1])
    out[idx[valid]] = feats[valid]
    return out

# ---------- numerically safe mean/std on arbitrary rows ----------
def safe_row_stats(feats, rows, n_row, clip=1e4):
    valid = (rows >= 0) & (rows < n_row)
    if valid.sum() == 0:
        z = feats.new_zeros
        return z(n_row, feats.size(1)), z(n_row, feats.size(1))

    feats, rows = feats[valid], rows[valid]
    feats = torch.nan_to_num(feats, nan=0.0, posinf=0.0,
                             neginf=0.0).clamp(-clip, clip)
    feats64 = feats.to(torch.float64)

    cnt  = ts.scatter_add(
        torch.ones((feats.size(0), 1), dtype=torch.float64, device=feats.device),
        rows.unsqueeze(1), dim=0, dim_size=n_row
    )
    sums = ts.scatter_add(feats64,          rows, dim=0, dim_size=n_row)
    sqs  = ts.scatter_add(feats64 * feats64, rows, dim=0, dim_size=n_row)

    cnt_safe = cnt.clamp_min(1.0)                # **out‑of‑place**
    mean = sums / cnt_safe
    var  = sqs  / cnt_safe - mean.pow(2)
    std  = torch.sqrt(torch.clamp(var, min=0.0))

    zero_mask = (cnt.squeeze(1) == 0).unsqueeze(1)   # (n_row,1) bool
    mean = mean.masked_fill(zero_mask, 0.0)          # **no in‑place**
    std  = std .masked_fill(zero_mask, 0.0)

    return mean.float(), std.float()

# def safe_row_stats(feats, rows, n_row, clip=1e4):
#     feats = torch.nan_to_num(feats, nan=0.0, posinf=0.0,
#                              neginf=0.0).clamp_(-clip, clip)
#     feats64 = feats.to(torch.float64)
#     cnt  = ts.scatter_add(torch.ones((feats.size(0), 1), dtype=torch.float64,
#                                      device=feats.device),
#                           rows.unsqueeze(1), dim=0, dim_size=n_row)
#     sums = ts.scatter_add(feats64,          rows, dim=0, dim_size=n_row)
#     sqs  = ts.scatter_add(feats64 * feats64, rows, dim=0, dim_size=n_row)
#     cnt = cnt.clamp_min_(1.0)               # avoid /0
#     mean = sums / cnt
#     var  = sqs  / cnt - mean.pow(2)
#     std  = torch.sqrt(torch.clamp(var, min=0.0))
#     zero = cnt.squeeze(1) == 0
#     mean[zero] = 0; std[zero] = 0
#     return mean.float(), std.float()
