from typing import Any, Dict, Iterable, List, Optional, Tuple
from functools import partial

import torch

from transformers.cache_utils import Cache, DynamicCache
from torch.nn.attention.flex_attention import (
    _mask_mod_signature,
    BlockMask,
    create_block_mask,
)

from .mask_utils import AttentionMask, causal_mask_fn, causal_attention_mask_fn, get_mask_mod_w_offset

create_block_mask_compiled = torch.compile(create_block_mask)


class RNSACache(DynamicCache):
    def __init__(
        self,
        max_seq_len: int = 20480,
        _distributed_cache_data: Iterable = None,
        device: str = "cuda",
    ) -> None:
        super().__init__()
        self.max_seq_len = max_seq_len

        self._seen_tokens = 0
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.forget_weights: List[torch.Tensor] = []
        self.kv_positions: List[torch.Tensor] = []
        self.n_seen_tokens: List[int] = []
        self.device = device

        if _distributed_cache_data is not None:
            for key_states, value_states, forget_weights, kv_positions in _distributed_cache_data:
                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
                self.forget_weights.append(forget_weights)
                self.kv_positions.append(kv_positions)

        self.offset = torch.tensor(0, dtype=torch.int64)
        self.block_mask = None


    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.forget_weights[layer_idx], self.kv_positions[layer_idx])
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def __iter__(self):
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.forget_weights[layer_idx], self.kv_positions[layer_idx])

    def __len__(self):
        return len(self.key_cache)

    def get_seen_tokens(self, layer_idx: Optional[int] = None) -> int:
        if layer_idx is None:
            return self._seen_tokens
        if layer_idx < len(self.n_seen_tokens):
            return self.n_seen_tokens[layer_idx]
        else:
            return 0

    def get_total_cached_tokens(self, num_key_value_heads: Optional[int] = None) -> int:
        if num_key_value_heads is None:
            num_key_value_heads = self.key_cache[0].shape[1] if self.key_cache else 0

        return sum(self.get_seq_length(layer_idx) * num_key_value_heads for layer_idx in range(len(self)))

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        forget_weights: torch.Tensor,
        cache_positions: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Update the cache
        if key_states is not None:
            # Update the number of seen tokens
            if layer_idx == 0:
                self._seen_tokens += key_states.shape[-2]

            cache_positions = cache_positions[None, None, :].expand_as(forget_weights) if cache_positions.dim() == 1 else cache_positions

            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append(torch.tensor([]))
                    self.value_cache.append(torch.tensor([]))
                    self.forget_weights.append(torch.tensor([]))
                    self.kv_positions.append(torch.tensor([]))
                    self.n_seen_tokens.append(0)

                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
                self.forget_weights.append(forget_weights)
                self.kv_positions.append(cache_positions)
                self.n_seen_tokens.append(key_states.shape[-2])
            elif (
                not self.key_cache[layer_idx].numel()  # prefers not t.numel() to len(t) == 0 to export the model
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states
                self.forget_weights[layer_idx] = forget_weights
                self.kv_positions[layer_idx] = cache_positions
                self.n_seen_tokens[layer_idx] = key_states.shape[-2]
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
                self.forget_weights[layer_idx] = torch.cat([self.forget_weights[layer_idx], forget_weights], dim=-1)
                self.kv_positions[layer_idx] = torch.cat([self.kv_positions[layer_idx], cache_positions], dim=-1)
                self.n_seen_tokens[layer_idx] += key_states.shape[-2]

        return self.key_cache[layer_idx], self.value_cache[layer_idx], self.forget_weights[layer_idx], self.kv_positions[layer_idx]

    def batch_select_indices(self, indices: torch.Tensor):
        """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
        for layer_idx in range(len(self)):
            self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
            self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
            self.forget_weights[layer_idx] = self.forget_weights[layer_idx][indices, ...]
            self.kv_positions[layer_idx] = self.kv_positions[layer_idx][indices, ...]

    def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
        """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
        `_split_model_inputs()` in `generation.utils`"""
        out = []
        for i in range(0, full_batch_size, split_size):
            current_split = RNSACache(max_seq_len=self.max_seq_len, device=self.device)
            current_split._seen_tokens = self._seen_tokens
            current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
            current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
            current_split.forget_weights = [tensor[i : i + split_size] for tensor in self.forget_weights]
            current_split.kv_positions = [tensor[i : i + split_size] for tensor in self.kv_positions]
            out.append(current_split)
        return out

    def get_block_mask(self, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.tensor] = None, layer_idx: Optional[int] = 0) -> BlockMask:
        if self.block_mask is None:
            # create block mask
            causal_mask = causal_mask_fn if attention_mask is None else partial(causal_attention_mask_fn, mask=attention_mask)
            self.block_mask = create_block_mask_compiled(causal_mask, None, None, self.max_seq_len, self.max_seq_len, device=cache_position.device, _compile=True)

        # update offset
        self.offset.fill_(self.get_seq_length(layer_idx))
        assert cache_position is None or cache_position.shape[-1] == 1, "cache_position must be a tensor with a single element in the last dimension"
        block_index = cache_position // self.block_mask.BLOCK_SIZE[0]
        mask = self.block_mask[:, :, block_index]
        mask.mask_mod = get_mask_mod_w_offset(self.block_mask.mask_mod, self.offset.item())
        mask.seq_lengths = (1, self.offset.item() + len(cache_position))
        return mask

    def compress(
        self,
        strategy: str = "lw_knorm_alpha",
        memory_size: int = 2048,
        buffer_size: int = 512,
        num_layers: Optional[int] = None,
        num_key_value_heads: Optional[int] = None,
        skip_layers: int = 0,
    ):
        num_layers = len(self.key_cache) if num_layers is None else num_layers
        assert num_layers == len(self.value_cache) == len(self.forget_weights) == len(self.kv_positions), "All caches must have the same number of layers"
        num_key_value_heads = self.key_cache[0].shape[1] if num_key_value_heads is None else num_key_value_heads
        assert num_key_value_heads == self.value_cache[0].shape[1], "Key and value caches must have the same number of heads"

        if 'lw_' in strategy:
            # layer wise compression
            for layer_idx in range(skip_layers, num_layers):
                if memory_size + buffer_size <= self.get_seq_length(layer_idx):
                    self.compress_layer(layer_idx, memory_size, strategy)
        else:
            # global compression across all layers
            total_memory_size = num_layers * num_key_value_heads * memory_size
            if total_memory_size + buffer_size <= self.get_total_cached_tokens():
                self.compress_global(total_memory_size, strategy)

    def compress_global(
        self,
        memory_size: int,
        compress_strategy: str = "lw_knorm_alpha",
        num_layers: Optional[int] = None,
    ):
        num_layers = len(self.key_cache) if num_layers is None else num_layers

        # gather all key and value states across all layers
        # all_key_states = torch.cat([self.key_cache[layer_idx] for layer_idx in range(num_layers)], dim=-2)
        # all_value_states = torch.cat([self.value_cache[layer_idx] for layer_idx in range(num_layers)], dim=-2)
        all_forget_weights = torch.cat([self.forget_weights[layer_idx] for layer_idx in range(num_layers)], dim=-1)
        # all_kv_positions = torch.cat([self.kv_positions[layer_idx] for layer_idx in range(num_layers)], dim=-1)
        raise NotImplementedError("Global compression is not implemented yet")

    def compress_layer(
        self,
        layer_idx: int,
        memory_size: int,
        comrpess_strategy: str = "lw_knorm_alpha",
    ):
        key_states = self.key_cache[layer_idx]
        value_states = self.value_cache[layer_idx]
        kv_positions = self.kv_positions[layer_idx]
        forget_weights = self.forget_weights[layer_idx]
        
        # log_beta = forget_weights.to(torch.float32)
        log_beta = forget_weights
        q_idx = self.get_seen_tokens(layer_idx) + 1
        log_alpha = ((q_idx - kv_positions) * log_beta)

        if comrpess_strategy == "lw_knorm_alpha":
            # get norm of the key states
            key_norm = key_states.norm(dim=-1, keepdim=False)
            # scores = torch.exp(log_alpha) * key_norm
            scores = log_alpha + torch.log(key_norm)
        elif comrpess_strategy == "lw_alpha":
            scores = log_alpha
        else:
            raise ValueError(f"Unknown compression strategy: {comrpess_strategy}")

        # get top-k indices with lowest alpha values
        # top_k_indices = torch.topk(log_alpha, memory_size, dim=-1).indices
        top_k_indices = torch.topk(scores, memory_size, dim=-1).indices
        # sort the top-k indices to maintain order
        top_k_indices, _ = torch.sort(top_k_indices, dim=-1)

        # gather the top-k key and value states to the first position, using gather
        self.key_cache[layer_idx] = key_states.gather(-2, top_k_indices.unsqueeze(-1).expand(-1, -1, -1, key_states.shape[-1]))
        self.value_cache[layer_idx] = value_states.gather(-2, top_k_indices.unsqueeze(-1).expand(-1, -1, -1, value_states.shape[-1]))
        self.forget_weights[layer_idx] = self.forget_weights[layer_idx].gather(-1, top_k_indices)
        self.kv_positions[layer_idx] = kv_positions.gather(-1, top_k_indices)

        # if layer_idx == 16:
        #     print(layer_idx, q_idx, "compressed to", memory_size, "tokens")
        #     print("Kv positions head 0:", kv_positions[0, 0, :])
        #     print("beta head 0:", beta[0, 0, :].exp())
        #     print("alpha:", log_alpha[0, 0, :])
    
    def log(self, layer_idx: int = None):
        logs = {}
        if layer_idx is None:
            for layer_idx in range(len(self.key_cache)):
                logs[layer_idx] = {
                    "seen_tokens": self.get_seen_tokens(layer_idx),
                    "kv_positions": self.kv_positions[layer_idx].detach().cpu(),
                }
        else:
            logs["seen_tokens"] = self.get_seen_tokens(layer_idx)
            logs["kv_positions"] = self.kv_positions[layer_idx].detach().cpu()
        return logs

    def copy_to_device(self, device: str):
        new_cache = RNSACache(max_seq_len=self.max_seq_len, device=device)
        new_cache._seen_tokens = self._seen_tokens
        new_cache.offset = self.offset.to(device)
        new_cache.block_mask = self.block_mask.to(device) if self.block_mask is not None else None
        new_cache.key_cache = [tensor.to(device) for tensor in self.key_cache]
        new_cache.value_cache = [tensor.to(device) for tensor in self.value_cache]
        new_cache.forget_weights = [tensor.to(device) for tensor in self.forget_weights]
        new_cache.kv_positions = [tensor.to(device) for tensor in self.kv_positions]
        new_cache.n_seen_tokens = self.n_seen_tokens.copy()
        return new_cache
