from typing import List, Tuple, Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

from transformers.cache_utils import Cache, QuantizedCache, QuantizedCacheProcessor
from .quant_utils import fake_quant_channel_wise
from typing import Any, Callable, Optional, Union

from src.model.feature_map import init_feature_map, init_learned_kernel
from .utils import repeat_kv, apply_rotary_pos_emb, get_causal_mask, get_chunk_mask
import torch.utils.checkpoint as cp
from transformers.integrations.flash_attention import flash_attention_forward
from .quant_utils import triton_quantize_and_pack_along_last_dim
from .fused_attention import kvlinc_attention_forward

# ----------------------
# Sliding window helpers
# ----------------------


def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None, 
                      causal: bool = True, fp32_attention: bool = False,
                      ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """
    Standard softmax attention; only compute outputs if v is not None
    -> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
    """
    y = None
    a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5)
    if causal:  # Apply causal mask
        m, n = a.shape[-2:]
        causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
        a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
    if fp32_attention:
        a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
    else:
        a = torch.softmax(a, dim=-1)
    if v is not None:
        y = torch.einsum('bhmn,bhnd->bhmd', a, v)
    return y, a, None

def softmax_attention_quant_fp(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, k_q : torch.Tensor,
                      residual_length: int, mask_value: float=-float("inf"),
                      ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """
    Standard softmax attention; only compute outputs if v is not None
    -> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
    """
    causal_mask = get_causal_mask(q.shape[-2], k.shape[-2], q.device)
    fp_mask = get_chunk_mask(q.shape[-2], residual_length, q.device)
    quant_mask = causal_mask & ~fp_mask  # quant_mask = causal minus fp
    
    inv_sqrt_d = 1.0/math.sqrt(k.shape[-1])
    
    # 1a. full precision attention on all tokens (use for Teacher labels)
    a = torch.matmul(q, k.transpose(-1, -2)) * inv_sqrt_d
    a_t = a.masked_fill(~causal_mask.bool(), mask_value)
    a_t_m = torch.amax(a_t, dim=-1, keepdim=True)
    a_t   = torch.exp(a_t - a_t_m) 
    a_t = a_t / (a_t.sum(dim=-1,keepdim=True))
    y_t = torch.matmul(a_t, v)


    # 2a. full precision softmax attention on recent tokens
    a_fp = a.masked_fill(~fp_mask.bool(), 0)
    

    # 2b. quantized attention on older tokens
    a_q = torch.matmul(q, k_q.transpose(-1, -2)) * inv_sqrt_d 
    a_q = a_q.masked_fill(~quant_mask.bool(), 0)
    
    # 2c. combine
    a_qfp = a_fp + a_q

    a_qfp = a_qfp.masked_fill(~causal_mask.bool(), mask_value)

    a_qfp_m = torch.amax(a_qfp, dim=-1, keepdim=True)
    a_qfp   = torch.exp(a_qfp - a_qfp_m) # unnormalized

    return (a_t, y_t), (a_qfp)  # attention weights only for the last chunk   



def kvlinc_attention(f_q: torch.Tensor,
                    f_k: torch.Tensor,
                    a_qfp : torch.Tensor, # unnormalized
                    residual_length: int,
                    seqlen: int):
    """
    Hybrid attention combining sliding window and linear attentions
    """
    causal_mask= get_causal_mask(seqlen, seqlen, f_q.device)
    fp_mask = get_chunk_mask(seqlen, residual_length, f_q.device)
    quant_mask = causal_mask & ~fp_mask  # window = causal minus fp
    

    # 2. Linear correction 1
    a_ln_1 = torch.matmul(f_q, f_k.transpose(-1, -2))
    a_ln_1 = a_ln_1.masked_fill(~quant_mask, 0)

    a = (a_qfp + a_ln_1)
    a =  a / (a.sum(dim=-1, keepdim=True))     

    del a_qfp, a_ln_1
    

    return a  


# ---------------------
# Attention layer class
# ---------------------

class KVLinCAttention(nn.Module):
    """
    KVLINC attention implementation
    """
    def __init__(self,
                 base_attn: nn.Module,  # like LlamaAttention
                 feature_map: str,
                 feature_map_kwargs: dict,
                 layer_idx: Optional[int] = None,
                 learned_kernel: Optional[str] = None,
                 learned_kernel_kwargs: Optional[dict] = None,
                 train_attention: Optional[bool] = False,
                 mask_value: int = 0,
                 eps: float = 1e-12,
                 rank: Optional[int] = 0,
                 kvquant: Optional[dict] = None, 
                 **kwargs: any) -> None:
        super().__init__()
        self.base_config = getattr(base_attn, 'config', None)
        self.attention_type = "kv_linc"
        self.config = self.base_config
        self.mask_value = mask_value
        self.eps = eps
        self.layer_idx = base_attn.layer_idx
        self.train_attention = train_attention
        self.skip_training_this_layer = not train_attention
        if rank == 0:  # multi-gpu
            if layer_idx == 0 and feature_map_kwargs is not None:
                for k, v in feature_map_kwargs.items():
                    print(f'-> {k}: {v}')
            if layer_idx == 0 and learned_kernel_kwargs is not None:
                for k, v in learned_kernel_kwargs.items():
                    print(f'-> {k}: {v}')
                    

        self.feature_map_before_repeat = kwargs['feature_map_before_repeat'] ### apply feature map before or after repeat_kv
        

        self.init_weights_(base_attn)
        self.init_feature_map_(feature_map, feature_map_kwargs,
                               learned_kernel, learned_kernel_kwargs)
        
        self.kvquant_config = kvquant
        self.residual_length = kvquant['residual_length']
        self.q_group_size = kvquant['q_group_size']
        self.nbits = kvquant['nbits']

        self.scaling = base_attn.scaling
        self.is_causal=base_attn.is_causal


    def init_feature_map_(self,
                          feature_map: str,
                          feature_map_kwargs: dict,
                          learned_kernel: str = None,
                          learned_kernel_kwargs: dict = None):
        """
        Initialize MLP-based feature map
        """
        self.fmap_gqa = False  # Turn True if specified below
        if learned_kernel is not None:
            # Ensure dict
            learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
            learned_kernel_kwargs['num_heads'] = self.num_heads
            learned_kernel_kwargs['head_dim']  = self.head_dim
            learned_kernel_kwargs['dtype']     = self.q_proj.weight.dtype
            learned_kernel_kwargs['device']    = self.q_proj.weight.device
            # Create MLP
            mlp_learned_kernel_q = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)
            if self.feature_map_before_repeat :
                learned_kernel_kwargs['num_heads'] = self.num_key_value_heads
            mlp_learned_kernel_k = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)


        # Add "activation"; see src.models.feature_map.py
        if not self.feature_map_before_repeat:
            self.feature_map_q = init_feature_map(name=feature_map,
                                                mlp=mlp_learned_kernel_q,
                                                **feature_map_kwargs)
            self.feature_map_k = copy.deepcopy(self.feature_map_q)
        else:
            self.feature_map_q = init_feature_map(name=feature_map,
                                                mlp=mlp_learned_kernel_q,
                                                **feature_map_kwargs)
            self.feature_map_k = init_feature_map(name=feature_map,
                                                mlp=mlp_learned_kernel_k,
                                                **feature_map_kwargs)

        

    def init_weights_(self, base_attn: nn.Module):
        """
        Initialize module layers, weights, positional dependencies, etc. 
        from original softmax attention layer (base_attn)
        """
        # Make other attributes accessible
        self.attention_dropout = 0  # We don't use dropout
        self.hidden_size = base_attn.config.hidden_size
        self.num_heads = base_attn.config.num_attention_heads
        self.head_dim = base_attn.head_dim
        self.num_key_value_heads = base_attn.config.num_key_value_heads
        self.num_key_value_groups = base_attn.num_key_value_groups

        self.q_shape = [self.num_heads, self.head_dim]
        self.k_shape = [self.num_key_value_heads, self.head_dim]
        self.v_shape = [self.num_key_value_heads, self.head_dim]

        # Copy original model projection layers
        self.q_proj = base_attn.q_proj
        self.k_proj = base_attn.k_proj
        self.v_proj = base_attn.v_proj
        self.o_proj = base_attn.o_proj
        if hasattr(base_attn, "k_norm"):
            # used in qwen3 models
            self.k_norm = base_attn.k_norm
            self.q_norm = base_attn.q_norm
        else:
            self.k_norm = None
            self.q_norm = None

        try:  # If wanting to use FA2 for ground-truth inference
            self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
        except AttributeError: 
            pass

        del base_attn  # We don't need to keep these around

    def process_qkv(self,
                    hidden_states: torch.Tensor,
                    position_embeddings: tuple[torch.Tensor, torch.Tensor],):
        """
        Compute queries, keys, and values
        """
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        # Shape is (batch_size, seq_len, num_heads, head_dim)
        if self.k_norm is not None:
            # qwen3 model
            q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
            k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        else:
            q = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            k = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        v = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        return q, k, v
    
    
    def repeat_tensors(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # repeat each tensor independently
        return tuple(repeat_kv(t, self.num_key_value_groups) for t in tensors)

    def forward(self,
                hidden_states: torch.Tensor,
                position_embeddings: tuple[torch.Tensor, torch.Tensor],
                attention_mask: Optional[torch.Tensor] = None,
                cache_position: Optional[torch.LongTensor] = None,
                past_key_value: Optional[Cache] = None,
                **kwargs,
               ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with the option to compute attention weights multiple ways
        if self.train_attention is True
        -> Consistent with HuggingFace Transformers for easy use with their pretrained models
        """
        b, l, _ = hidden_states.size()
        input_shape = hidden_states.shape[:-1]
        with torch.no_grad():
            q, k, v = self.process_qkv(hidden_states, position_embeddings)

        if not self.feature_map_before_repeat:
            k, v = self.repeat_tensors(k, v)
        if self.train_attention :
            if not self.skip_training_this_layer:
                # quantize kv cache using your favorite quantizer. 
                k_q = fake_quant_channel_wise(k,self.q_group_size,self.nbits)

                # quantization error
                k_e = k - k_q
                
                # kernel functions 
                f_q, f_k = self.feature_map_q(q), self.feature_map_k(k_e)

                if self.feature_map_before_repeat:
                    k, v, k_q, f_k = self.repeat_tensors(k, v, k_q, f_k)
                
                # 1. Compute "ground-truth" attention output and weights
                with torch.no_grad():
                    (a_true, y_true), (a_qfp) = softmax_attention_quant_fp(q, k, v, k_q, self.residual_length)
                    y_true = y_true.transpose(1, 2).contiguous().reshape(*input_shape, -1)
                    y_true = self.o_proj(y_true)

                # 2. Compute "predicted" attention outputs
                a_pred = cp.checkpoint(kvlinc_attention, f_q, f_k, a_qfp, self.residual_length, q.shape[-2])
                attn_weights = ((a_pred, a_true), (None, None))
            else:
                if self.feature_map_before_repeat:
                    k, v = self.repeat_tensors(k, v)
                with torch.no_grad():
                    _y_true, _ = flash_attention_forward(self, q, k, v, attention_mask = attention_mask, dropout=0.0, scaling = 1.0/math.sqrt(self.head_dim), **kwargs)
                    y_true = _y_true.reshape(*input_shape, -1).contiguous()     
                    y_true = self.o_proj(y_true)
                attn_weights=None
        else:
            attn_weights = None
            if past_key_value is not None:
                cos, sin = position_embeddings
                # sin and cos are specific to RoPE models; cache_position needed for the static cache
                cache_kwargs = {
                        "sin": sin, "cos": cos, "cache_position": cache_position,
                        "feature_map_k": self.feature_map_k,
                        }
                # (k_attn, v_attn), (ke_v_state, ke_state) = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
                (k_fp, v_fp), (k_q, k_sc, k_mn, v_q, v_sc, v_mn, ke_v_state, ke_state) = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
                

                if q.shape[-2] == 1:
                    # generating
                    if k_q is None:
                        _y_true, _ = flash_attention_forward(self, q, k_fp, v_fp, attention_mask = attention_mask, dropout=0.0, scaling = 1.0/math.sqrt(self.head_dim), **kwargs)
                        y_true = _y_true.reshape(*input_shape, -1).contiguous()
                    else:
                        if self.feature_map_before_repeat:
                            if ke_v_state is not None:
                                # repeat for GQA
                                ke_v_state, ke_state = self.repeat_tensors(ke_v_state, ke_state)

                        # query feature maps for linear correction
                        f_q = self.feature_map_q(q)

                        # numerator and denominator of linear correction
                        a_ln = torch.matmul(f_q.float(), ke_v_state.float())
                        sum_ln = torch.matmul(f_q.float(), ke_state.float().transpose(-1, -2))
                        # kvlinc attention
                        _y_true = kvlinc_attention_forward(q, k_fp, v_fp, k_q, k_sc, k_mn, v_q, v_sc, v_mn, a_ln, sum_ln, self.q_group_size,self.nbits,softmax_scale = 1.0/math.sqrt(self.head_dim))
                        y_true = _y_true.transpose(1, 2).contiguous().reshape(*input_shape, -1)

                else:
                    # prefill
                    _y_true, _ = flash_attention_forward(self, q, k_fp, v_fp, attention_mask = attention_mask, dropout=0.0, scaling = 1.0/math.sqrt(self.head_dim), **kwargs)
                    y_true = _y_true.reshape(*input_shape, -1).contiguous()
                    
            else:
                _y_true, _ = flash_attention_forward(self, q, k, v, attention_mask = attention_mask, dropout=0.0, scaling = 1.0/math.sqrt(self.head_dim), **kwargs)
                y_true = _y_true.reshape(*input_shape, -1).contiguous()  

            y_true = self.o_proj(y_true)
        

        return y_true, attn_weights


class KVLinCAttentionCache(QuantizedCache):
    """
    Class for `past_key_values`
    -> Stores quantized kv cache with full precision recent tokens.
    -> Also stores linear correction states
    -> Modified from transformers.cache_utils.QuantizedCache
    """
    def __init__(self, **kwargs) -> None:
         # Call the grandparent class's __init__ directly
        super(QuantizedCache, self).__init__(cache_processor=KVLinc_CacheProcessor, **kwargs)

class KVLinc_CacheProcessor(QuantizedCacheProcessor):
    """
    Quantized cache processor that uses kvlinc quantization backend
    """
    def __init__(
            self,
            cache: "Cache",
            backend: str = "kvlinc",
            nbits: int = 2,
            axis_key: int = 0,
            axis_value: int = 0,
            q_group_size: int = 64,
            residual_length: int = 128,
            compute_dtype: torch.dtype = torch.float16,
            device: str = "cpu",
    )-> None:
        """Initialize the kivi_kvlinc quantization processor.
        self.keys, self.values => full precision recent keys and values
        self._quantized_keys, self._quantized_values => quantized keys and values
        self._ke_v_state, self._ke_state => Linear correction terms
        """
        super().__init__(
            cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device
        )
        if backend != "kvlinc":
            raise ValueError(f"KVLinc_CacheProcessor only supports `kvlinc` backend, but got {backend}")
        self._ke_v_state: List[torch.Tensor] = [] # kv state for k_e and v (numerator linear_correction)
        self._ke_state:  List[torch.Tensor] = [] # k_state for k_e (denominator linear_correction)
        self._k_sc: List[torch.Tensor] = [] # _k_sc scale factor for keys
        self._k_mn: List[torch.Tensor] = [] # _k_mn zero point for keys
        self._v_sc: List[torch.Tensor] = [] # _v_sc scale factor for keys
        self._v_mn: List[torch.Tensor] = [] # _v_mn zero point for keys
        self.linc_len = 0
        self.max_linc_len = 8400

    def post_update(
            self, 
            cache: "Cache",
            k: torch.Tensor,
            v: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[dict[str, Any]] = None,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            "Apply quantization after cache update"
            if len(cache) < layer_idx:
                raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
            
            feature_map_k = cache_kwargs["feature_map_k"]
            # `k` is the content of the residual cache, after having been updated by DynamicLayer
            # quantization process is adopted from KIVI : https://arxiv.org/abs/2402.02750
            # keep recent tokens upto residual length in full precision and quantize the tokens once window is full.
            if self._is_quantized_length_zero(layer_idx):
                # prefill
                k_fp_to_return, v_fp_to_return = k, v
                k_q_to_return, k_sc_to_return, k_mn_to_return = None, None, None
                v_q_to_return, v_sc_to_return, v_mn_to_return = None, None, None   
                ke_v_state_to_return, ke_state_to_return = None, None
                if layer_idx == 0: # only one layer updates linc_len
                    self.linc_len = k.shape[-2] 
                # get tokens within residual_length window and outside the window
                if k.shape[-2] % self.residual_length != 0:
                    if k.shape[-2] < self.residual_length:
                        # nothing to quantize. residual window not full yet.
                        self._quantized_keys.append(None)
                        self._k_sc.append(None)
                        self._k_mn.append(None)

                        self._quantized_values.append(None)
                        self._v_sc.append(None)
                        self._v_mn.append(None)

                        self._ke_v_state.append(0)
                        self._ke_state.append(0)
                    else:
                        # keys within window are in fp cache
                        cache.layers[layer_idx].keys = k[:, :, -(k.shape[-2] % self.residual_length):, :].contiguous()
                        cache.layers[layer_idx].values = v[:, :, -(v.shape[-2] % self.residual_length):, :].contiguous()

                        # quantize outside window
                        k_outside_window =  k[:, :, :-(k.shape[-2] % self.residual_length), :].contiguous()
                        v_outside_window = v[:, :, :-(v.shape[-2] % self.residual_length), :].contiguous()
                        
                        self.erased_length = k_outside_window.shape[-2]

                        k_q, k_sc, k_mn, k_e, _ = triton_quantize_and_pack_along_last_dim(k_outside_window.transpose(2,3).contiguous(), self.q_group_size, self.nbits)
                        v_q, v_sc, v_mn, _, v_dq = triton_quantize_and_pack_along_last_dim(v_outside_window, self.q_group_size, self.nbits)

                        self._quantized_keys.append(k_q)
                        self._k_sc.append(k_sc)
                        self._k_mn.append(k_mn)

                        self._quantized_values.append(v_q)
                        self._v_sc.append(v_sc)
                        self._v_mn.append(v_mn)

                        if k_outside_window.shape[-2] < self.max_linc_len: 
                            #correct errors if within max_linc_len
                            # get feature maps for keys
                            f_k = feature_map_k(k_e.transpose(2,3))

                            # create linear correction states
                            ke_v_state = f_k.transpose(-1, -2) @ v_dq   # [B, H, F, D]
                            ke_state  = f_k.sum(dim=-2, keepdim=True) # b, h, 1, f; note the 1

                            self._ke_v_state.append(ke_v_state)
                            self._ke_state.append(ke_state)
                        else:
                            # ignore after max_linc_len
                            self._ke_v_state.append(None)
                            self._ke_state.append(None)
                else:
                    k_outside_window = k
                    v_outside_window = v

                    self.erased_length = k_outside_window.shape[-2]

                    # empty fp cache
                    cache.layers[layer_idx].keys = torch.zeros(
                        0,
                        dtype=k.dtype,
                        device=k.device,
                    )
                    cache.layers[layer_idx].values = torch.zeros(
                        0,
                        dtype=v.dtype,
                        device=v.device,
                    )

                    k_q, k_sc, k_mn, k_e, _ = triton_quantize_and_pack_along_last_dim(k_outside_window.transpose(2,3).contiguous(), self.q_group_size, self.nbits)
                    v_q, v_sc, v_mn, _, v_dq = triton_quantize_and_pack_along_last_dim(v_outside_window, self.q_group_size, self.nbits)

                    self._quantized_keys.append(k_q)
                    self._k_sc.append(k_sc)
                    self._k_mn.append(k_mn)

                    self._quantized_values.append(v_q)
                    self._v_sc.append(v_sc)
                    self._v_mn.append(v_mn)

                    if k_outside_window.shape[-2] < self.max_linc_len:
                        #correct errors if within max_linc_len
                        # get feature maps for keys
                        f_k = feature_map_k(k_e.transpose(2,3))

                        # create linear correction states
                        prod = f_k.transpose(-1, -2) @ v_dq   # [B, H, F, D]
                        ke_v_state = prod
                        ke_state  = f_k.sum(dim=-2, keepdim=True) # b, h, 1, f; note the 1

                        self._ke_v_state.append(ke_v_state)
                        self._ke_state.append(ke_state)
                    else:
                        # ignore after max_linc_len
                        self._ke_v_state.append(None)
                        self._ke_state.append(None)
                
            else:
                #generating
                k_q, k_sc, k_mn = self._quantized_keys[layer_idx], self._k_sc[layer_idx], self._k_mn[layer_idx]
                v_q, v_sc, v_mn = self._quantized_values[layer_idx], self._v_sc[layer_idx], self._v_mn[layer_idx]
                if layer_idx == 0: # only one layer updates linc_len
                    self.linc_len += 1
                if k_q is not None:
                    k_fp_to_return = k
                    v_fp_to_return = v
                    k_q_to_return, k_sc_to_return, k_mn_to_return = k_q, k_sc, k_mn
                    v_q_to_return, v_sc_to_return, v_mn_to_return = v_q, v_sc, v_mn
                    ke_v_state_to_return = self._ke_v_state[layer_idx]
                    ke_state_to_return = self._ke_state[layer_idx]

                    # ke_v_state_to_return = None
                    # ke_state_to_return = None
                else:
                    k_fp_to_return = k
                    v_fp_to_return = v
                    k_q_to_return, k_sc_to_return, k_mn_to_return = None, None, None
                    v_q_to_return, v_sc_to_return, v_mn_to_return = None, None, None
                    ke_v_state_to_return, ke_state_to_return = None, None

                if k.shape[-2] >= self.residual_length:
                    # quantize the window
                    k_q_new, k_sc_new, k_mn_new, k_e, _ = triton_quantize_and_pack_along_last_dim(k.transpose(2,3), self.q_group_size, self.nbits)
                    v_q_new, v_sc_new, v_mn_new, _, v_dq_new  = triton_quantize_and_pack_along_last_dim(v, self.q_group_size, self.nbits)

                    if self.linc_len < self.max_linc_len:
                        #correct errors if within max_linc_len
                        f_k = feature_map_k(k_e.transpose(2,3))

                        prod = f_k.transpose(-1, -2) @ v_dq_new   # [B, H, F, D]

                        self._ke_v_state[layer_idx] += prod
                        self._ke_state[layer_idx] += f_k.sum(dim=-2, keepdim=True) # b, h, 1, f; note the 1
                    else:
                        # ignore after max_linc_len
                        self._ke_v_state[layer_idx] = None
                        self._ke_state[layer_idx] = None
                    
                    if k_q is not None:
                        self._quantized_keys[layer_idx] = torch.cat([k_q, k_q_new], dim=-1)
                        self._k_sc[layer_idx] = torch.cat([k_sc, k_sc_new], dim=-1)
                        self._k_mn[layer_idx] = torch.cat([k_mn, k_mn_new], dim=-1)

                        self._quantized_values[layer_idx] = torch.cat([v_q, v_q_new], dim = -2)
                        self._v_sc[layer_idx] = torch.cat([v_sc, v_sc_new], dim= -2)
                        self._v_mn[layer_idx] = torch.cat([v_mn, v_mn_new], dim= -2)
                    else:
                        self._quantized_keys[layer_idx] = k_q_new
                        self._k_sc[layer_idx] = k_sc_new
                        self._k_mn[layer_idx] = k_mn_new

                        self._quantized_values[layer_idx] = v_q_new
                        self._v_sc[layer_idx] = v_sc_new
                        self._v_mn[layer_idx] = v_mn_new

                    cache.layers[layer_idx].keys =  torch.zeros(
                        0,
                        dtype=k.dtype,
                        device=k.device,
                    )
                    cache.layers[layer_idx].values =  torch.zeros(
                        0,
                        dtype=v.dtype,
                        device=v.device,
                    )
                    
                    self.erased_length += k.shape[-2]
            return (k_fp_to_return, v_fp_to_return), (k_q_to_return, k_sc_to_return, k_mn_to_return, v_q_to_return, v_sc_to_return, v_mn_to_return, ke_v_state_to_return, ke_state_to_return) # hack because @apply_processor in Transformers.cache_utils expects 2 outputs





                    


        
