# modified from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/layer/parallel_sequence
# !/usr/bin/env python
# -*- encoding: utf-8 -*-

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed.distributed_c10d import _get_global_rank
from torch.nn import Parameter

from .distributed_utils import DistGroups


class TransformerSelfAttentionRing(nn.Module):
    """Parallel self-attention layer abstract class.
    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.

    Args:
        hidden_size (int): hidden size.
        num_attention_heads (int): number of attention heads.
        attention_dropout (float): dropout probability for attention layer.
        attention_mask_func (:class:`typing.Callable`): Mask function to be applied.
        layer_number (int): number of layers.

    """

    def __init__(
        self,
        hidden_size,
        num_attention_heads,
        attention_dropout,
        layer_number,
        apply_query_key_layer_scaling: bool = False,
        convert_fp16_to_fp32_in_softmax: bool = False,
        masked_softmax_fusion=True,
        fp16=False,
        bf16=False,
    ):
        super().__init__()
        self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        # self.attention_mask_func = attention_mask_func
        self.layer_number = layer_number
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        # self.attn_mask_type = attn_mask_type
        assert self.layer_number > 0
        self.attention_dropout = attention_dropout

        if self.apply_query_key_layer_scaling:
            self.convert_fp16_to_fp32_in_softmax = True

        assert (
            self.hidden_size % self.num_attention_heads == 0
        ), "hidden size is not divisible by the number of attention heads"

        self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads

        self.world_size = DistGroups["sp"].size()

        # Strided linear layer.
        self.query_key_value = nn.Linear(
            hidden_size,
            3 * self.hidden_size,
        )

        self.coeff = None
        self.norm_factor = math.sqrt(self.hidden_size)

        if self.apply_query_key_layer_scaling:
            self.coeff = layer_number
            self.norm_factor *= self.coeff

        # self.scale_mask_softmax = FusedScaleMaskSoftmax(
        #     fp16, bf16,
        #     self.attn_mask_type,
        #     masked_softmax_fusion,
        #     self.attention_mask_func,
        #     self.convert_fp16_to_fp32_in_softmax,
        #     self.coeff)

        self.attention_dropout = nn.Dropout(attention_dropout)

        # Output.
        self.dense = nn.Linear(hidden_size, hidden_size, bias=True)
        # note: need to be replaced
        # self.dense = None

    def forward(
        self,
        hidden_states,
        attention_mask,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        **kwargs,
    ):
        # hidden_states: [sub_seq_len, batch_size, hidden_size] not consistent
        # seem to be [batch_size, sub_seq_len, hidden_size]
        # attention_mask: [batch_size, 1, 1, seq_len]
        batch_size, sub_seq_length, hidden_size = hidden_states.size()

        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads shape change:
        # [batch_size, sub_seq_len, hidden_size] --> [batch_size, sub_seq_len, (3 * head_size * num_heads)]
        mixed_x_layer = self.query_key_value(hidden_states)

        # [batch_size, sub_seq_len, num_heads, 3 * head_size] --> 3 [batch_size, sub_seq_len, num_heads, head_size]
        new_tensor_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads,
            3 * self.hidden_size_per_attention_head,
        )
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # split into query, key and value
        last_dim = mixed_x_layer.dim() - 1
        last_dim_value = mixed_x_layer.size(-1)
        assert last_dim_value % 3 == 0, (
            "the last dimension is not a multiple of 3, " "cannot be divided into query, key and value"
        )
        partition_size = last_dim_value // 3
        (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim)

        # attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
        output_size = (
            query_layer.size(0),
            query_layer.size(2),
            query_layer.size(1),
            key_layer.size(1) * self.world_size,
        )

        # [batch_size, sub_seq_len, num_heads, head_size] -> [batch_size * num_heads, sub_seq_len, head_size]
        query_layer = query_layer.transpose(1, 2).contiguous().view(output_size[0] * output_size[1], output_size[2], -1)
        # [batch_size, sub_seq_len, num_heads, head_size] -> [batch_size * num_heads, sub_seq_len, head_size]
        key_layer = key_layer.transpose(1, 2).contiguous().view(output_size[0] * output_size[1], output_size[2], -1)

        # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
        attention_scores = RingQK.apply(
            query_layer,  # [batch_size * num_heads, sub_seq_len, head_size]
            key_layer,  # [batch_size * num_heads, sub_seq_len, head_size],
            batch_size,
            self.num_attention_heads,
            sub_seq_length,
        )

        attention_scores /= self.norm_factor

        # change view to [batch_size, num_heads, sub_seq_len, seq_len]
        attention_scores = attention_scores.view(*output_size)

        # change shape to [batch_size, num_heads, sub_seq_len, seq_len]
        # attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        # note: no scale mask softmax kernel
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.attention_dropout(attention_probs)
        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask
        # context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
        output_size = (value_layer.size(0), value_layer.size(2), value_layer.size(1), value_layer.size(3))

        # change view [batch_size, sub_seq_len, num_heads, head_size] ->
        #             [batch_size * num_heads, sub_seq_len, head_size]
        value_layer = value_layer.transpose(1, 2).contiguous().view(output_size[0] * output_size[1], output_size[2], -1)

        # # change view [b * num_heads, sub_seq_len, seq_len]
        attention_probs = attention_probs.view(
            attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)
        )

        # matmul: [batch_size * num_heads, sub_seq_len, head_size]
        context_layer = RingAV.apply(
            attention_probs,
            value_layer,
            batch_size,
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            sub_seq_length,
        )

        # change view [batch_size, num_heads, sub_seq_len, head_size]
        context_layer = context_layer.view(*output_size)

        # [batch_size, num_heads, sub_seq_len, head_size] -> [batch_size, sub_seq_len, num_heads, head_size]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

        # [batch_size, sub_seq_len, num_heads, head_size] -> [batch_size, sub_seq_len, hidden_size]
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size_per_attention_head * self.num_attention_heads,
        )
        context_layer = context_layer.view(*new_context_layer_shape)

        # note: replace back to bert's attn output
        output = self.dense(context_layer, hidden_states)

        return (output,)

    def extra_repr(self):
        return (
            f"apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, "
            f"layer_number={self.layer_number}, hidden_size:{self.hidden_size}, "
            f"num_attention_heads={self.num_attention_heads}, "
            f"hidden_size_per_attention_head={self.hidden_size_per_attention_head}, "
            f"coeff={self.coeff}, norm_factor={self.norm_factor}, "
            f"convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax} "
        )


