"""
Inherited from Cache class in github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
Requires transformers==4.47.0
Newly implemented cache functionality: 
    - SinkCache (re-implemented)
    - H2O & TOVA & SnapKV  (special case of OBCache)
    - OBCache (our method)
"""

import torch
import warnings
from typing import Any, Dict, List, Optional, Tuple
from transformers.cache_utils import Cache
import math


class SinkCache(Cache):
    def __init__(self,
                 num_recent: int=None, 
                 num_heavy: int=None, 
                 recent_ratio: float=None, 
                 heavy_ratio: float=None) -> None:
        super().__init__()
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []

        self.num_recent_tokens = num_recent
        self.num_sink_tokens = num_heavy

        if recent_ratio is not None:
            self.recent_ratio = recent_ratio
        if heavy_ratio is not None:
            self.sink_ratio = heavy_ratio

        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if len(self.key_cache) <= layer_idx:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def get_max_cache_shape(self) -> Optional[int]:
        """Returns the maximum sequence length of the cache object, in case of H2OCache it is the window length."""
        return self.num_recent_tokens + self.num_sink_tokens

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass.
        Return:
            A tuple containing the updated key and value states.
        """
        num_coming = key_states.shape[-2]
        if layer_idx == 0:
            self._seen_tokens += num_coming

        if len(self.key_cache) <= layer_idx:
            ## Empty cache
            if hasattr(self, "recent_ratio") and hasattr(self, "sink_ratio"):
                assert self.sink_ratio <= self.recent_ratio, "Sink ratio should be smaller than recent ratio"
                self.num_recent_tokens = int(num_coming * self.recent_ratio)
                self.num_sink_tokens = int(num_coming * self.sink_ratio)

            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def evict(self, num_coming=0):
        """ KV Cache Pruning 
        Args: 
            num_coming (int): Number of tokens coming in the next step.
        """
        seq_len = self.get_seq_length()
        if seq_len + num_coming <= self.get_max_cache_shape():
            return
        
        recent_cutoff = seq_len - self.num_recent_tokens + num_coming
        for layer_idx in range(len(self.key_cache)):
            k = self.key_cache[layer_idx]
            v = self.value_cache[layer_idx]

            self.key_cache[layer_idx] = torch.cat(
                [
                    k[:, :, :self.num_sink_tokens], 
                    k[:, :, recent_cutoff:]
                ], dim=-2
            )
            self.value_cache[layer_idx] = torch.cat(
                [
                    v[:, :, :self.num_sink_tokens], 
                    v[:, :, recent_cutoff:]
                ], dim=-2
            )

    def reset(self):
        self.key_cache.clear()
        self.value_cache.clear()
        self._seen_tokens = 0

    def __repr__(self):
        return f"SinkCache: num_recent_tokens={self.num_recent_tokens}, num_sink_tokens={self.num_sink_tokens}, " \
               f"recent_ratio={getattr(self, 'recent_ratio', None)}, sink_ratio={getattr(self, 'sink_ratio', None)}"


class OBCache(Cache):
    def __init__(self, 
                 num_recent: int=None,
                 num_heavy: int=None, 
                 recent_ratio: float = None, 
                 heavy_ratio: float = None,
                 decode_evict: bool = True, 
                 fix_recent_token: bool = False,
                 cache_ratio: float = None,  # fixed cache ratio of prompt length, required when fix_recent is True
                 **score_tracker_kwargs) -> None:
        
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []   

        self.num_recent_tokens = num_recent
        self.num_heavy_tokens = num_heavy

        if recent_ratio is not None:
            self.recent_ratio = recent_ratio
        if heavy_ratio is not None:
            self.heavy_ratio = heavy_ratio

        self.score_tracker = OBCScoreTracker(**score_tracker_kwargs)

        self.fix_recent_token = fix_recent_token
        if self.fix_recent_token:
            self.num_heavy_tokens = None
            assert cache_ratio is not None, "When fix_recent_token is True, cache_ratio should be provided."
            self.cache_ratio = cache_ratio
    
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
        self.decode_evict = decode_evict  # whether to evict during decoding (for comparison with prefill-eviction only methods)
        self.method = "default"

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if len(self.key_cache) <= layer_idx:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def get_max_cache_shape(self) -> Optional[int]:
        """Returns the maximum sequence length of the cache object, in case of H2OCache it is the window length."""
        return self.num_recent_tokens + self.num_heavy_tokens
    
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None
    ):  
        """ Cache update method called in hf's attention forward """ 

        num_coming = key_states.shape[-2]
        if layer_idx == 0:
            self._seen_tokens += num_coming

        if len(self.key_cache) <= layer_idx:
            ## Empty cache
            if self.fix_recent_token:
                kv_budget = int(num_coming * self.cache_ratio)
                self.num_heavy_tokens = kv_budget - self.num_recent_tokens
                assert self.num_heavy_tokens > 0, f"Cache ratio {self.cache_ratio} is too small for num_recent_tokens {self.num_recent_tokens}"

            if hasattr(self, "recent_ratio") and hasattr(self, "heavy_ratio"):
                self.num_recent_tokens = int(num_coming * self.recent_ratio)
                self.num_heavy_tokens = int(num_coming * self.heavy_ratio)
                if not (
                    getattr(self.score_tracker, "ptb_window", None) == 1
                    and getattr(self.score_tracker, "ptb_is_recent", False)
                ):
                    self.score_tracker.ptb_window = self.num_recent_tokens

            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def evict(self, layer_idx, num_coming=0):
        """ KV Cache Pruning 
        Args: 
            num_coming (int): Number of tokens coming in the next step.
        """
        seq_len = self.get_seq_length(layer_idx)
        if seq_len + num_coming <= self.get_max_cache_shape():
            return

        recent_cutoff = seq_len - self.num_recent_tokens + num_coming
        
        k = self.key_cache[layer_idx]
        v = self.value_cache[layer_idx]

        keep_topk_idx = self.score_tracker.evict(
            self.num_recent_tokens,
            self.num_heavy_tokens,
            num_coming,
            layer_idx,
        ).to(k.device)

        self.key_cache[layer_idx] = take_gather(k, keep_topk_idx, recent_cutoff, gather_dim=-2)
        self.value_cache[layer_idx] = take_gather(v, keep_topk_idx, recent_cutoff, gather_dim=-2)

    def evict_all_layers(self, num_coming=0):
        for layer_idx in range(len(self.key_cache)):
            self.evict(layer_idx, num_coming=num_coming)

    def reset(self):
        self.key_cache.clear()
        self.value_cache.clear()
        self._seen_tokens = 0
        self.score_tracker.reset()
    
    def __repr__(self):
        return f"--- OBCache Config {self.method} ---: \nnum_recent_tokens={self.num_recent_tokens}, num_heavy_tokens={self.num_heavy_tokens}, " \
               f"\nrecent_ratio={getattr(self, 'recent_ratio', None)}, heavy_ratio={getattr(self, 'heavy_ratio', None)}, cache_ratio={getattr(self, 'cache_ratio', None)}" \
               f"\nuse_v_score={self.score_tracker.use_v_score}, use_k_score={self.score_tracker.use_k_score}, use_cross={self.score_tracker.use_cross}, " \
               f"\np={self.score_tracker.p}, use_act={self.score_tracker.use_act}" \
               f"\nptb_window={getattr(self.score_tracker, 'ptb_window', None)}, ptb_is_recent={getattr(self.score_tracker, 'ptb_is_recent', None)}" \
               f"\npool_fn={self.score_tracker.pool_fn}, decode_evict={self.decode_evict}" \
               f"\n----------------"


