# SPDX-License-Identifier: Apache-2.0
"""
    This module implements a PyNccl pipe for sending and receiving
    Optional[torch.Tensor] between distributed ranks with advanced
    communication features.

    Key Features:
    - Supports sending and receiving tensors with metadata
    - Handles both CUDA and CPU device communications
    - Implements a non-blocking tensor transfer mechanism
    - Manages buffer size and provides backpressure control
    - Supports distributed process groups with configurable parameters
"""

import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple

import torch

from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger

logger = init_logger(__name__)


class BrokenPipeException(Exception):

    def __init__(self, message):
        self.message = message
        super().__init__(self.message)


Metadata = Dict[str, Optional[torch.Tensor]]


class PyNcclPipe(KVPipeBase):

    METADATA_LENGTH = 16
    MAX_TENSOR_DIMENSIONS = 14
    METADATA_DTYPE = torch.int64

    def __init__(self,
                 local_rank: int,
                 config: KVTransferConfig,
                 device: Optional[str] = None,
                 port_offset: int = 0):
        self.config = config
        self.local_rank = local_rank
        self.kv_rank = self.config.kv_rank
        self.kv_parallel_size = self.config.kv_parallel_size
        if device is None:
            self.device = self._select_device(self.config.kv_buffer_device)
        else:
            self.device = self._select_device(device)

        # build distributed connection and send/recv implementation
        store_timeout = self.config.get_from_extra_config("store_timeout", 300)
        self.group = StatelessProcessGroup.create(
            host=self.config.kv_ip,
            port=self.config.kv_port + port_offset,
            rank=self.kv_rank,
            world_size=self.kv_parallel_size,
            store_timeout=store_timeout,
        )
        # add a barrier to make sure the connection is initiated properly
        self.group.barrier()
        impl = self._get_device_send_recv_impl(self.group)
        self.device_send_func, self.device_recv_func = impl
        # set target rank
        self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
        self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size

        # transportation-related variables
        self.transport_thread: Optional[ThreadPoolExecutor] = None
        self.buffer_size = 0
        self.buffer_size_lock = threading.Lock()
        self.buffer_size_thresh = self.config.kv_buffer_size

    def _get_device_send_recv_impl(
        self, group: StatelessProcessGroup
    ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
        [torch.Tensor, int], None]]:

        send: Callable[[torch.Tensor, int], None]
        recv: Callable[[torch.Tensor, int], None]
        if self.device.type == "cuda":
            # use PyNCCL for send / recv
            comm = PyNcclCommunicator(group, device=self.local_rank)
            comm.disabled = False
            send, recv = comm.send, comm.recv  # type: ignore
        else:
            # This send / recv implementation here is NOT intended to transfer
            # KV caches (and should NOT be repurposed to transfer KV caches).
            # Currently it is only used to transmit control-plane messages
            # for PyNcclBuffer.
            send = group.send_obj

            def my_recv(x, src):
                x[...] = group.recv_obj(src)

            recv = my_recv

        return send, recv

    def _select_device(self, device: str):
        logger.info("Selecting device: %s", device)
        if device == "cuda":
            return torch.device(f"cuda:{self.local_rank}")
        else:
            return torch.device("cpu")

    def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
        """
        Create the metadata as a dictionary based on the input tensor.

        Parameters:
            - tensor: The input tensor or None if no tensor is provided.

        Returns:
            - metadata: A dictionary with the following keys:
                - "dtype": The data type of the tensor or None.
                - "shape": The shape of the tensor or None.
        """
        if tensor is None:
            return {"dtype": None, "shape": None}
        else:
            return {"dtype": tensor.dtype, "shape": tensor.shape}

    def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
        """
        Create a buffer to receive the tensor based on the provided metadata.

        Parameters:
            - metadata: A dictionary with keys "dtype" and "shape", describing
              the tensor's data type and shape.

        Returns:
            - buffer: A tensor of the specified type and shape, allocated on
              self.device.
        """
        return torch.empty(metadata["shape"],
                           dtype=metadata["dtype"],
                           device=self.device)

    def _send_metadata(self, metadata: Metadata):
        """
        Send the metadata dictionary to the target rank.

        Parameters:
            - metadata: A dictionary with keys "dtype" and "shape".
        """
        self.group.send_obj(metadata, self.target_rank_for_send)

    def _recv_metadata(self) -> Metadata:
        """
        Receive the metadata dictionary from the target rank.

        Returns:
            - metadata: A dictionary with keys "dtype" and "shape" describing
              the tensor.
        """
        return self.group.recv_obj(self.target_rank_for_recv)

    def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
        """
        The actual implementation of sending the tensor and its metadata to the
        target rank.

        Parameters:
            - tensor: The input tensor to be sent, or None if no tensor is
              being sent.
        """
        metadata = self._make_metadata(tensor)
        self._send_metadata(metadata)
        if tensor is not None:
            self.device_send_func(tensor.to(self.device),
                                  self.target_rank_for_send)

    def _recv_impl(self) -> Optional[torch.Tensor]:
        """
        The actual implementation of receiving a tensor and its metadata from
        the target rank.

        Returns:
            - buffer: The received tensor, or None if no tensor is received.
        """
        metadata = self._recv_metadata()
        if metadata["dtype"] is None:
            return None
        buffer = self._prepare_recv_buffer(metadata)
        self.device_recv_func(buffer, self.target_rank_for_recv)

        return buffer

    def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
                            tensor_size: int) -> None:
        """
        Wrapper for _send_impl to handle exceptions and update buffer size.
        """
        try:
            self._send_impl(tensor)

            with self.buffer_size_lock:
                self.buffer_size -= tensor_size
        except Exception as e:
            logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
                         torch.distributed.get_rank(), str(tensor), str(e))
            import traceback
            traceback.print_exc()

    def block_if_full(self):
        """
        Block the current thread if the buffer size is larger than the
        threshold.
        """
        while self.buffer_size > self.buffer_size_thresh:
            logger.debug("KV cache transfer pipe is full. Waiting...")
            time.sleep(0.05)

    def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
        """
        Sends a tensor and its metadata to the destination rank in a
        non-blocking way.

        Parameters:
            - tensor: The tensor to send, or None if no tensor is being sent.
        """
        if self.transport_thread is None:
            self.transport_thread = ThreadPoolExecutor(max_workers=1)

        if tensor is not None:
            tensor_size = tensor.element_size() * tensor.numel()
        else:
            tensor_size = 0

        self.block_if_full()

        with self.buffer_size_lock:
            self.buffer_size += tensor_size

        self.transport_thread.submit(self.send_tensor_wrapper, tensor,
                                     tensor_size)

    def recv_tensor(self) -> Optional[torch.Tensor]:
        """
        Receives a tensor and its metadata from the source rank. Blocking call.

        Returns:
            - tensor: The received tensor, or None if no tensor is received.
        """
        if self.transport_thread is None:
            self.transport_thread = ThreadPoolExecutor(max_workers=1)

        future = self.transport_thread.submit(self._recv_impl)

        try:
            tensor = future.result()
        except Exception as e:
            logger.error("Encountering exception in KV receiving thread")
            logger.error("%s", e)
            logger.error("My device: %s", self.device)
            import traceback
            traceback.print_exc()
            raise e

        return tensor

    def close(self):
        """
        Close the pipe and release associated resources.
        """
        if hasattr(self,
                   "transport_thread") and self.transport_thread is not None:
            self.transport_thread.shutdown()
