# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast

import torch
from torch import Tensor
import torch.cuda.comm

__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]


class Batch:
    """An abstraction of an atomic tensor or a tuple of tensors. This
    eliminates every boilerplate code to classify an atomic tensor or a tuple
    of tensors.
    ::

        x = generate_tensor_or_tensors()
        x = Batch(x)

        # in-place update
        x[0] = F.apply(x[0])
        x[:] = F.apply(*x)

        # f(x) if x is a tensor.
        # f(*x) if x is a tuple of tensors.
        # y is also a batch.
        y = x.call(f)

    """

    def __init__(self, value: TensorOrTensors, index: int) -> None:
        self.value = value
        self.atomic = torch.is_tensor(value)
        self.__index = index

    @property
    def index(self) -> int:
        return self.__index

    @property
    def tensor(self) -> Tensor:
        """Retrieves the underlying tensor."""
        if not self.atomic:
            raise AttributeError("not atomic batch")
        return cast(Tensor, self.value)

    @property
    def tensors(self) -> Tensors:
        """Retrieves the underlying tensors."""
        if self.atomic:
            raise AttributeError("batch is atomic")
        return cast(Tensors, self.value)

    @property
    def tensor_or_tensors(self) -> TensorOrTensors:
        """Retrieves the underlying tensor or tensors regardless of type."""
        return self.value

    def call(self, function: Function) -> "Batch":
        """Calls a function by the underlying tensor or tensors. It also wraps
        the output with :class:`Batch`.
        """
        return Batch(function(self.value), self.index)

    def __repr__(self) -> str:
        return f"Batch[atomic={self.atomic!r}]({self.value!r})"

    def __iter__(self) -> Iterator[Tensor]:
        if self.atomic:
            yield self.tensor
        else:
            yield from self.tensors

    def __len__(self) -> int:
        return 1 if self.atomic else len(self.tensors)

    def __getitem__(self, index: int) -> Tensor:
        if not self.atomic:
            return self.tensors[index]

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        return self.tensor

    # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
    @typing.overload
    def __setitem__(self, index: int, value: Tensor) -> None:
        ...

    @typing.overload
    def __setitem__(self, index: slice, value: Tensors) -> None:
        ...

    def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
        if isinstance(index, int):
            value = cast(Tensor, value)
            self._setitem_by_index(index, value)
        else:
            value = cast(Tensors, value)
            self._setitem_by_slice(index, value)

    def _setitem_by_index(self, index: int, value: Tensor) -> None:
        if not self.atomic:
            i = index
            self.value = self.value[:i] + (value,) + self.value[i + 1 :]
            return

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        self.value = value

    def _setitem_by_slice(self, index: slice, value: Tensors) -> None:
        if not (index.start is index.stop is index.step is None):
            raise NotImplementedError("only slice [:] supported")

        if not self.atomic:
            self.value = value
            return

        if len(value) != 1:
            raise IndexError("atomic batch cannot be replaced with multiple tensors")

        self.value = value[0]


def check(input: TensorOrTensors) -> None:
    """Checks whether the input is a tensor or tensors.

    Raises:
        TypeError: input is not a tensor or tensors.

    """
    if isinstance(input, tuple):
        for x in input:
            check(x)
        return

    if not isinstance(input, Tensor):
        raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")


def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
    """Splits an input mini-batch into multiple micro-batches."""
    inputs: Iterable[TensorOrTensors]

    if isinstance(input, Tensor):
        inputs = input.chunk(chunks)
    else:
        rotated: List[Tensors] = []

        for tensor in input:
            tensors = tensor.chunk(chunks)
            rotated.append(cast(Tensors, tensors))

        inputs = zip(*rotated)

    return [Batch(x, i) for i, x in enumerate(inputs)]


def gather(outputs: List[Batch]) -> TensorOrTensors:
    """Concatenates output micro-batches into a mini-batch."""
    output: TensorOrTensors

    if outputs[0].atomic:
        tensors = tuple(b.tensor for b in outputs)
        output = torch.cat(tensors)
    else:
        rotated = [b.tensors for b in outputs]
        output_buf = []

        for tensors in zip(*rotated):
            output_buf.append(torch.cat(tensors))

        output = tuple(output_buf)

    return output
