from dataclasses import dataclass
from typing import Any, Dict, DefaultDict, List, Optional, Tuple
from collections import defaultdict
from pathlib import Path

import os
import torch
import faiss
import einops

from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


@dataclass
class Cache:
    """
    Base, abstract class for all caches. The actual data structure is specific to each subclass.
    """

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """
        raise NotImplementedError("Make sure to implement `update` in a subclass.")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        raise NotImplementedError(
            "Make sure to implement `get_seq_length` in a subclass."
        )

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states, if there is any."""
        raise NotImplementedError(
            "Make sure to implement `get_max_length` in a subclass."
        )

    def get_usable_length(
        self, new_seq_length: int, layer_idx: Optional[int] = 0
    ) -> int:
        """Given the sequence length of the new inputs, returns the usable length of the cache."""
        # Cache without size limit -> all cache is usable
        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
        #   length, we will need to evict part of the cache (and thus not all cache is usable)
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx)
        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        return previous_seq_length

    @property
    def seen_tokens(self):
        logger.warning_once(
            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
            "model input instead."
        )
        if hasattr(self, "_seen_tokens"):
            return self._seen_tokens
        else:
            return None


class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = (
            0  # Used in `generate` to keep tally of how many tokens the cache has seen
        )

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        else:
            raise KeyError(
                f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
            )

    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return len(self.key_cache)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat(
                [self.key_cache[layer_idx], key_states], dim=-2
            )
            self.value_cache[layer_idx] = torch.cat(
                [self.value_cache[layer_idx], value_states], dim=-2
            )

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if len(self.key_cache) <= layer_idx:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return None

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
                0, beam_idx.to(device)
            )
            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
                0, beam_idx.to(device)
            )

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
        return legacy_cache

    @classmethod
    def from_legacy_cache(
        cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    ) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache


class SinkCache(Cache):
    """
    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    Parameters:
        window_length (`int`):
            The length of the context window.
        num_sink_tokens (`int`):
            The number of sink tokens. See the original paper for more information.
    """

    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.window_length = window_length
        self.num_sink_tokens = num_sink_tokens
        self.cos_sin_cache = {}
        self._seen_tokens = (
            0  # Used in `generate` to keep tally of how many tokens the cache has seen
        )

    @staticmethod
    def _rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_key_rotary_pos_emb(
        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> torch.Tensor:
        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
        return rotated_key_states

    def _get_rerotation_cos_sin(
        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if key_states.shape[-2] not in self.cos_sin_cache:
            # Upcast to float32 temporarily for better accuracy
            cos = cos.to(torch.float32)
            sin = sin.to(torch.float32)

            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin

            self.cos_sin_cache[key_states.shape[-2]] = (
                rerotation_cos.to(key_states.dtype).unsqueeze(0),
                rerotation_sin.to(key_states.dtype).unsqueeze(0),
            )
        return self.cos_sin_cache[key_states.shape[-2]]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
        if len(self.key_cache) <= layer_idx:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states."""
        return self.window_length

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
                rotation as the tokens are shifted.

        Return:
            A tuple containing the updated key and value states.
        """
        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
        # with partially rotated position embeddings, like Phi or Persimmon.
        sin = cache_kwargs.get("sin")
        cos = cache_kwargs.get("cos")
        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
        using_rope = cos is not None and sin is not None

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # [bsz, num_heads, seq_len, head_dim]
        if len(self.key_cache) <= layer_idx:
            # Empty cache
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)

        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
            # Growing cache
            self.key_cache[layer_idx] = torch.cat(
                [self.key_cache[layer_idx], key_states], dim=-2
            )
            self.value_cache[layer_idx] = torch.cat(
                [self.value_cache[layer_idx], value_states], dim=-2
            )

        else:
            # Shifting cache
            keys_to_keep = self.key_cache[layer_idx][
                :,
                :,
                -self.window_length + self.num_sink_tokens + key_states.shape[-2] :,
            ]

            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
            if using_rope:
                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
                    key_states, cos[: self.window_length], sin[: self.window_length]
                )
                if partial_rotation_size is not None:
                    keys_to_keep, keys_pass = (
                        keys_to_keep[..., :partial_rotation_size],
                        keys_to_keep[..., partial_rotation_size:],
                    )
                keys_to_keep = self._apply_key_rotary_pos_emb(
                    keys_to_keep, rerotation_cos, rerotation_sin
                )
                if partial_rotation_size is not None:
                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)

            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
            self.key_cache[layer_idx] = torch.cat(
                [sink_keys, keys_to_keep, key_states], dim=-2
            )

            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
            values_to_keep = self.value_cache[layer_idx][
                :,
                :,
                -self.window_length + self.num_sink_tokens + value_states.shape[-2] :,
            ]
            self.value_cache[layer_idx] = torch.cat(
                [sink_values, values_to_keep, value_states], dim=-2
            )

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
                0, beam_idx.to(device)
            )
            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
                0, beam_idx.to(device)
            )


