from __future__ import annotations

from typing import Sequence

import numpy as np
import open3d as o3d
from open3d.geometry import PointCloud


class AxisAlignedCrop:
    def __init__(
        self,
        min_bound: np.ndarray | Sequence[float],
        max_bound: np.ndarray | Sequence[float],
    ) -> None:
        self.min_bound = np.asarray(min_bound)
        self.max_bound = np.asarray(max_bound)

    def __call__(self, pcd: PointCloud | np.ndarray) -> PointCloud | np.ndarray:
        if isinstance(pcd, np.ndarray):
            pos = pcd[:, :3]
            mask = np.all((pos >= self.min_bound) & (pos <= self.max_bound), axis=-1)
            pcd = pcd[mask]
            return pcd

        else:
            bounding_box = o3d.geometry.AxisAlignedBoundingBox(
                min_bound=self.min_bound,
                max_bound=self.max_bound,
            )

            pcd = pcd.crop(bounding_box)

            return pcd
