from __future__ import annotations

import numpy as np
from open3d.geometry import PointCloud


class Normalize:
    def __init__(self, center: float | None = None, scale: float | None = None):
        if center is not None:
            center = np.array(center)
        if scale is not None:
            scale = float(scale)
        self.center = center
        self.scale = scale

    def __call__(self, pcd: PointCloud | np.ndarray) -> PointCloud | np.ndarray:
        if isinstance(pcd, np.ndarray):
            pos = pcd[:, :3]
        else:
            pos = np.asarray(pcd.points)

        if (center := self.center) is None:
            center = pos.mean(axis=0)
        pos[...] -= center

        if (scale := self.scale) is None:
            scale = 0.999999 / np.abs(pos).max()
        pos[...] *= scale

        return pcd