class _Linear(nn.Module):
    """Linear layer with column parallelism.
    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip
                       adding bias but instead return it.
    """

    def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
        super(_Linear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add

        self.weight = Parameter(
            torch.empty(
                self.output_size,
                self.input_size,
            )
        )
        nn.init.xavier_normal_(self.weight)

        if bias:
            self.bias = Parameter(torch.empty(self.output_size))
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)

    def forward(self, input_):
        # Matrix multiply.
        bias = self.bias if not self.skip_bias_add else None
        output = F.linear(input_, self.weight, bias)

        if self.skip_bias_add:
            return output, self.bias
        else:
            return output

    def __repr__(self):
        return (
            f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
            + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
        )


class RingQK(torch.autograd.Function):
    """
    Calculate QK in a ring-exchange style
    """

    @staticmethod
    @custom_fwd
    def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length):
        # save tensor for backward
        ctx.save_for_backward(sub_q, sub_k)
        ctx.sub_seq_length = sub_seq_length

        # create local segment of attention score
        attention_score = torch.empty(
            batch_size * num_attention_heads,
            sub_seq_length,
            sub_seq_length * DistGroups["sp"].size(),
            dtype=sub_q.dtype,
            device=sub_q.device,
        )

        # compute local QK^T
        part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()
        start_idx = local_rank * sub_seq_length
        end_idx = (local_rank + 1) * sub_seq_length
        attention_score[:, :, start_idx:end_idx] = part_a

        # compute QK^T in ring-all-reduce style
        for i in range(local_world_size - 1):
            sub_k = ring_forward(sub_k)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
            part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
            attention_score[:, :, start_idx:end_idx] = part_a

        return attention_score

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        (
            sub_q,
            sub_k,
        ) = ctx.saved_tensors
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()

        # calculate gradient of sub_k
        grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q)

        dist.all_reduce(grad_k, group=DistGroups["sp"])
        grad_k = grad_k[:, local_rank * ctx.sub_seq_length : (local_rank + 1) * ctx.sub_seq_length]
        grad_k /= local_world_size

        # calculate gradient for sub_q
        grad_q = torch.zeros_like(
            sub_q,
            dtype=sub_q.dtype,
            device=sub_q.device,
        )

        # compute with local sub_k
        start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
        grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)

        # compute QK^T in ring-all-reduce style
        for i in range(local_world_size - 1):
            sub_k = ring_forward(sub_k)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
            grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)

        grad_q /= local_world_size

        return grad_q, grad_k, None, None, None


