# SPDX-License-Identifier: Apache-2.0
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.

This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.

These classes above are abstracted behind class `KVCacheBufferBase`.
"""

from abc import ABC, abstractmethod
from typing import List, Optional

import torch


class KVCacheBufferBase(ABC):
    """
    Abstract base class for a KVCache buffer.
    """

    @abstractmethod
    def close(self) -> None:
        """Close the buffer and release resources.

        This method is responsible for cleaning up resources related to the
        KVCache buffer when it is no longer needed.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.
        """
        raise NotImplementedError


class KVLookupBufferBase(KVCacheBufferBase):
    """
    Abstract base class for a KVCache lookup buffer.

    This class provides an abstraction for a key-value (KV) cache lookup buffer.
    
    The key of the lookup buffer:
    - input_tokens: token IDs of the request
    - roi: a binary mask on top of input_tokens.
      - Purpose of roi: Since KV cache may only be available for a subset of 
        tokens in the input (for example, when vLLM is connected to an external 
        KV cache service), roi specifies the subset of tokens that the KV cache 
        is associated with.
      - NOTE: roi can be further extended to describe which part of KV the 
        current process is holding (each process may only hold a part of KV 
        due to TP and PP). This is not implemented for now.
        
    The value of the lookup buffer:
    - key: the key tensor in the KV cache
    - value: the value tensor in the KV cache
    - hidden: the final hidden state generated by model forwarding. This allows 
      vLLM to bypass further model forwarding by transmitting the hidden state.
    """

    @abstractmethod
    def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
               key: torch.Tensor, value: torch.Tensor,
               hidden: torch.Tensor) -> None:
        """Insert into the lookup buffer.
        
        The functionality is similar to the following python statement
        ```
        buffer[input_tokens, roi] = [key, value, hidden]
        ```
        
        FIXME: in the future, we should only have two arguments, key and value,
        where key is a tensor dict and value is a tensor dict.
        
        FIXME: we should transmit both sampler outputs and the hidden states.

        Args:
            input_tokens (torch.Tensor): token IDs.
            roi (torch.Tensor): A binary mask on top of the input tokens
            key (torch.Tensor): The key tensor in the KV cache.
            value (torch.Tensor): The value tensor in the KV cache.
            hidden (torch.Tensor): The final hidden state tensor generated 
                                   during model forwarding to bypass model 
                                   forwarding.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.
        """
        raise NotImplementedError

    @abstractmethod
    def drop_select(
            self, input_tokens: Optional[torch.Tensor],
            roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
        """Select and *drop* KV cache entries from the lookup buffer.
        
        The functionality is similar to the following python statements
        ```
        ret = buffer.pop(input_tokens, roi)
        return ret
        ```
        
        If `input_tokens` and `roi` is `None`, it means selecting any of the
        KV caches in the buffer, return, and remove it from the buffer, useful
        when offloading KV cache to KV cache storage service.

        Args:
            input_tokens (torch.Tensor): token IDs.
            roi (torch.Tensor): A binary mask on top of the input tokens

        Returns:
            List[Optional[torch.Tensor]]: A list of tensors. Can be None.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.
        """
        raise NotImplementedError


class KVStoreBufferBase(KVCacheBufferBase):
    """
    Abstract base class for a KVCache storage buffer with key-value semantics.
    This class provides a simple key-value storage buffer abstract with basic
    put/get operations, which enables flexible KVCache transfer granular
    control.

    The functionality is similar to a distributed key-value store, where:
    - Key: A unique string identifier for the cached entry
    - Value:
        - Tensor to be stored and retrieved
        - None (indicating deletion or empty value)
    """

    @abstractmethod
    def put(
        self,
        key: str,
        value: Optional[torch.Tensor],
    ) -> None:
        """Store a key-value pair in the buffer.

        Args:
            key (str): Unique identifier for a tensor, this tensor could be the
                key cache tensor, value cache tensor, or hidden state tensor
                generated during model forwarding.

            value (Optional[torch.Tensor]): Tensor to be stored.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.
        """
        raise NotImplementedError

    @abstractmethod
    def get(
        self,
        key: str,
    ) -> Optional[torch.Tensor]:
        """Retrieve a value from the buffer by key.

        Args:
            key (str): Unique identifier for a tensor, this tensor could be the
                key cache tensor, value cache tensor, or hidden state tensor
                generated during model forwarding.

        Returns:
            Optional[torch.Tensor]: Stored tensor if exists, None otherwise.

        Raises:
            NotImplementedError: This method must be implemented in subclasses.
        """
        raise NotImplementedError