class StaticCache(Cache):
    """
    Static Cache class to be used with `torch.compile(model)`.

    Parameters:
        config (`PretrainedConfig):
            The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
            required to initialize the static cache.
        max_batch_size (`int`):
            The maximum batch size with which the model will be used.
        max_cache_len (`int`):
            The maximum sequence length with which the model will be used.
        device (`torch.device`):
            The device on which the cache should be initialized. Should be the same as the layer.
        dtype (*optional*, defaults to `torch.float32`):
            The default `dtype` to use when initializing the layer.
    """

    def __init__(
        self,
        config: PretrainedConfig,
        max_batch_size: int,
        max_cache_len: int,
        device,
        dtype=None,
    ) -> None:
        super().__init__()
        self.max_batch_size = max_batch_size
        self.max_cache_len = (
            config.max_position_embeddings if max_cache_len is None else max_cache_len
        )
        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
        self.head_dim = (
            config.head_dim
            if hasattr(config, "head_dim")
            else config.hidden_size // config.num_attention_heads
        )

        self.dtype = dtype if dtype is not None else torch.float32
        self.num_key_value_heads = (
            config.num_attention_heads
            if config.num_key_value_heads is None
            else config.num_key_value_heads
        )

        cache_shape = (
            max_batch_size,
            self.num_key_value_heads,
            self.max_cache_len,
            self.head_dim,
        )
        self.key_cache: torch.Tensor = torch.zeros(
            cache_shape, dtype=self.dtype, device=device
        )
        self.value_cache: torch.Tensor = torch.zeros(
            cache_shape, dtype=self.dtype, device=device
        )

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for. Kept for backward compatibility
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
                to know how much of the cache it should overwrite.

        Return:
            A tuple containing the updated key and value states.
        """
        new_cache_positions = cache_kwargs.get("cache_position")
        k_out = self.key_cache
        v_out = self.value_cache

        k_out[:, :, new_cache_positions] = key_states
        v_out[:, :, new_cache_positions] = value_states

        return k_out, v_out

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
        # limit the check to the first batch member and head dimension.
        # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
        # https://github.com/pytorch/pytorch/issues/120248 is fixed
        return (self.key_cache[0, 0].any(dim=-1)).sum()

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        return self.max_cache_len

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        device = self.key_cache.device
        self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
        device = self.value_cache.device
        self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))

    def to_legacy_cache(self):
        """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
        return None