class RingAV(torch.autograd.Function):
    """
    Calculate AV in a ring-exchange style
    """

    @staticmethod
    @custom_fwd
    def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length):
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()
        local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)

        sub_attention_result = torch.zeros(
            batch_size * num_attention_heads,
            sub_seq_length,
            attention_head_size,
            device=attention_score.device,
            dtype=attention_score.dtype,
        )

        # save tensors for backward
        ctx.save_for_backward(attention_score, sub_v)
        ctx.sub_seq_length = sub_seq_length

        # compute local AV
        part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
        sub_attention_result += part_av

        # compute AV in ring - all - reduce style
        for i in range(local_world_size - 1):
            sub_v = ring_forward(sub_v)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)

            # compute QK^T
            part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
            sub_attention_result += part_av
        return sub_attention_result

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()
        local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
        attention_scores, sub_v = ctx.saved_tensors

        # calculate gradient of v
        grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output)
        dist.all_reduce(grad_v, group=DistGroups["sp"])
        grad_v = grad_v[:, local_start_idx:local_end_idx]
        grad_v /= local_world_size

        # calculate gradient for attention score
        grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=grad_output.device)

        # compute with local sub_k
        grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

        # compute QK^T in ring-all-reduce style
        for i in range(local_world_size - 1):
            sub_v = ring_forward(sub_v)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)

            # compute grad_q
            grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

        return grad_attention_score, grad_v, None, None, None, None


def ring_forward(tensor_send_next: torch.Tensor) -> torch.Tensor:
    """Sends a tensor to the next member and receives a tensor from the previous member.
    This function returns the received tensor from the previous member.

    Args:
        tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
        parallel_mode (ParallelMode): Parallel group mode used in this communication

    Returns:
        :class:`torch.Tensor`: The tensor received from the previous.

    Note:
        The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
        in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
    """
    buffer_shape = tensor_send_next.size()

    ops = []
    local_rank = DistGroups["sp"].rank()
    local_world_size = DistGroups["sp"].size()
    next_local_rank = (local_rank + 1) % local_world_size
    pre_local_rank = (local_rank - 1) % local_world_size
    next_global_rank = _get_global_rank(DistGroups["sp"], next_local_rank)
    pre_global_rank = _get_global_rank(DistGroups["sp"], pre_local_rank)

    tensor_recv_prev = torch.empty(
        buffer_shape, requires_grad=True, device=tensor_send_next.device, dtype=tensor_send_next.dtype
    )

    # send to next rank
    send_next_op = torch.distributed.P2POp(
        torch.distributed.isend, tensor_send_next, next_global_rank, group=DistGroups["sp"]
    )
    ops.append(send_next_op)

    # receive from prev rank
    recv_prev_op = torch.distributed.P2POp(
        torch.distributed.irecv, tensor_recv_prev, pre_global_rank, group=DistGroups["sp"]
    )
    ops.append(recv_prev_op)

    # seems to be even
    assert local_world_size % 2 == 0
    if local_rank % 2 == 0:
        ops = ops[::-1]

    reqs = torch.distributed.batch_isend_irecv(ops)
    for req in reqs:
        req.wait()

    # To protect against race condition when using batch_isend_irecv().
    torch.cuda.synchronize()

    return tensor_recv_prev