pool_fns = {
    'maxpool': torch.nn.functional.max_pool1d,
    'avgpool': torch.nn.functional.avg_pool1d,
}


class OBCScoreTracker:
    def __init__(self,  
                 use_v_score=True,
                 use_k_score=True,
                 use_cross=True,
                 p=2, 
                 use_act=True,
                 ptb_window=None,       # minimize output perturbation starting from `-ptb_window : ...` tokens
                 pool_fn=None,          # use pooling to smooth the scores
                 ptb_is_recent=False,   # whether to only consider recent `ptb_window` tokens for perturbation
                 num_sink=0,
                 ):

        self.use_v_score = use_v_score
        self.use_k_score = use_k_score        
        self.p = p
        self.use_act = use_act
        self.use_cross = use_cross
        self.pool_fn = pool_fn
        self.ptb_is_recent = ptb_is_recent
        self.num_sink = num_sink

        if ptb_window is not None:
            self.ptb_window = ptb_window

        if self.use_v_score:
            self.all_a_normp: List[torch.Tensor] = []         # [(bsz, num_heads, kv_len), ...]
            if self.use_act:
                self.all_v_normp: List[torch.Tensor] = []     # [(bsz, num_heads, kv_len), ...]

        if self.use_k_score:
            self.all_k_scores: List[torch.Tensor] = []        # [(bsz, num_heads, kv_len), ...] 

    @torch.no_grad()
    def retrieve_score(self, layer_idx):
        score = None
        if self.use_v_score:
            score = self.all_a_normp[layer_idx]
            if self.use_act:
                score = score * self.all_v_normp[layer_idx]

        if self.use_k_score:
            k_score = self.all_k_scores[layer_idx]
            score = k_score if score is None else (score + k_score)
        
        return score # [bsz, num_heads, kv_len]

    @torch.no_grad()
    def update(self, A, V, qK, O, Q=None, K=None, layer_idx=0):
        """ Update attention scores and activation norms.
        Args:
            A (tensor): attn_weights of shape [bsz, num_heads, q_len, kv_len+q_len]
            V (tensor): value_states of shape [bsz, num_heads, kv_len+q_len, head_dim]
            qK (tensor): pre-softmax attn_weights of shape [bsz, num_heads, q_len, kv_len+q_len]
            O (tensor): pre-o-proj attn_output of shape [bsz, num_heads, q_len, head_dim]
            layer_idx (int): index of the layer to update
        """
        if A is None and qK is None: # only when prefill w/ flashattn: manually compute attn_weights (adapted from SnapKV)
            if hasattr(self, 'ptb_window'):
                Q = Q[..., -self.ptb_window:, :]
                if O is not None:
                    O = O[..., -self.ptb_window:, :]
            
            qK = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(Q.size(-1))
            mask = torch.full((Q.size(-2), Q.size(-2)), torch.finfo(qK.dtype).min, device=qK.device)
            mask_cond = torch.arange(mask.size(-1), device=qK.device)
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
            mask = mask.to(qK.device)
            attention_mask = mask[None, None, :, :]

            A = qK
            A[:, :, -Q.size(-2):, -Q.size(-2):] += attention_mask
            A = torch.nn.functional.softmax(A, dim=-1, dtype=torch.float32).to(Q.dtype) # [bsz, num_heads, q_len (ptb_window), kv_len=q_len]

            q_len = A.size(-1) # A shape: [bsz, num_heads, q_len (ptb_window), kv_len=q_len]

        else: # prefill w/ eager-attn or decoding
            _, _, q_len, kv_len = A.shape
            if q_len == kv_len and hasattr(self, 'ptb_window'): # prefill
                A = A[..., -self.ptb_window:, :]
                if qK is not None:
                    qK = qK[..., -self.ptb_window:, :]
                if O is not None:
                    O = O[..., -self.ptb_window:, :]

        if self.use_v_score:
            self.v_update(A, V, q_len, layer_idx=layer_idx)

        if self.use_k_score:
            self.k_update(A, V, qK, O, q_len, layer_idx=layer_idx)

    @torch.no_grad()
    def v_update(self, A, V, q_len, layer_idx):
        """ Update attention scores and activation norms for isolated value pruning
        Args:
            A (tensor): attn_weights of shape [bsz, num_attn_heads, q_len (ptb_window), kv_len]
            V (tensor): value_states of shape [bsz, num_attn_heads, kv_len, head_dim]
            layer_idx (int): index of the layer to update
        """
        num_kv_groups = A.size(1) // V.size(1)
        if num_kv_groups > 1:
            A = A.view(A.size(0), V.size(1), num_kv_groups, A.size(2), A.size(3))

        A_normp = A.pow(self.p).sum(-2) # sum over q_len
        if num_kv_groups > 1:
            A_normp = A_normp.sum(-2)   # sum over kv_groups

        if len(self.all_a_normp) <= layer_idx:
            self.all_a_normp.append(A_normp)
        else:
            if self.ptb_is_recent and self.ptb_window == 1:
                self.all_a_normp[layer_idx] = A_normp
            else:
                prev = self.all_a_normp[layer_idx]
                # accumulate:  [bsz, num_heads, prev_kv_len] -> [bsz, num_heads, prev_kv_len+q_len]
                A_normp[..., :-q_len] += prev
                self.all_a_normp[layer_idx] = A_normp

        if self.use_act:
            if self.p == 1:
                V_normp = V[..., -q_len:, :].abs().sum(dim=-1)              # [bsz, num_heads, q_len]
            else:
                V_normp = V[..., -q_len:, :].pow(self.p).sum(dim=-1)        # [bsz, num_heads, q_len]
            
            if len(self.all_v_normp) <= layer_idx:
                self.all_v_normp.append(V_normp)
            else:
                V_normp = torch.cat([self.all_v_normp[layer_idx], V_normp], dim=2)
                self.all_v_normp[layer_idx] = V_normp                       # [bsz, num_heads, prev_kv_len] -> [bsz, num_heads, prev_kv_len+q_len]

    @torch.no_grad()
    def k_update(self, A, V, qK, O, q_len, layer_idx):
        """ Update key scores for isolated key pruning
        Args:
            A (tensor): attn_weights of shape [bsz, num_attn_heads, q_len, kv_len]
            V (tensor): value_states of shape [bsz, num_attn_heads, kv_len, head_dim]
            qK (tensor): pre-softmax attn_weights of shape [bsz, num_attn_heads, q_len, kv_len]
            O (tensor): pre-o-proj attn_output of shape [bsz, num_attn_heads, q_len, head_dim]
            layer_idx (int): index of the layer to update
        """
        num_kv_groups = A.size(1) // V.size(1)
        if num_kv_groups > 1:
            A = A.view(A.size(0), V.size(1), num_kv_groups, A.size(2), A.size(3))
            qK = qK.view(qK.size(0), V.size(1), num_kv_groups, qK.size(2), qK.size(3))
            O = O.view(O.size(0), V.size(1), num_kv_groups, O.size(2), O.size(3))
            V = V.unsqueeze(2)

        # az_vo = (A * Z)^2  * |V_p - O_i|^2
        azp = (A * qK).pow(2)                              # [bsz, num_heads, q_len, kv_len]

        # vop = |V_p - O_i|^2 (= ||V_p||^2 + ||O_i||^2 - 2 V_p^T O_i)
        V_sq = (V * V).sum(dim=-1)                         # [bsz, num_heads, kv_len]
        O_sq = (O * O).sum(dim=-1)                         # [bsz, num_heads, q_len]
        
        if num_kv_groups > 1:
            dot  = torch.einsum('bhgqd,bhgpd->bhgqp', O, V)
        else:
            dot  = torch.einsum('bhqd,bhpd->bhqp', O, V)

        vop = V_sq[..., None, :] + O_sq[..., None] - 2.0 * dot 

        az_vo = azp * vop
        key_score = az_vo.sum(dim=-2)          # sum over q_len
        if num_kv_groups > 1:
            key_score = key_score.sum(dim=-2)  # sum over kv_groups


        if self.use_v_score and self.use_cross:
            # a2z = A^2 * Z
            a2z = A.pow(2) * qK

            # vvo = V_p^T (V_p - O_i)                                
            vvo = V_sq[..., None, :] - dot
            
            a2z_vvo = 2.0 * a2z * vvo
            kv_cross = a2z_vvo.sum(dim=-2)       # sum over q_len
            if num_kv_groups > 1:
                kv_cross = kv_cross.sum(dim=-2)  # sum over kv_groups

        else:
            kv_cross = 0

        key_score = key_score + kv_cross   

        if len(self.all_k_scores) <= layer_idx:
            self.all_k_scores.append(key_score)
        else:
            if self.ptb_is_recent and self.ptb_window == 1:
                self.all_k_scores[layer_idx] = key_score
            else:
                prev_key_score = self.all_k_scores[layer_idx]
                # accumulate:  [bsz, num_heads, prev_kv_len] -> [bsz, num_heads, prev_kv_len+q_len]
                key_score[..., :-q_len] += prev_key_score
                self.all_k_scores[layer_idx] = key_score

    @torch.no_grad()
    def evict(self, 
              num_recent_tokens, 
              num_heavy_tokens, 
              num_coming,
              layer_idx):
        """ update and evict attention scores / activation norms (cache eviction phase)
        Args:
            num_recent_tokens: number of recent tokens to keep
            num_heavy_tokens: number of heavy tokens to keep
            num_coming: number of new tokens coming in (q_len)
            layer_idx: index of the layer to update
        """
        score = self.retrieve_score(layer_idx)  # [bsz, num_heads, kv_len]
        _, _, seq_len = score.shape
        if num_coming >= num_recent_tokens:  # only in streaming mode
            num_heavy_tokens -= (num_coming - num_recent_tokens)
            num_heavy_tokens = max(num_heavy_tokens, 0)

        # select heavy tokens to keep
        recent_cutoff = seq_len - num_recent_tokens + num_coming
        score_to_select = score[..., :recent_cutoff]
        if self.pool_fn in pool_fns.keys():
            score_to_select = pool_fns[self.pool_fn](score_to_select, kernel_size=7, padding=3, stride=1)
        
        if self.num_sink > 0:
            score_to_select[:, :, :self.num_sink] = torch.finfo(score.dtype).max

        _, keep_topk_idx = torch.topk(score_to_select, num_heavy_tokens, dim=-1)
        keep_topk_idx = keep_topk_idx.sort().values        # [bsz, num_heads, num_heavy_tokens]

        # update scores: remove evicted columns
        if self.use_v_score:
            if not (self.ptb_is_recent and self.ptb_window == 1):  # tova does not need accumulation
                self.all_a_normp[layer_idx] = take_gather(self.all_a_normp[layer_idx], keep_topk_idx, recent_cutoff, gather_dim=-1)

            if self.use_act:
                self.all_v_normp[layer_idx] = take_gather(self.all_v_normp[layer_idx], keep_topk_idx, recent_cutoff, gather_dim=-1)
    
        if self.use_k_score:
            if not (self.ptb_is_recent and self.ptb_window == 1):
                self.all_k_scores[layer_idx] = take_gather(self.all_k_scores[layer_idx], keep_topk_idx, recent_cutoff, gather_dim=-1)

        return keep_topk_idx

    def reset(self):
        if self.use_v_score:
            self.all_a_normp.clear()
            if self.use_act:
                self.all_v_normp.clear()

        if self.use_k_score:
            self.all_k_scores.clear()