class DynamicFaissCache(Cache):
    """
    Structure of the cache:
    - Two fields, key_cache and value_cache.
        - Each field is a list, with one item per layer of the model
            - Each item is a tuple, with two elements.
                - The first element contains keys/values from the prefix (construct phase)
                - The second element contains keys/values from the suffix (query/generation phase)
                
    Example:
        self.key_cache:
            [
                (faiss.swigfaiss_avx2.IndexFlatIP, torch.Tensor),
                ...
                (faiss.swigfaiss_avx2.IndexFlatIP, torch.Tensor)
            ] 
        self.value_cache:
            [
                (torch.Tensor, torch.Tensor),
                ...
                (torch.Tensor, torch.Tensor)
            ] 
    """

    def __init__(self, flat: bool = True) -> None:
        self.key_cache: List[(faiss.swigfaiss_avx2.IndexFlatIP, torch.Tensor)] = []
        self.value_cache: List[(torch.Tensor, torch.Tensor)] = []
        self.dense_key_cache: List[torch.Tensor] = []
        self.dense_value_cache: List[torch.Tensor] = []
        self.seq_lengths: DefaultDict[int, int] = defaultdict(int)
        self.sparse_cache_initialized: DefaultDict[int, bool] = defaultdict(bool)
        self.flat = flat

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """
        self.seq_lengths[layer_idx] += key_states.shape[-2]
        construct_mode = cache_kwargs.get("construct_mode")
        block_mode = cache_kwargs.get("block_mode")
        if construct_mode:
            if block_mode:
                if len(self.dense_key_cache) <= layer_idx:
                    if not self.sparse_cache_initialized[layer_idx]:
                        self.key_cache.append(
                            (torch.empty(0).to(key_states.dtype).cpu(), torch.empty(0).cuda().to(value_states.dtype))
                        )
                        self.value_cache.append(
                            (torch.empty(0).to(value_states.dtype).cpu(), torch.empty(0).cuda().to(value_states.dtype))
                        )
                        self.sparse_cache_initialized[layer_idx] = True
                    self.dense_key_cache.append(key_states)
                    self.dense_value_cache.append(value_states)
                else:
                    self.dense_key_cache[layer_idx] = torch.cat(
                        [self.dense_key_cache[layer_idx], key_states], dim=-2
                    )
                    self.dense_value_cache[layer_idx] = torch.cat(
                        [self.dense_value_cache[layer_idx], value_states], dim=-2
                    )
                return (self.key_cache[layer_idx][0], self.dense_key_cache[layer_idx]), (self.value_cache[layer_idx][0], self.dense_value_cache[layer_idx])
            else:
                # We create a Faiss database for the keys insert it into the prefix part of the cache (the first part of each tuple)
                # The values cache just contains the values tensor for both the prefix and suffix.
                if len(self.key_cache) <= layer_idx:
                    key_db = DynamicFaissCache.create_key_database(key_states=key_states, flat=self.flat)
                    self.key_cache.append((key_db, torch.empty(0).cuda().to(key_states.dtype)))
                    self.value_cache.append(
                        (value_states.cpu(), torch.empty(0).cuda().to(value_states.dtype))
                    )
                else:
                    prefix_key_db = self.key_cache[layer_idx][0]
                    prefix_key_db_update = DynamicFaissCache.update_key_database(key_states, prefix_key_db)
                    self.key_cache[layer_idx] = (prefix_key_db_update, torch.empty(0).cuda().to(key_states.dtype))
                    prefix_value_cache = self.value_cache[layer_idx][0]
                    prefix_value_cache_update = torch.cat((prefix_value_cache, value_states.cpu()), dim=-2)
                    self.value_cache[layer_idx] = (prefix_value_cache_update, torch.empty(0).cuda().to(value_states.dtype))
        else:
            # We update the suffix part of the cache (the second part of each tuple)
            assert len(self.key_cache) > layer_idx
            self.key_cache[layer_idx] = (
                self.key_cache[layer_idx][0],
                torch.cat([self.key_cache[layer_idx][1], key_states], dim=-2),
            )
            self.value_cache[layer_idx] = (
                self.value_cache[layer_idx][0],
                torch.cat([self.value_cache[layer_idx][1], value_states], dim=-2),
            )

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        return self.seq_lengths[layer_idx]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states, if there is any."""
        return None

    def to_legacy_cache(self):
        return self

    def update_dense_to_sparse(self):
        for layer_idx, key_cache_entry in enumerate(self.key_cache):
            sparse_prefix_keys, _ = key_cache_entry
            dense_keys = self.dense_key_cache[layer_idx]
            prefix_key_cache_update = torch.cat((sparse_prefix_keys, dense_keys.cpu()), dim=-2)
            self.key_cache[layer_idx] = (prefix_key_cache_update, torch.empty(0).cuda().to(dense_keys.dtype))

        for layer_idx, value_cache_entry in enumerate(self.value_cache):
            sparse_prefix_values, _ = value_cache_entry
            dense_values = self.dense_value_cache[layer_idx]
            prefix_value_cache_update = torch.cat((sparse_prefix_values, dense_values.cpu()), dim=-2)
            self.value_cache[layer_idx] = (prefix_value_cache_update, torch.empty(0).cuda().to(dense_values.dtype))

        self.dense_key_cache = []
        self.dense_value_cache = []

    def update_tensor_to_faiss_index(self):
        for layer_idx, key_cache_entry in enumerate(self.key_cache):
            sparse_prefix_keys, suffix_keys = key_cache_entry
            key_db = DynamicFaissCache.create_key_database(flat=True, key_states=sparse_prefix_keys)
            self.key_cache[layer_idx] = (key_db, suffix_keys)

    @staticmethod
    def create_key_database(key_states=None, flat=True, BH=None, D=None):
        """Create a key vector database using FAISS. Stored on CPU.

        Args:
            key_states (torch.Tensor): Tensor of key states.

        Returns:
            list: List of FAISS search indexes, one for each batch and head.
        """
        if key_states is None:
            assert BH is not None, f"One of key_states or BH must be given"
            assert D is not None, f"One of key_states or D must be given"
        if len(key_states.shape) == 4:
            B, H, N, D = key_states.shape
            BH = B * H
            key_states = einops.rearrange(key_states, "B H N D -> (B H) N D")
        else:
            BH, N, D = key_states.shape

        # TODO parallelize?
        faiss_indices_list = []
        key_database = []
        for i in range(BH):
            if flat:
                search_index = faiss.IndexFlatIP(D)
            else:
                quantizer = faiss.IndexFlatIP(D)
                index = faiss.IndexIVFFlat(quantizer, D, ivf_centers_num, faiss.METRIC_INNER_PRODUCT)
                index.train(K[i, :, :].contiguous().to(torch.float32).cpu())
                index.add(K[i, :, :].contiguous().to(torch.float32).cpu())
                index.nprobe = 2
                faiss_indices.append(index)

            if key_states is not None:
                search_index.add(key_states[i, :, :].contiguous().to(torch.float32).cpu())
            key_database.append(search_index)

        return key_database

    @staticmethod
    def update_key_database(key_states, key_db):
        """Adds key_states to a given faiss key vector database. 

        Args:
            key_states (torch.Tensor): Tensor of key states.
            key_db (faiss index): Current key database

        Returns:
            list: List of FAISS search indexes, one for each batch and head.
        """
        if len(key_states.shape) == 4:
            B, H, N, D = key_states.shape
            BH = B * H
            key_states = einops.rearrange(key_states, "B H N D -> (B H) N D")
        else:
            BH, N, D = key_states.shape

        # TODO parallelize?
        faiss_indices_list = []
        key_database = []
        for i, db in enumerate(key_db):
            db.add(key_states[i, :, :].contiguous().to(torch.float32).cpu())

        return key_db

    @staticmethod
    def save_cache(cache: 'DynamicFaissCache', directory: str) -> None:
        cache_path = Path(directory)
        cache_path.mkdir(parents=True, exist_ok=True)
        key_cache_path = cache_path.joinpath('key_cache')
        key_cache_path.mkdir(parents=True, exist_ok=True)
        value_cache_path = cache_path.joinpath('value_cache')
        value_cache_path.mkdir(parents=True, exist_ok=True)

        layer_idx = 0
        for prefix_key_db, suffix_key_states in cache.key_cache:
            prefix_key_indices_path = key_cache_path.joinpath(f"layer_{layer_idx}_indices")
            prefix_key_indices_path.mkdir(parents=True, exist_ok=True)
            bh_num = 0
            for db in prefix_key_db:
                prefix_key_index_path = prefix_key_indices_path.joinpath(f"head_{bh_num}.index")
                faiss.write_index(prefix_key_db[bh_num], str(prefix_key_index_path.absolute()))
                bh_num = bh_num + 1
            layer_idx = layer_idx + 1

        layer_idx = 0
        for prefix_value_states, suffix_value_states in cache.value_cache:
            prefix_value_tensor_path = value_cache_path.joinpath(f"layer_{layer_idx}_prefix.pt")
            torch.save(prefix_value_states, str(prefix_value_tensor_path.absolute()))
            layer_idx = layer_idx + 1
    
    @staticmethod
    def save_cache_tensor(cache: 'DynamicFaissCache', filepath: str) -> None:
        # CHECK THAT FILE NAME IS LIKE WHATEVER VALID OR SOME SHIT
        if not isinstance(filepath, str):
            raise ValueError("Filepath must be a string.")

        # Ensure directory exists
        dir_name = os.path.dirname(filepath)
        if not os.path.exists(dir_name):
            raise FileNotFoundError(f"Directory {dir_name} does not exist.")

        v_cache_tensor = []
        for layer_idx, val_cache_entry in enumerate(cache.value_cache):
            vals, _ = val_cache_entry # vals H N D
            v_cache_tensor.append(vals)
        v_cache_tensor = torch.stack(v_cache_tensor) # L H N D

        N, D = v_cache_tensor.shape[-2], v_cache_tensor.shape[-1]
        dtype = v_cache_tensor.dtype
        k_cache_tensor = []
        for layer_idx, key_cache_entry in enumerate(cache.key_cache):
            key_db, _ = key_cache_entry
            head_list = []
            for head_idx, db in enumerate(key_db):
                head_list.append(torch.tensor(faiss.rev_swig_ptr(db.get_xb(), N * D)).reshape((N, D)).to(dtype)) # N D
            k_cache_tensor.append(torch.stack(head_list)) # head_list H N D
        k_cache_tensor = torch.stack(k_cache_tensor) # L H N D

        kv_cache_tensor = torch.stack([k_cache_tensor, v_cache_tensor]) # 2 L H N D
        torch.save(kv_cache_tensor, filepath)

    @staticmethod
    def load_cache(directory: str, num_layers: int, bh_size: int, device="cuda", dtype=torch.bfloat16) -> 'DynamicFaissCache':
        cache_path = Path(directory)
        key_cache_path = cache_path.joinpath('key_cache')
        value_cache_path = cache_path.joinpath('value_cache')

        key_cache = []
        value_cache = []
        seq_lengths = defaultdict(int)

        for layer_idx in range(num_layers):
            prefix_key_indices_path = key_cache_path.joinpath(f"layer_{layer_idx}_indices")
            key_database = []
            for bh_num in range(bh_size):
                prefix_key_index_path = prefix_key_indices_path.joinpath(f"head_{bh_num}.index")
                key_database.append(faiss.read_index(str(prefix_key_index_path.absolute())))
            key_cache.append((key_database, torch.empty(0).to(device).to(dtype)))
            seq_lengths[layer_idx] = key_database[0].ntotal

        for layer_idx in range(num_layers):
            prefix_value_tensor_path = value_cache_path.joinpath(f"layer_{layer_idx}_prefix.pt")
            prefix_value_states = torch.load(str(prefix_value_tensor_path.absolute()))
            value_cache.append((prefix_value_states, torch.empty(0).to(device).to(dtype)))

        cache = DynamicFaissCache()
        cache.key_cache = key_cache
        cache.value_cache = value_cache
        cache.seq_lengths = seq_lengths
        return cache

    @staticmethod
    def load_cache_tensor(filepath: str = None, tensor: torch.Tensor = None, flat: bool = True) -> None:
        if filepath is None:
            kv_cache_tensor = tensor
        else:
            kv_cache_tensor = torch.load(filepath)

        assert kv_cache_tensor is not None
        cache = DynamicFaissCache()
        cache.flat = flat

        num_layers = kv_cache_tensor.shape[1]
        dtype = kv_cache_tensor.dtype
        for l in range(num_layers):
            key_db = DynamicFaissCache.create_key_database(flat=flat, key_states=kv_cache_tensor[0, l, :, :, :])
            key_cache_tuple = (key_db, torch.empty(0).cuda().to(dtype))
            cache.key_cache.append(key_cache_tuple)

            val_cache_tuple = (kv_cache_tensor[1, l, :, :, :] , torch.empty(0).cuda().to(dtype))
            cache.value_cache.append(val_cache_tuple)
            cache.seq_lengths[l] = kv_cache_tensor.shape[-2]
            cache.sparse_cache_initialized[l] = True

        return cache


    @classmethod
    def from_dynamic_cache(cls, dynamic_cache: DynamicCache):
        cache = cls()
        key_cache = []
        for k in dynamic_cache.key_cache:
            key_db = cls.create_key_database(k)
            key_cache.append((key_db, torch.empty(0).cuda().to(k.dtype)))
        cache.key_cache = key_cache

        value_cache = []
        for v in dynamic_cache.value_cache:
            value_cache.append(
                (v.cpu(), torch.empty(0).cuda().to(v.dtype))
            )
        cache.value_cache = value_cache

        # Sequence lengths
        seq_lengths = defaultdict(int)
        for layer in range(len(dynamic_cache.key_cache)):
            seq_lengths[layer] = dynamic_cache.get_seq_length(layer)
        cache.seq_lengths = seq_lengths

        return cache

