# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Custom collate function that includes cases for nerfstudio types.
"""

import collections
import collections.abc
import re
from typing import Any, Callable, Dict, Union

import torch
import torch.utils.data

from nerfstudio.cameras.cameras import Cameras

NERFSTUDIO_COLLATE_ERR_MSG_FORMAT = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts, lists or anything in {}; found {}"
)
np_str_obj_array_pattern = re.compile(r"[SaUO]")


def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], None] = None) -> Any:
    r"""
    This is the default pytorch collate function, but with support for nerfstudio types. All documentation
    below is copied straight over from pytorch's default_collate function, python version 3.8.13,
    pytorch version '1.12.1+cu113'. Custom nerfstudio types are accounted for at the end, and extra
    mappings can be passed in to handle custom types. These mappings are from types: callable (types
    being like int or float or the return value of type(3.), etc). The only code before we parse for custom types that
    was changed from default pytorch was the addition of the extra_mappings argument, a find and replace operation
    from default_collate to nerfstudio_collate, and the addition of the nerfstudio_collate_err_msg_format variable.


    Function that takes in a batch of data and puts the elements within the batch
    into a tensor with an additional outer dimension - batch size. The exact output type can be
    a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
    Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
    This is used as the default function for collation when
    `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.

    Here is the general input type (based on the type of the element within the batch) to output type mapping:

        * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
        * NumPy Arrays -> :class:`torch.Tensor`
        * `float` -> :class:`torch.Tensor`
        * `int` -> :class:`torch.Tensor`
        * `str` -> `str` (unchanged)
        * `bytes` -> `bytes` (unchanged)
        * `Mapping[K, V_i]` -> `Mapping[K, nerfstudio_collate([V_1, V_2, ...])]`
        * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[nerfstudio_collate([V1_1, V1_2, ...]),
          nerfstudio_collate([V2_1, V2_2, ...]), ...]`
        * `Sequence[V1_i, V2_i, ...]` -> `Sequence[nerfstudio_collate([V1_1, V1_2, ...]),
          nerfstudio_collate([V2_1, V2_2, ...]), ...]`

    Args:
        batch: a single batch to be collated

    Examples:
        >>> # Example with a batch of `int`s:
        >>> nerfstudio_collate([0, 1, 2, 3])
        tensor([0, 1, 2, 3])
        >>> # Example with a batch of `str`s:
        >>> nerfstudio_collate(['a', 'b', 'c'])
        ['a', 'b', 'c']
        >>> # Example with `Map` inside the batch:
        >>> nerfstudio_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
        {'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
        >>> # Example with `NamedTuple` inside the batch:
        >>> Point = namedtuple('Point', ['x', 'y'])
        >>> nerfstudio_collate([Point(0, 0), Point(1, 1)])
        Point(x=tensor([0, 1]), y=tensor([0, 1]))
        >>> # Example with `Tuple` inside the batch:
        >>> nerfstudio_collate([(0, 1), (2, 3)])
        [tensor([0, 2]), tensor([1, 3])]
        >>> # Example with `List` inside the batch:
        >>> nerfstudio_collate([[0, 1], [2, 3]])
        [tensor([0, 2]), tensor([1, 3])]
    """
    if extra_mappings is None:
        extra_mappings = {}
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel, device=elem.device)
            out = elem.new(storage).resize_(len(batch), *list(elem.size()))
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == "numpy" and elem_type.__name__ not in ("str_", "string_"):
        if elem_type.__name__ in ("ndarray", "memmap"):
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(NERFSTUDIO_COLLATE_ERR_MSG_FORMAT.format(elem.dtype))

            return nerfstudio_collate([torch.as_tensor(b) for b in batch], extra_mappings=extra_mappings)
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, (str, bytes)):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        try:
            return elem_type(
                {key: nerfstudio_collate([d[key] for d in batch], extra_mappings=extra_mappings) for key in elem}
            )
        except TypeError:
            # The mapping type may not support `__init__(iterable)`.
            return {key: nerfstudio_collate([d[key] for d in batch], extra_mappings=extra_mappings) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(*(nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError("each element in list of batch should be of equal size")
        transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.

        if isinstance(elem, tuple):
            return [
                nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in transposed
            ]  # Backwards compatibility.
        else:
            try:
                return elem_type([nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in transposed])
            except TypeError:
                # The sequence type may not support `__init__(iterable)` (e.g., `range`).
                return [nerfstudio_collate(samples, extra_mappings=extra_mappings) for samples in transposed]

    # NerfStudio types supported below

    elif isinstance(elem, Cameras):
        # If a camera, just concatenate along the batch dimension. In the future, this may change to stacking
        assert all((isinstance(cam, Cameras) for cam in batch))
        assert all((cam.distortion_params is None for cam in batch)) or all(
            (cam.distortion_params is not None for cam in batch)
        ), "All cameras must have distortion parameters or none of them should have distortion parameters.\
            Generalized batching will be supported in the future."

        if batch[0].metadata is not None:
            metadata_keys = batch[0].metadata.keys()
            assert all(
                (cam.metadata.keys() == metadata_keys for cam in batch)
            ), "All cameras must have the same metadata keys."
        else:
            assert all((cam.metadata is None for cam in batch)), "All cameras must have the same metadata keys."

        if batch[0].times is not None:
            assert all((cam.times is not None for cam in batch)), "All cameras must have times present or absent."
        else:
            assert all((cam.times is None for cam in batch)), "All cameras must have times present or absent."

        # If no batch dimension exists, then we need to stack everything and create a batch dimension on 0th dim
        if elem.shape == ():
            op = torch.stack
        # If batch dimension exists, then we need to concatenate along the 0th dimension
        else:
            op = torch.cat

        # Create metadata dictionary
        if batch[0].metadata is not None:
            metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()}
        else:
            metadata = None

        if batch[0].distortion_params is not None:
            distortion_params = op(
                [cameras.distortion_params for cameras in batch],
                dim=0,
            )
        else:
            distortion_params = None

        if batch[0].times is not None:
            times = torch.stack([cameras.times for cameras in batch], dim=0)
        else:
            times = None

        return Cameras(
            op([cameras.camera_to_worlds for cameras in batch], dim=0),
            op([cameras.fx for cameras in batch], dim=0),
            op([cameras.fy for cameras in batch], dim=0),
            op([cameras.cx for cameras in batch], dim=0),
            op([cameras.cy for cameras in batch], dim=0),
            height=op([cameras.height for cameras in batch], dim=0),
            width=op([cameras.width for cameras in batch], dim=0),
            distortion_params=distortion_params,
            camera_type=op([cameras.camera_type for cameras in batch], dim=0),
            times=times,
            metadata=metadata,
        )

    for type_key in extra_mappings:
        if isinstance(elem, type_key):
            return extra_mappings[type_key](batch)

    raise TypeError(NERFSTUDIO_COLLATE_ERR_MSG_FORMAT.format(elem_type))