def _calc_incoming_device_range(i, rank, world_size, sub_seq_length):
    device_of_incoming_k = (rank - i - 1) % world_size
    start_idx = sub_seq_length * device_of_incoming_k
    end_idx = sub_seq_length * (device_of_incoming_k + 1)
    return start_idx, end_idx


def _calc_current_device_range(rank, sub_seq_length):
    start_idx = sub_seq_length * rank
    end_idx = sub_seq_length * (rank + 1)
    return start_idx, end_idx


class AllGatherSP(torch.autograd.Function):
    """sequence parallel encoder output all gather to full seq
    bs, local_seq, hs = hidden_states.shape
    """

    @staticmethod
    def forward(ctx, hidden_states):
        local_world_size = DistGroups["sp"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return hidden_states

        # Size and dimension.
        split_dim = 1
        local_rank = DistGroups["sp"].rank()

        tensor_list = [torch.empty_like(hidden_states) for _ in range(local_world_size)]
        tensor_list[local_rank] = hidden_states
        torch.distributed.all_gather(tensor_list, hidden_states, group=DistGroups["sp"])

        # Note: torch.cat already creates a contiguous tensor.
        output = torch.cat(tensor_list, dim=split_dim).contiguous()
        return output

    @staticmethod
    def backward(ctx, grad):
        split_dim = 1
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return grad
        src = _get_global_rank(DistGroups["sp"], 0)
        torch.distributed.broadcast(grad, src=src, group=DistGroups["sp"])
        return torch.chunk(grad, local_world_size, split_dim)[local_rank]


class BroadcastSP(torch.autograd.Function):
    """sequence parallel encoder output broadcast
    bs, local_seq, hs = hidden_states.shape
    """

    @staticmethod
    def forward(ctx, hidden_states):
        local_world_size = DistGroups["sp"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return hidden_states

        # Size and dimension.
        split_dim = 1
        local_rank = DistGroups["sp"].rank()

        src = _get_global_rank(DistGroups["sp"], 0)
        torch.distributed.broadcast(hidden_states, src=src, group=DistGroups["sp"])
        return torch.chunk(hidden_states, local_world_size, split_dim)[local_rank]

    @staticmethod
    def backward(ctx, grad):
        split_dim = 1
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return grad

        tensor_list = [torch.empty_like(grad) for _ in range(local_world_size)]
        tensor_list[local_rank] = grad
        torch.distributed.all_gather(tensor_list, grad, group=DistGroups["sp"])

        # Note: torch.cat already creates a contiguous tensor.
        if local_rank == 0:
            output = torch.cat(tensor_list, dim=split_dim).contiguous()
        else:
            shape = list(grad.shape)
            shape[split_dim] *= local_world_size
            shape = tuple(int(i) for i in shape)
            output = torch.zeros(shape, dtype=grad.dtype, device=grad.device)
        return output


class ScatterSP(torch.autograd.Function):
    """sequence parallel encoder output scatter
    bs, local_seq, hs = hidden_states.shape
    """

    @staticmethod
    def forward(ctx, hidden_states):
        local_world_size = DistGroups["sp_cpu"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return hidden_states

        # Size and dimension.
        split_dim = 1
        local_rank = DistGroups["sp_cpu"].rank()

        if local_rank == 0:
            tensor_list = list(torch.chunk(hidden_states, local_world_size, split_dim))
        else:
            tensor_list = None
        shape = list(hidden_states.shape)
        shape[split_dim] /= local_world_size
        shape = tuple(int(i) for i in shape)
        output = torch.zeros(shape, dtype=hidden_states.dtype, device=hidden_states.device)
        src = _get_global_rank(DistGroups["sp_cpu"], 0)
        torch.distributed.scatter(output, scatter_list=tensor_list, src=src, group=DistGroups["sp_cpu"])
        return output

    @staticmethod
    def backward(ctx, grad):
        split_dim = 1
        local_rank = DistGroups["sp_cpu"].rank()
        local_world_size = DistGroups["sp_cpu"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return grad
        if local_rank == 0:
            tensor_list = [torch.empty_like(grad) for _ in range(local_world_size)]
            tensor_list[local_rank] = grad
        else:
            tensor_list = None
        dst = _get_global_rank(DistGroups["sp_cpu"], 0)
        torch.distributed.gather(grad, gather_list=tensor_list, dst=dst, group=DistGroups["sp_cpu"])
        if local_rank == 0:
            output = torch.cat(tensor_list, dim=split_dim).contiguous()
        else:
            shape = list(grad.shape)
            shape[split_dim] *= local_world_size
            shape = tuple(int(i) for i in shape)
            output = torch.zeros(shape, dtype=grad.dtype, device=grad.device)
        return output


class GatherSP(torch.autograd.Function):
    """sequence parallel encoder input gather to full seq in rank 0, either zeros
    bs, local_seq, hs = hidden_states.shape
    """

    @staticmethod
    def forward(ctx, hidden_states):
        local_world_size = DistGroups["sp_cpu"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return hidden_states

        # Size and dimension.
        split_dim = 1
        local_rank = DistGroups["sp_cpu"].rank()

        if local_rank == 0:
            tensor_list = [torch.empty_like(hidden_states) for _ in range(local_world_size)]
            tensor_list[local_rank] = hidden_states
        else:
            tensor_list = None
        dst = _get_global_rank(DistGroups["sp_cpu"], 0)
        torch.distributed.gather(hidden_states, gather_list=tensor_list, dst=dst, group=DistGroups["sp_cpu"])

        # Note: torch.cat already creates a contiguous tensor.
        if local_rank == 0:
            output = torch.cat(tensor_list, dim=split_dim).contiguous()
        else:
            shape = list(hidden_states.shape)
            shape[split_dim] *= local_world_size
            shape = tuple(int(i) for i in shape)
            output = torch.zeros(shape, dtype=hidden_states.dtype, device=hidden_states.device)
        return output

    @staticmethod
    def backward(ctx, grad):
        split_dim = 1
        local_rank = DistGroups["sp_cpu"].rank()
        local_world_size = DistGroups["sp_cpu"].size()
        # Bypass the function if we are using only 1 GPU.
        if local_world_size == 1:
            return grad
        if local_rank == 0:
            tensor_list = list(torch.chunk(grad, local_world_size, split_dim))
        else:
            tensor_list = None
        shape = list(grad.shape)
        shape[split_dim] /= local_world_size
        shape = tuple(int(i) for i in shape)
        output = torch.zeros(shape, dtype=grad.dtype, device=grad.device)
        src = _get_global_rank(DistGroups["sp_cpu"], 0)
        torch.distributed.scatter(output, scatter_list=tensor_list, src=src, group=DistGroups["sp_cpu"])
        return output


def _wrap_sp_split(forward_fn):  # deprecated
    def new_forward_fn(hidden_states, *args, **kwargs):
        bs, seq_len, hs = hidden_states.shape
        local_rank = DistGroups["sp"].rank()
        local_world_size = DistGroups["sp"].size()
        assert seq_len % local_world_size == 0
        local_seq = seq_len // local_world_size
        local_hidden_states = torch.split(hidden_states, local_seq, 1)[local_rank]
        output = forward_fn(local_hidden_states, *args, **kwargs)
        # currently not support other output
        assert len(output) == 1
        return (AllGatherSP.apply(output[0]),)

    return new_forward_fn


def wrap_sp_split_cpu(forward_fn):
    def new_forward_fn(hidden_states, *args, **kwargs):
        local_hidden_states = ScatterSP.apply(hidden_states)
        output = forward_fn(local_hidden_states, *args, **kwargs)
        # currently not support other output
        assert len(output) == 1
        return (GatherSP.apply(output[0]),)

    return new_forward_fn


def wrap_sp_split(forward_fn):
    def new_forward_fn(hidden_states, *args, **kwargs):
        local_hidden_states = BroadcastSP.apply(hidden_states)
        output = forward_fn(local_hidden_states, *args, **kwargs)
        # currently not support other output
        assert len(output) == 1
        return (AllGatherSP.apply(output[0]),)

    return new_forward_fn
