import torch
from numpy import ndarray
from typing import Any, Dict, Union


class Instances3D:
    """
    This class represents a list of instances in a scene.
    It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
    All fields must have the same ``__len__`` which is the number of instances.

    All other (non-field) attributes of this class are considered private:
    they must start with '_' and are not modifiable by a user.

    Some basic usage:

    1. Set/get/check a field:

       .. code-block:: python

          instances.gt_boxes = Boxes(...)
          print(instances.pred_masks)  # a tensor of shape (N, H, W)
          print('gt_masks' in instances)

    2. ``len(instances)`` returns the number of instances
    3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
       and returns a new :class:`Instances`.
       Typically, ``indices`` is a integer vector of indices,
       or a binary mask of length ``num_instances``

       .. code-block:: python

          category_3_detections = instances[instances.pred_classes == 3]
          confident_detections = instances[instances.scores > 0.9]
    """

    def __init__(self, num_points: int, gt_instances: ndarray = None, **kwargs: Any):
        """
        Args:
            num_points: number of points of the scene.
            kwargs: fields to add to this `Instances`.
        """
        self._num_points = num_points
        self._gt_instances = gt_instances
        self._fields: Dict[str, Any] = {}
        for k, v in kwargs.items():
            self.set(k, v)

    @property
    def num_points(self) -> int:
        """
        Returns:
            int
        """
        return self._num_points

    @property
    def gt_instances(self) -> ndarray:
        """
        Returns:
            int
        """
        return self._gt_instances

    def __setattr__(self, name: str, val: Any) -> None:
        if name.startswith('_'):
            super().__setattr__(name, val)
        else:
            self.set(name, val)

    def __getattr__(self, name: str) -> Any:
        if name == '_fields' or name not in self._fields:
            raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
        return self._fields[name]

    def set(self, name: str, value: Any) -> None:
        """
        Set the field named `name` to `value`.
        The length of `value` must be the number of instances,
        and must agree with other existing fields in this object.
        """
        data_len = len(value)
        if len(self._fields):
            assert (len(self) == data_len), 'Adding a field of length {} to a Instances of length {}'.format(
                data_len, len(self))
        self._fields[name] = value

    def has(self, name: str) -> bool:
        """
        Returns:
            bool: whether the field called `name` exists.
        """
        return name in self._fields

    def remove(self, name: str) -> None:
        """
        Remove the field called `name`.
        """
        del self._fields[name]

    def get(self, name: str) -> Any:
        """
        Returns the field called `name`.
        """
        return self._fields[name]

    def get_fields(self) -> Dict[str, Any]:
        """
        Returns:
            dict: a dict which maps names (str) to data of the fields

        Modifying the returned dict will modify this instance.
        """
        return self._fields

    # Tensor-like methods
    def to(self, *args: Any, **kwargs: Any) -> 'Instances3D':
        """
        Returns:
            Instances: all fields are called with a `to(device)`, if the field has this method.
        """
        ret = Instances3D(self._num_points, self._gt_instances)
        for k, v in self._fields.items():
            if hasattr(v, 'to'):
                v = v.to(*args, **kwargs)
            ret.set(k, v)
        return ret

    def cuda(self, *args: Any, **kwargs: Any) -> 'Instances3D':
        ret = Instances3D(self._num_points, self._gt_instances)
        for k, v in self._fields.items():
            if hasattr(v, 'cuda'):
                v = v.cuda(*args, **kwargs)
            ret.set(k, v)
        return ret

    def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> 'Instances3D':
        """
        Args:
            item: an index-like object and will be used to index all the fields.

        Returns:
            If `item` is a string, return the data in the corresponding field.
            Otherwise, returns an `Instances` where all fields are indexed by `item`.
        """
        if type(item) == int:
            if item >= len(self) or item < -len(self):
                raise IndexError('Instances index out of range!')
            else:
                item = slice(item, None, len(self))

        ret = Instances3D(self._num_points, self._gt_instances)
        for k, v in self._fields.items():
            ret.set(k, v[item])
        return ret

    def __len__(self) -> int:
        for v in self._fields.values():
            # use __len__ because len() has to be int and is not friendly to tracing
            return v.__len__()
        raise NotImplementedError('Empty Instances does not support __len__!')

    def __iter__(self):
        raise NotImplementedError('`Instances` object is not iterable!')

    def __str__(self) -> str:
        s = self.__class__.__name__ + '('
        s += 'num_instances={}, '.format(len(self))
        s += 'num_points={}, '.format(self._num_points)
        s += 'fields=[{}])'.format(', '.join((f'{k}: {v}' for k, v in self._fields.items())))
        return s

    __repr__ = __str__
