# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed
from deepspeed.runtime.utils import partition_uniform as partition


def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=False):
    """Split a tensor along its last dimension. Adapted from Megatron-LM.

    Arguments:
        tensor: input tensor.
        partitions: list of partition sizes to supply to torch.split
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1

    # Split.
    tensor_list = torch.split(tensor, partitions, dim=last_dim)
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


class TiledLinear(torch.nn.Module):

    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 in_splits=1,
                 out_splits=1,
                 input_is_already_split=False,
                 combine_out_splits=True,
                 linear_cls=torch.nn.Linear,
                 init_linear=None,
                 **kwargs):
        """A replacement for ``torch.nn.Linear`` that works with ZeRO-3 to reduce
        memory requirements via tiling.

        TiledLinear breaks the input and output dimensions of a linear layer
        into tiles that are processed in sequence. This class enables huge
        linear layers when combined with ZeRO-3 because inactive tiles can be
        partitioned and offloaded.

        .. note::
            We recommend using as few tiles as necessary. Tiling
            significantly reduces memory usage, but can reduce throughput
            for inexpensive layers. This due to the smaller kernels having
            less parallelism and lower arithmetic intensity, while
            introducing more frequent synchronization and communication.

        Args:
            in_features (int): See ``torch.nn.Linear``
            out_features (int): See ``torch.nn.Linear``
            bias (bool, optional): See ``torch.nn.Linear``
            in_splits (int, optional): The number of tiles along the input dimension. Defaults to 1.
            out_splits (int, optional): The number of tiles along the output dimension. Defaults to 1.
            input_is_already_split (bool, optional): If set to ``True``, assume that the ``input_`` in
                to ``forward()`` is already split into ``in_splits`` chunks. Defaults to ``False``.
            combine_out_splits (bool, optional): If set to ``False``, do not combine the ``out_splits`` outputs
                into a single tensor. Defaults to ``True``.
            linear_cls (class, optional): The underlying class to build individual tiles.
                Defaults to ``torch.nn.Linear``.
            init_linear (``torch.nn.Linear``, optional): If set, copy the parameters of
                ``init_linear``. Useful for debugging. Defaults to ``None``.
            kwargs (dict, optional): additional keyword arguments to provide to ``linear_cls()``.

        Raises:
            RuntimeError: ``in_splits`` must be within the range [1, in_features).
            RuntimeError: ``out_splits`` must be within the range of [1, out_features).
        """

        super().__init__()

        if (in_splits < 1) or (in_splits > in_features):
            raise RuntimeError('in splits must be in range [1, in_features].')
        if (out_splits < 1) or (out_splits > out_features):
            raise RuntimeError('out splits must be in range [1, out_features].')

        # global, not necessarily local
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = bias

        self.out_splits = out_splits
        self.in_splits = in_splits
        self.input_is_already_split = input_is_already_split
        self.combine_out_splits = combine_out_splits

        # Build partition-lists. These are CSR-style splits [0, part0, part1, ..., features]
        # For example, row_parts[p] gives the start of partition p and row_parts[p+1]
        # is the exclusive end.
        self.in_parts = partition(num_items=in_features, num_parts=in_splits)
        self.out_parts = partition(num_items=out_features, num_parts=out_splits)

        assert len(self.out_parts) == out_splits + 1
        assert len(self.in_parts) == in_splits + 1
        assert self.out_parts[0] == 0
        assert self.out_parts[out_splits] == out_features
        assert self.in_parts[in_splits] == in_features

        self.linears = torch.nn.ModuleList()
        for out_id in range(out_splits):
            self.linears.append(torch.nn.ModuleList())

            local_out_dim = self.out_parts[out_id + 1] - self.out_parts[out_id]

            for in_id in range(in_splits):
                #if input_size is split, we only need one bias
                local_bias = bias if in_id == (in_splits - 1) else False

                local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id]
                local = linear_cls(local_in_dim, local_out_dim, bias=local_bias, **kwargs)
                self.linears[out_id].append(local)

        # Optionally initialize with a known tensor
        if init_linear is not None:
            self.copy_params_from(init_linear)

    def forward(self, input_):
        if self.in_splits > 1 and not self.input_is_already_split:
            input_parts = partition(input_.shape[-1], self.in_splits)
            split_sizes = [input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)]
            inputs = self._split_global_input(input_, split_sizes)
        elif self.in_splits > 1:
            inputs = input_
            assert len(
                inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}"
        else:
            # no splits
            inputs = [input_]

        outputs = [None] * self.out_splits
        for out_id in range(self.out_splits):
            for in_id in range(self.in_splits):
                local_output = self.linears[out_id][in_id](inputs[in_id])

                outputs[out_id] = self._reduce_local_output(in_id=in_id,
                                                            out_id=out_id,
                                                            current_out=outputs[out_id],
                                                            new_out=local_output)

        if self.combine_out_splits:
            return self._combine_output_splits(outputs)

        return outputs

    def _split_global_input(self, input, split_sizes):
        """Partition an input tensor along the last dimension, aligned with given splits.

        Subclasses should override this method to account for new input types.

        Args:
            input (List[Tensor]): The tensor to partition along the last dimension.
            split_sizes (List[int]): The size of each partition.

        Returns:
            List[Any]: A list of the chunks of ``input``.
        """
        return split_tensor_along_last_dim(input, split_sizes)

    def _reduce_local_output(self, in_id, out_id, current_out, new_out):
        """Reduce (sum) a new local result into the existing local results.

        Subclasses should override this method.

        For a given ``out_id``, this method is called ``in_id-1`` times. The first input
        split is a simple assignment.

        Args:
            in_id (int): The input split that produced ``new_out``.
            out_id (int): The output split that produced ``new_out``.
            current_out (Any): The reduced form of all previous ``out_id`` results.
            new_out (Any): The local result from forward (``in_id``, ``out_id``)e

        Returns:
            Any: The combined result of ``current_out`` and ``new_out``.
        """

        if current_out is None:
            #this clone is necessary to preserve auto grad
            #there is some issue with inplace update for outputs that are views
            return new_out.clone()
        else:
            return current_out + new_out

    def _combine_output_splits(self, outputs):
        """Join the splits of the output into a single result.

        Args:
            outputs (List[Any]): The reduced outputs for each output split.

        Returns:
            Any: The combined outputs.
        """
        assert len(outputs) == self.out_splits
        return torch.cat(outputs, dim=-1)

    @torch.no_grad()
    def copy_params_from(self, other):
        """Copy the weight and bias data from ``other``.

        This is especially useful for reproducible initialization and testing.

        Equivalent to:

        .. code-block:: python

            with torch.no_grad():
                self.weight.copy_(other.weight)
                if self.bias is not None:
                    self.bias.copy_(other.bias)

        .. note::
            If ZeRO-3 is enabled, this is a collective operation and the
            updated parameters of data-parallel rank 0 will be visible on all
            ranks. See :class:`deepspeed.zero.GatheredParameters` for more
            information.


        Args:
            other (``torch.nn.Linear``): the linear layer to copy from.
        """
        assert hasattr(other, 'weight')
        assert other.weight.size() == (self.out_features, self.in_features)
        if self.use_bias:
            assert hasattr(other, 'bias')
            assert other.bias is not None
            assert other.bias.size() == (self.out_features, )
        else:
            assert other.bias is None

        for row in range(self.out_splits):
            rstart = self.out_parts[row]
            rstop = self.out_parts[row + 1]

            for col in range(self.in_splits):
                cstart = self.in_parts[col]
                cstop = self.in_parts[col + 1]

                local = self.linears[row][col]
                global_weight = other.weight[rstart:rstop, cstart:cstop]
                with deepspeed.zero.GatheredParameters(local.weight, modifier_rank=0):
                    local.weight.copy_(global_weight)

            if local.bias is not None:
                with deepspeed.zero.GatheredParameters(local.bias, modifier_rank=0):
                    local.bias.data.copy_(other.bias[rstart:rstop].data)


class TiledLinearReturnBias(TiledLinear):
    """Wrapper for a Linear class that returns its own bias parameter, such as
    used by Megatron-LM.
    """

    def _reduce_local_output(self, in_id, out_id, current_out, new_out):
        """Reduces output tensors, but not the returned bias. """
        if current_out is not None:
            old_tensor, old_bias = current_out
        else:
            old_tensor, old_bias = None, None

        assert isinstance(new_out, tuple)
        assert len(new_out) == 2

        tensor, bias = new_out
        assert tensor is not None

        tensor = super()._reduce_local_output(in_id=in_id, out_id=out_id, current_out=old_tensor, new_out=tensor)

        if bias is None:
            bias = old_bias

        return tensor, bias

    def _combine_output_splits(self, outputs):
        # stack output tensors
        tensors = [o[0] for o in outputs]
        tensor = super()._combine_output_splits(tensors)

        # stack biases if applicable
        biases = [o[1] for o in outputs if o[1] is not None]
        if len(biases) > 0:
            bias = super()._combine_output_splits(biases)
        else:
            bias = None

        return tensor, bias
