from transformers.cache_utils import DynamicCache
from typing import (
    Optional,
    Tuple,
    List,
    Dict,
    Any
)
import torch


class NSNCache(DynamicCache):
    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
        super().__init__()
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
        self.cache_names = ["quantized_key_cache", "key_norm_idx", "key_norm_scale", "key_norm_offset", "key_mean_idx", "key_mean_scale", "key_mean_offset", "key_norm2",
                            "quantized_value_cache", "value_norm_idx", "value_norm_scale", "value_norm_offset", "value_mean_idx", "value_mean_scale", "value_mean_offset", "value_norm2",
                            "full_key_cache", "full_value_cache", "sin", "cos"]
        self.caches = {cache_name: [] for cache_name in self.cache_names}

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self):
            return {name: self.caches[name][layer_idx] for name in self.cache_names}
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield {name: self.caches[name][layer_idx] for name in self.cache_names}

    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return len(self.caches["full_key_cache"])
    
    def direct_update(self, name: str, layer_idx: int , x: torch.Tensor):
        self.caches[name][layer_idx] = x

    def concat_by_dim(self, name: str, layer_idx: int, x: torch.Tensor, dim: int):
        if len(self.caches[name][layer_idx]) == 0:
            self.caches[name][layer_idx] = x
        else:
            self.caches[name][layer_idx] = torch.cat([self.caches[name][layer_idx], x], dim=dim)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """

        key_cache = self.caches["full_key_cache"]
        value_cache = self.caches["full_value_cache"]
        sin_cache = self.caches["sin"]
        cos_cache = self.caches["cos"]
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(key_cache), layer_idx):
                    key_cache.append([])
                    value_cache.append([])
                    sin_cache.append([])
                    cos_cache.append([])
                key_cache.append(key_states)
                value_cache.append(value_states)
                sin_cache.append(cache_kwargs["sin"])
                cos_cache.append(cache_kwargs["cos"])

                for cache in self.caches.values():
                    for _ in range(len(cache), len(key_cache)):
                        cache.append([])
            else:
                if (len(key_cache[layer_idx]) == 0):  # fills previously skipped layers; checking for tensor causes errors
                    key_cache[layer_idx] = key_states
                    value_cache[layer_idx] = value_states
                else:
                    key_cache[layer_idx] = torch.cat([key_cache[layer_idx], key_states], dim=-2)
                    value_cache[layer_idx] = torch.cat([value_cache[layer_idx], value_states], dim=-2)
                
                if (len(sin_cache[layer_idx]) == 0):
                    sin_cache[layer_idx] = cache_kwargs["sin"]
                    cos_cache[layer_idx] = cache_kwargs["cos"]
                else:
                    sin_cache[layer_idx] = torch.cat([sin_cache[layer_idx], cache_kwargs["sin"]], dim=-2)
                    cos_cache[layer_idx] = torch.cat([cos_cache[layer_idx], cache_kwargs["cos"]], dim=-2)

        return key_cache[layer_idx], value_cache[layer_idx]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        quantized_key_cache = self.caches["quantized_key_cache"]
        key_cache = self.caches["full_key_cache"]
        is_key_empty = (
            len(key_cache) == 0  # no cache in any layer
            or len(key_cache) <= layer_idx  # skipped `layer_idx` and hasn't run a layer with cache after it
            or len(key_cache[layer_idx]) == 0  # the layer has no cache
        )
        is_quantized_key_empty = (
            len(quantized_key_cache) == 0  # no cache in any layer
            or len(quantized_key_cache) <= layer_idx  # skipped `layer_idx` and hasn't run a layer with cache after it
            or len(quantized_key_cache[layer_idx]) == 0  # the layer has no cache
        )

        if is_key_empty:
            full_len = 0
        else:
            full_len = key_cache[layer_idx].shape[-2]

        if is_quantized_key_empty:
            quantized_len = 0
        else:
            quantized_len = quantized_key_cache[layer_idx].shape[-2]
        layer_seq_length = full_len + quantized_len if not (is_key_empty and is_quantized_key_empty) else 0
        return layer_seq_length

    def get_max_cache_shape(self) -> Optional[int]:
        return None

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.caches[name][layer_idx] for name in self.cache_names),)
        return legacy_cache