class DynamicWindowCache(Cache):
    """
    Structure of the cache:
    - Two fields, key_cache and value_cache.
        - Each field is a list, with one item per layer of the model
            - Each item is a tuple, with two elements.
                - The first element contains keys/values from the prefix (construct phase)
                - The second element contains keys/values from the suffix (query/generation phase)

    Example:
        self.key_cache:
            [
                (faiss.swigfaiss_avx2.IndexFlatIP, torch.Tensor),
                ...
                (faiss.swigfaiss_avx2.IndexFlatIP, torch.Tensor)
            ]
        self.value_cache:
            [
                (torch.Tensor, torch.Tensor),
                ...
                (torch.Tensor, torch.Tensor)
            ]
    """

    def __init__(self, gpu_cache_size: int) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.gpu_key_cache: List[torch.Tensor] = []
        self.gpu_value_cache: List[torch.Tensor] = []
        self.seq_lengths: DefaultDict[int, int] = defaultdict(int)
        self.gpu_cache_size: int = gpu_cache_size

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """
        assert len(key_states.shape) == 4
        assert len(value_states.shape) == 4

        self.seq_lengths[layer_idx] += key_states.shape[-2]
        if len(self.gpu_key_cache) <= layer_idx:
            self.key_cache.append(key_states.cpu())
            self.value_cache.append(value_states.cpu())
            self.gpu_key_cache.append(key_states)
            self.gpu_value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat(
                [self.key_cache[layer_idx], key_states.cpu()], dim=-2
            )
            self.value_cache[layer_idx] = torch.cat(
                [self.value_cache[layer_idx], value_states.cpu()], dim=-2
            )
            if (self.key_cache[layer_idx].shape[-2] <= self.gpu_cache_size):
                self.gpu_key_cache[layer_idx] = torch.cat(
                    [self.gpu_key_cache[layer_idx], key_states], dim=-2
                )
                self.gpu_value_cache[layer_idx] = torch.cat(
                    [self.gpu_value_cache[layer_idx], value_states], dim=-2
                )
            else:
                chunk_size = key_states.shape[-2]
                self.gpu_key_cache[layer_idx][:, :, :-chunk_size, :] = self.gpu_key_cache[layer_idx][:, :, chunk_size:, :]
                self.gpu_key_cache[layer_idx][:, :, -chunk_size:, :] = key_states
                self.gpu_value_cache[layer_idx][:, :, :-chunk_size, :] = self.gpu_value_cache[layer_idx][:, :, chunk_size:, :]
                self.gpu_value_cache[layer_idx][:, :, -chunk_size:, :] = value_states

        return self.gpu_key_cache[layer_idx], self.gpu_value_cache[layer_idx]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        return self.seq_lengths[layer_idx]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states, if there is any."""
        return None

    def to_legacy_cache(self):
        return self

    def update_tensor_to_faiss_index(self):
        for layer_idx, key_cache_entry in enumerate(self.key_cache):
            sparse_prefix_keys, suffix_keys = key_cache_entry
            key_db = DynamicFaissCache.create_key_database(flat=True, key_states=sparse_prefix_keys)
            self.key_cache[layer_idx] = (key_db, suffix_keys)

    @staticmethod
    def convert_cache_to_list(cache: 'DynamicWindowCache') -> torch.Tensor:
        kv_cache_tensor = DynamicWindowCache.get_kv_cache_as_tensor(cache)
        return kv_cache_tensor.to(torch.float32).cpu().detach().numpy().tolist()

    @staticmethod
    def get_kv_cache_as_tensor(cache: 'DynamicWindowCache') -> torch.Tensor:
        key_cache_tensor = torch.stack(cache.key_cache)
        value_cache_tensor = torch.stack(cache.value_cache)
        kv_cache_tensor = torch.stack([key_cache_tensor, value_cache_tensor])
        return kv_cache_tensor

    @staticmethod
    def save_cache(cache: 'DynamicWindowCache', filepath: str) -> None:
        kv_cache_tensor = DynamicWindowCache.get_kv_cache_as_tensor(cache)
        torch.save(kv_cache_tensor, filepath)

    @staticmethod
    def load_cache(directory: str, device="cuda", dtype=torch.bfloat16) -> 'DynamicFaissCache':
        raise NotImplementedError
