              
                                                      
                       

from typing import List

import torch

from megatron.core.parallel_state import (
    get_context_parallel_group,
    get_context_parallel_world_size,
    get_tensor_model_parallel_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)


def _all_reduce_of_cp(input_):
                 
    torch.distributed.all_reduce(input_, group=get_context_parallel_group())
    return input_


class _AllReduceOfContextParallel(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):
        return _all_reduce_of_cp(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


def reduce_from_context_parallel_region(input_):
    return _AllReduceOfContextParallel.apply(input_)


class _AllGatherToContextParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, seq_dim, bwd_op):
        assert seq_dim == 1
        cp_size = get_context_parallel_world_size()
                                           
        input_ = input_.view(
            *input_.shape[0:seq_dim],
            2,
            input_.shape[seq_dim] // 2,
            *input_.shape[(seq_dim + 1):],
        )

        gathered_logits = [torch.zeros_like(input_) for _ in range(cp_size)]
        torch.distributed.all_gather(gathered_logits, input_, group=get_context_parallel_group())

        reorded_logits = [None for _ in range(2 * cp_size)]
        for rank in range(cp_size):
            reorded_logits[rank] = gathered_logits[rank][:, 0]
            reorded_logits[2 * cp_size - rank - 1] = gathered_logits[rank][:, 1]
        gathered_logits = torch.cat(reorded_logits, dim=seq_dim)

        ctx.cp_size = cp_size
        ctx.seq_dim = seq_dim
        ctx.cp_group = get_context_parallel_group()
        ctx.local_shape = input_.shape
        ctx.bwd_op = bwd_op

        return gathered_logits

    @staticmethod
    def backward(ctx, grad_output):
        seq_dim = ctx.seq_dim
        assert seq_dim == 1
        cp_group = ctx.cp_group
        cp_size = ctx.cp_size
        local_shape = ctx.local_shape
        bwd_op = ctx.bwd_op

                                                       
        grad_output = grad_output.view(
            *grad_output.shape[0:seq_dim],
            2 * cp_size,
            grad_output.shape[seq_dim] // (2 * cp_size),
            *grad_output.shape[(seq_dim + 1):],
        )

        reordered_indices = []
        for rank in range(cp_size):
            reordered_indices.append(rank)
            reordered_indices.append(2 * cp_size - rank - 1)
        grad_output = grad_output[:, reordered_indices, :]

        split_tensors = torch.split(grad_output, grad_output.size(1) // cp_size, dim=seq_dim)
        grad_list = [t.contiguous() for t in split_tensors]
        assert split_tensors[0].shape == local_shape

        local_grad = torch.empty(local_shape,
                                 dtype=grad_output.dtype,
                                 device=torch.cuda.current_device())
        torch.distributed.reduce_scatter(local_grad, grad_list, op=bwd_op, group=cp_group)

        local_grad = local_grad.view(
            *local_grad.shape[0:seq_dim],
            -1,
            *local_grad.shape[(seq_dim + 2):],
        )
        return local_grad, None, None


def _gather_along_any_dim(input_, gather_dim, comm_group=None):
    """Gather tensors and concatinate along the last dimension."""

    if comm_group is not None:
        world_size = torch.distributed.get_world_size(group=comm_group)
    else:
        world_size = get_tensor_model_parallel_world_size()
                                                     
    if world_size == 1:
        return input_

                         
    assert gather_dim < input_.dim() and gather_dim >= 0, "Invalid dimension to gather along."
    if comm_group is not None:
        rank = torch.distributed.get_rank(group=comm_group)
    else:
        rank = get_tensor_model_parallel_rank()

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    if comm_group is None:
        comm_group = get_tensor_model_parallel_group()
    torch.distributed.all_gather(tensor_list, input_, group=comm_group)

                                                          
    output = torch.cat(tensor_list, dim=gather_dim).contiguous()

    return output


def split_tensor_along_any_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
    any_dim=0,
) -> List[torch.Tensor]:
    """ Split a tensor along its last dimension.

        Args:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.

        Returns:
            A list of Tensors
    """
    assert any_dim < tensor.dim() and any_dim >= 0, "Invalid dimension"
                                 
    any_dim_size = divide(tensor.size()[any_dim], num_partitions)
            
    tensor_list = torch.split(tensor, any_dim_size, dim=any_dim)
                                                                      
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


def _split_along_any_dim(input_, split_dim, comm_group=None):
    """Split the tensor along its last dimension and keep the
    corresponding slice."""
    assert split_dim >= 0 and split_dim < input_.dim()

    if comm_group is not None:
        world_size = torch.distributed.get_world_size(group=comm_group)
    else:
        world_size = get_tensor_model_parallel_world_size()
                                                     
    if world_size == 1:
        return input_

                                 
    input_list = split_tensor_along_any_dim(input_, world_size, any_dim=split_dim)

                                                                      
    if comm_group is not None:
        rank = torch.distributed.get_rank(group=comm_group)
    else:
        rank = get_tensor_model_parallel_rank()
    output = input_list[rank].contiguous()

    return output


def _reduce_scatter_along_any_dim(input_, dim, comm_group=None):
    """Reduce-scatter the input tensor across model parallel group."""
    assert dim < input_.dim() and dim >= 0, "Invalid dimension to reduce-scatter along."
    if comm_group is not None:
        world_size = torch.distributed.get_world_size(group=comm_group)
    else:
        world_size = get_tensor_model_parallel_world_size()
                                                     
    if world_size == 1:
        return input_

    dim_size = list(input_.size())
    assert (dim_size[dim] % world_size == 0
            ), "First dimension of the tensor should be divisible by tensor parallel size"
    dim_size[dim] = dim_size[dim] // world_size

    output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
    if comm_group is None:
        comm_group = get_tensor_model_parallel_group()
    torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), group=comm_group)
    return output


class _GatherFromAnyDim(torch.autograd.Function):
    """Gather the input from model parallel region and concatinate."""

    @staticmethod
    def symbolic(graph, input_, gather_dim=1, tensor_parallel_output_grad=True, comm_group=None):
        return _gather_along_any_dim(input_, gather_dim=gather_dim, comm_group=comm_group)

    @staticmethod
    def forward(ctx, input_, gather_dim=1, tensor_parallel_output_grad=True, comm_group=None):
        ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
        ctx.gather_dim = gather_dim
        ctx.comm_group = comm_group
        return _gather_along_any_dim(input_, gather_dim=gather_dim, comm_group=comm_group)

    @staticmethod
    def backward(ctx, grad_output):
        tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
        gather_dim = ctx.gather_dim
        comm_group = ctx.comm_group
        if tensor_parallel_output_grad:
            return _reduce_scatter_along_any_dim(grad_output, dim=gather_dim,
                                                 comm_group=comm_group), None, None, None
        else:
            return _split_along_any_dim(grad_output, split_dim=gather_dim,
                                        comm_group=comm_group), None, None, None


def all_gather_to_context_parallel_region(local_tensor,
                                          gather_dim=1,
                                          bwd_op=torch.distributed.ReduceOp.AVG):
    return _AllGatherToContextParallelRegion.apply(local_tensor, gather_dim, bwd_op)


def gather_from_any_dim(input_, gather_dim=1, tensor_parallel_output_grad=True, comm_group=None):
    return _GatherFromAnyDim.apply(input_, gather_dim, tensor_parallel_output_grad, comm_group)
