from __future__ import annotations

import numpy as np
from open3d.geometry import PointCloud

from pc_rl.utils.o3d import np_to_o3d, o3d_to_np


class VoxelGridDownsampling:
    def __init__(self, voxel_grid_size: float) -> None:
        self.voxel_grid_size = voxel_grid_size

    def __call__(self, pcd: PointCloud | np.ndarray) -> PointCloud | np.ndarray:
        if numpy := isinstance(pcd, np.ndarray):
            pcd_o3d = np_to_o3d(pcd)
        else:
            pcd_o3d = pcd

        pcd_o3d = pcd_o3d.voxel_down_sample(self.voxel_grid_size)

        if numpy:
            return o3d_to_np(pcd_o3d)

        return pcd_o3d