def take_gather(buf, keep_topk_idx, recent_cutoff, gather_dim=-2):
    """
    Args: 
        buf: [bsz, heads, (...), seq_len, (...)]
        keep_topk_idx: [bsz, heads, num_hh]
        recent_cutoff: int
    """    
    if gather_dim == -1: 
        if buf.dim() == keep_topk_idx.dim():
            # buf: [bsz, heads, seq_len] for accumulative scores
            select_buf = buf[..., :recent_cutoff]                                         # [bsz, heads, recent_cutoff]
            keep_recent = buf[..., recent_cutoff:]                                        # [bsz, heads, num_recent]
        elif buf.dim() == keep_topk_idx.dim() + 1:
            # buf: [bsz, heads, q_len, seq_len] for scores (before accumulation)
            select_buf = buf[..., :recent_cutoff]                                         # [bsz, heads, q_len, recent_cutoff]
            keep_recent = buf[..., recent_cutoff:]                                        # [bsz, heads, q_len, num_recent]
            keep_topk_idx = keep_topk_idx.unsqueeze(-2).expand(-1, -1, buf.size(-2), -1)  # [bsz, heads, q_len, num_hh]

    elif gather_dim == -2: # buf: [bsz, heads, seq_len, head_dim] for kv
        select_buf = buf[..., :recent_cutoff, :]                                          # [bsz, heads, recent_cutoff, head_dim]
        keep_recent = buf[..., recent_cutoff:, :]                                         # [bsz, heads, num_recent, head_dim]
        keep_topk_idx = keep_topk_idx.unsqueeze(-1).expand(-1, -1, -1, buf.size(-1))      # [bsz, heads, num_hh, head_dim]

    else:
        raise ValueError(f"gather_dim {gather_dim} not supported")


    hh = torch.gather(select_buf, dim=gather_dim, index=keep_topk_idx)                # [bsz, heads, num_hh, ...]
    return torch.cat([hh, keep_recent], dim=gather_dim)                               # [bsz, heads, num_hh + num_recent, ...]