
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from typing import List, Optional, Tuple

from transformers import DynamicCache
import numpy as np

torch.autograd.set_detect_anomaly(True)

from transformers.utils import (
    logging,
)


logger = logging.get_logger(__name__)

LLAMA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance, see our
            [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
            the complete sequence length.
"""


class DynamicCache_LA_SA_back(DynamicCache):

    def __init__(self, layer_num) -> None:
        super().__init__()
        self.L_cache: List[torch.Tensor] = []
        self.k_sum: List[torch.Tensor] = []
        self.v_sum: List[torch.Tensor] = []
        self.lin_cached: List[bool] = []
        self.layer_num = layer_num
        self.prefill_stage: List[bool] = [True] * layer_num
        self.attn_score: List[torch.Tensor] = []

    def linear_cache(self, layer_idx: int, ):
        return self.L_cache[layer_idx], self.k_sum[layer_idx], self.v_sum[layer_idx], self._seen_tokens, self.key_cache[
            layer_idx], self.value_cache[layer_idx]

    def update(
            self,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # get key_state's shape
        B, H, T, D = key_states.shape

        # get key_states's device
        device = key_states.device

        # get key_states's dtype
        dtype = key_states.dtype

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])

                    # linear related
                    self.L_cache.append(torch.zeros(1, H, D, D, device=device, dtype=dtype))
                    self.k_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                    self.v_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                    self.lin_cached.append(False)

                self.key_cache.append(key_states)
                self.value_cache.append(value_states)

                self.L_cache.append(torch.zeros(1, H, D, D, device=device, dtype=dtype))
                self.k_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.v_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached.append(False)
            elif (
                    len(self.key_cache[layer_idx]) == 0
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states

                self.L_cache[layer_idx] = (torch.zeros(1, H, D, D, device=device, dtype=dtype))
                self.k_sum[layer_idx] = (torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.v_sum[layer_idx] = (torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached[layer_idx] = False
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def update_index(
            self,
            layer_idx: int,
            selected_indices: torch.Tensor,
            mink_indices: torch.Tensor,
            window_size: int,
    ):
        """
        Removes the key/value entries at selected_indices along the seq_len (T) dimension for each (B, H) pair.

        Assumes:
            selected_indices: [B, H, 1]
            self.key_cache[layer_idx]: [B, H, T, D]
            self.value_cache[layer_idx]: [B, H, T, D]
        """
        # print("self.L_cache[layer_idx]_before_add", self.L_cache[layer_idx][0,0,0:2,0:2])
        # print("before cut", self.value_cache[layer_idx].shape)
        red_k = torch.gather(self.key_cache[layer_idx][:, :, :-window_size, :], dim=-2, index=mink_indices)
        red_v = torch.gather(self.value_cache[layer_idx][:, :, :-window_size, :], dim=-2, index=mink_indices)
        # print("red_k", red_k.shape)
        self.L_cache[layer_idx] = self.L_cache[layer_idx] + torch.matmul(red_k.transpose(2, 3), red_v)
        # print("self.L_cache[layer_idx]_after_add", self.L_cache[layer_idx][0, 0, 0:2, 0:2])
        red_k = torch.sum(red_k, dim=-2, keepdim=True)
        red_v = torch.sum(red_v, dim=-2, keepdim=True)
        # print("self.k_sum[layer_idx]", self.k_sum[layer_idx].shape)
        self.k_sum[layer_idx] = self.k_sum[layer_idx] + red_k
        self.v_sum[layer_idx] = self.v_sum[layer_idx] + red_v

        self.key_cache[layer_idx] = torch.cat(
            [torch.gather(self.key_cache[layer_idx][:, :, :-window_size, :], dim=-2, index=selected_indices),
             self.key_cache[layer_idx][:, :, -window_size:, :]], dim=-2)
        self.value_cache[layer_idx] = torch.cat(
            [torch.gather(self.value_cache[layer_idx][:, :, :-window_size, :], dim=-2, index=selected_indices),
             self.value_cache[layer_idx][:, :, -window_size:, :]], dim=-2)

        # print("after cut", self.value_cache[layer_idx].shape)

        self.lin_cached[layer_idx] = True

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
                          num_hidden_layers=None) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
        backward compatibility."""
        cache = cls(num_hidden_layers)
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache


class DynamicCache_LA_SA(DynamicCache):

    def __init__(self, layer_num) -> None:
        super().__init__()
        self.average_cache_size = 128
        self.min_window = 32
        self.layer_num = layer_num
        self.L_cache: List[torch.Tensor] = []
        self.k_sum: List[torch.Tensor] = []
        self.v_sum: List[torch.Tensor] = []
        self.lin_cached: List[bool] = []
        self.prefill_stage: List[bool] = [True] * layer_num
        self.pref_scores = []
        self.evict_scores = []
        self.layer_budget = []
        self.window_size = []

    def linear_cache(self,layer_idx: int,):
        return self.L_cache[layer_idx], self.k_sum[layer_idx], self.v_sum[layer_idx], self._seen_tokens, self.key_cache[layer_idx], self.value_cache[layer_idx]

    def budget_update(self, layer_idx):
        #print(layer_idx,"layer before allocation budget", self.layer_budget)

        total_budget = (self.layer_budget[layer_idx]-self.window_size[layer_idx]-self.min_window) * self.layer_num
        total_score = sum(self.pref_scores)
        # for i in range(len(self.layer_budget)):
        #     self.layer_budget[i] = int((self.pref_scores[i]/total_score)*total_budget)+self.window_size[i]+self.min_window
        #print(layer_idx,"layer after allocation budget", self.layer_budget)
        #return self.layer_budget[layer_idx]

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # get key_state's shape
        B, H, T, D = key_states.shape

        # get key_states's device
        device = key_states.device

        # get key_states's dtype
        dtype = key_states.dtype

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])

                    # linear related
                    self.L_cache.append(torch.zeros( 1, H, D, D, device=device, dtype=dtype))
                    self.k_sum.append(torch.zeros( 1, H, 1, D, device=device, dtype=dtype))
                    self.v_sum.append(torch.zeros( 1, H, 1, D, device=device, dtype=dtype))
                    self.lin_cached.append(False)

                self.key_cache.append(key_states)
                self.value_cache.append(value_states)

                self.L_cache.append(torch.zeros( 1, H, D, D, device=device, dtype=dtype))
                self.k_sum.append(torch.zeros( 1, H, 1, D, device=device, dtype=dtype))
                self.v_sum.append(torch.zeros( 1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached.append(False)
            elif (
                len(self.key_cache[layer_idx]) == 0
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states

                self.L_cache[layer_idx] = (torch.zeros( 1, H, D, D, device=device, dtype=dtype))
                self.k_sum[layer_idx] = (torch.zeros( 1, H, 1, D, device=device, dtype=dtype))
                self.v_sum[layer_idx] = (torch.zeros( 1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached[layer_idx] = False
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def update_index(
            self,
            layer_idx: int,
    ):
        """
        Removes the key/value entries at selected_indices along the seq_len (T) dimension for each (B, H) pair.

        Assumes:
            selected_indices: [B, H, 1]
            self.key_cache[layer_idx]: [B, H, T, D]
            self.value_cache[layer_idx]: [B, H, T, D]
        """
        _, H_num, seq_len, D =  self.key_cache[layer_idx].shape
        #print(layer_idx,"layer, before eviction", self.key_cache[layer_idx].shape)
        if seq_len> self.layer_budget[layer_idx]:
            # print("mink indice num", seq_len - self.layer_budget[layer_idx])
            # print("topk indice num", self.layer_budget[layer_idx] - self.window_size[layer_idx])
            # print("evict_scores1", self.evict_scores[layer_idx].shape)
            # print("red_k", self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :].shape)
            mink_indices = self.evict_scores[layer_idx].topk((seq_len - self.layer_budget[layer_idx]), dim=-1, largest=False).indices
            topk_indices = self.evict_scores[layer_idx].topk((self.layer_budget[layer_idx] - self.window_size[layer_idx]), dim=-1).indices

            self.evict_scores[layer_idx] = torch.gather(self.evict_scores[layer_idx], dim=-1,
                                 index=topk_indices)
            # print("evict_scores2", self.evict_scores[layer_idx].shape)

            topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            mink_indices = mink_indices.unsqueeze(-1).expand(-1, -1, -1, D)

            # print("self.L_cache[layer_idx]_before_add", self.L_cache[layer_idx][0,0,0:2,0:2])
            # print("before cut", self.value_cache[layer_idx].shape)
            red_k = torch.gather(self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=mink_indices)
            red_v = torch.gather(self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=mink_indices)
            # print("red_k", red_k.shape)
            self.L_cache[layer_idx] = self.L_cache[layer_idx] + torch.matmul(red_k.transpose(2, 3), red_v)
            # print("self.L_cache[layer_idx]_after_add", self.L_cache[layer_idx][0, 0, 0:2, 0:2])
            red_k = torch.sum(red_k, dim=-2, keepdim=True)
            red_v = torch.sum(red_v, dim=-2, keepdim=True)
            # print("self.k_sum[layer_idx]", self.k_sum[layer_idx].shape)
            self.k_sum[layer_idx] = self.k_sum[layer_idx] + red_k
            self.v_sum[layer_idx] = self.v_sum[layer_idx] + red_v

            self.key_cache[layer_idx] = torch.cat([torch.gather(self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                                    self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)
            self.value_cache[layer_idx] = torch.cat([torch.gather(self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                                      self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)

            # print("after cut", self.value_cache[layer_idx].shape)

            self.lin_cached[layer_idx] = True
        #print(layer_idx, "layer, after eviction", self.key_cache[layer_idx].shape)
    # def update_score(
    #     self,
    #     pref_score: torch.Tensor,
    #     evict_score: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
    # ):
    #     self.pref_scores.append(pref_score)
    #     self.evict_scores.append(evict_score)

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers=None) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
        backward compatibility."""
        cache = cls(num_hidden_layers)
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache


class DynamicCache_LA_SA_1(DynamicCache):

    def __init__(self, layer_num) -> None:
        super().__init__()
        self.average_cache_size = 128
        self.min_window = 32
        self.layer_num = layer_num
        self.L_cache: List[torch.Tensor] = []
        self.k_sum: List[torch.Tensor] = []
        self.v_sum: List[torch.Tensor] = []
        self.lin_cached: List[bool] = []
        self.prefill_stage: List[bool] = [True] * layer_num
        self.pref_scores = []  # [201, 246, 73, 72, 77, 107, 104, 125, 145, 140, 156, 165, 146, 145, 177, 152, 144, 145, 115, 118, 129, 114, 114, 98, 117, 92, 115, 96, 99, 113, 95, 178]# []
        self.evict_scores = []
        self.layer_budget = []
        self.window_size = []
        self.first_token = True
        self.linear_cache_size = []
        self.seen_tokens_layerwise = [0] * layer_num

    def linear_cache(self, layer_idx: int, ):
        return self.L_cache[layer_idx], self.k_sum[layer_idx], self.v_sum[layer_idx], self.seen_tokens_layerwise[
            layer_idx], self.key_cache[layer_idx], self.value_cache[layer_idx]

    def budget_update(self, layer_idx, seq_len):

        #seq_len = seq_len-self.linear_cache_size[layer_idx]
        # 1. 초기 버짓 할당 계산
        total_budget_initial = (self.layer_budget[layer_idx] - self.window_size[
            layer_idx] - self.min_window) * self.layer_num

        # 2. 선호도 점수(pref_scores) 정규화
        temp_score = self.pref_scores.copy()
        #print(temp_score)
        total_score_temp = sum(temp_score)
        if total_score_temp == 0:
            normalized_scores = np.full(len(self.layer_budget), 1.0 / len(self.layer_budget))
        else:
            normalized_scores = temp_score / total_score_temp

        # 3. 정규화된 점수에 따라 각 레이어의 버짓 초기 업데이트 (NumPy 배열 사용)
        budget_list = np.array([0] * len(self.layer_budget), dtype=int)
        for i in range(len(self.layer_budget)):
            budget_list[i] = int(normalized_scores[i] * total_budget_initial) + self.window_size[
                i] + self.min_window + 1
        #print("budget_list", budget_list)

        # 4. seq_len을 초과하는 버짓 조정 및 재분배 (NumPy 로직 적용)

        # 초과 버짓 계산oioo
        excess = np.maximum(budget_list - seq_len, 0)
        budget_list = np.minimum(budget_list, seq_len)
        total_excess = np.sum(excess)

        if total_excess > 0:
            valid_indices = budget_list < seq_len
            num_valid = np.sum(valid_indices)

            if num_valid > 0:
                distribute_per_layer = total_excess // num_valid
                remainder = total_excess % num_valid

                budget_list[valid_indices] += distribute_per_layer
                budget_list[np.where(valid_indices)[0][:remainder]] += 1

        # 5. 최종 버짓을 self.layer_budget에 반영
        self.layer_budget = budget_list.tolist()
        #print("self.layer_budget", self.layer_budget)
        #print(self.layer_budget)
        return self.layer_budget

    def update(
            self,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[Dict[str, Any]] = None,

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # get key_state's shape
        B, H, T, D = key_states.shape

        # get key_states's device
        device = key_states.device

        # get key_states's dtype
        dtype = key_states.dtype

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        self.seen_tokens_layerwise[layer_idx] += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])

                    # linear related
                    self.L_cache.append(torch.zeros(1, H, D, D, device=device, dtype=dtype))
                    self.k_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                    self.v_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                    self.lin_cached.append(False)

                self.key_cache.append(key_states)
                self.value_cache.append(value_states)

                self.L_cache.append(torch.zeros(1, H, D, D, device=device, dtype=dtype))
                self.k_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.v_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached.append(False)
            elif (
                    len(self.key_cache[layer_idx]) == 0
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states

                self.L_cache[layer_idx] = (torch.zeros(1, H, D, D, device=device, dtype=dtype))
                self.k_sum[layer_idx] = (torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.v_sum[layer_idx] = (torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached[layer_idx] = False
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def update_index_prefill(
            self,
            layer_idx: int,
            curr_layer_idx: int
    ):
        """
        Removes the key/value entries at selected_indices along the seq_len (T) dimension for each (B, H) pair.

        Assumes:
            selected_indices: [B, H, 1]
            self.key_cache[layer_idx]: [B, H, T, D]
            self.value_cache[layer_idx]: [B, H, T, D]
        """
        _, H_num, seq_len, D = self.key_cache[layer_idx].shape

        # print("cur_l", curr_layer_idx)
        if seq_len > self.layer_budget[layer_idx] + self.linear_cache_size[layer_idx]:
            # print("seq_len", seq_len)
            # print("self.evict_scores[layer_idx]1", self.evict_scores[layer_idx].shape)
            topk_indices = self.evict_scores[layer_idx].topk(
                (self.linear_cache_size[layer_idx] + self.layer_budget[layer_idx] - self.window_size[layer_idx]),
                dim=-1).indices
            # print("topk_indices1", topk_indices.shape)
            self.evict_scores[layer_idx] = torch.gather(self.evict_scores[layer_idx], dim=-1,
                                                        index=topk_indices)
            # print("self.evict_scores[layer_idx]2", self.evict_scores[layer_idx].shape)
            topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            # print("topk_indices2", topk_indices.shape)
            # print("self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
            # print("evicted_k", torch.gather(
            #     self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices).shape)
            # print("window_k", self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:,
            #     :].shape)

            self.key_cache[layer_idx] = torch.cat([torch.gather(
                self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)
            self.value_cache[layer_idx] = torch.cat([torch.gather(
                self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)

            # print("after cut", self.value_cache[layer_idx].shape)
        _, H_num, seq_len, D = self.key_cache[layer_idx].shape

        if (curr_layer_idx == self.layer_num - 1):
            # print("self.layer_budget[layer_idx]", self.layer_budget[layer_idx])
            # print(" self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
            #self.lin_cached[layer_idx] = True
            self.seen_tokens_layerwise[layer_idx] = seq_len
            # print("cur_layer",curr_layer_idx)
            if seq_len > self.layer_budget[layer_idx]:
                # print("layer_idx", layer_idx)
                # # print("token to go lin", seq_len - self.layer_budget[layer_idx])
                # #print("self.evict_scores[layer_idx]", self.evict_scores[layer_idx].shape)
                # print("seq_len", seq_len)
                # print("self.layer_budget[layer_idx]",self.layer_budget[layer_idx])
                # print("self.evict_scores[layer_idx].shape", self.evict_scores[layer_idx].shape)
                mink_indices = self.evict_scores[layer_idx].topk((seq_len - (self.layer_budget[layer_idx])), dim=-1,
                                                                 largest=False).indices
                # print("self.evict_scores[layer_idx].shape", self.evict_scores[layer_idx].shape)
                #print("mink_indices", mink_indices.shape)
                mink_indices = mink_indices.unsqueeze(-1).expand(-1, -1, -1, D)
                # print("mink_indices_expand", mink_indices.shape)

                topk_indices = self.evict_scores[layer_idx].topk((self.layer_budget[layer_idx]- self.window_size[layer_idx]), dim=-1,
                                                                 largest=True).indices
                topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)

                red_k = torch.gather(self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=mink_indices)
                #print("red_k", red_k.shape)
                red_v = torch.gather(self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=mink_indices)
                self.seen_tokens_layerwise[layer_idx] = seq_len

                #print("self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
                self.L_cache[layer_idx] = torch.matmul(red_k.transpose(2, 3), red_v)
                self.k_sum[layer_idx] = torch.sum(red_k, dim=-2, keepdim=True)
                self.v_sum[layer_idx] = torch.sum(red_v, dim=-2, keepdim=True)

                self.key_cache[layer_idx] = torch.cat([torch.gather(
                    self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                    self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)
                #print("self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)


                self.value_cache[layer_idx] = torch.cat([torch.gather(
                    self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                    self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)

                self.lin_cached[layer_idx] = True

    def update_index(
            self,
            layer_idx: int,
    ):
        """
        Removes the key/value entries at selected_indices along the seq_len (T) dimension for each (B, H) pair.

        Assumes:
            selected_indices: [B, H, 1]
            self.key_cache[layer_idx]: [B, H, T, D]
            self.value_cache[layer_idx]: [B, H, T, D]
        """
        _, H_num, seq_len, D = self.key_cache[layer_idx].shape
        # print(layer_idx,"layer, before eviction", self.key_cache[layer_idx].shape)
        # print("self.evict_scores[layer_idx].shape", self.evict_scores[layer_idx].shape)
        # print("self.layer_budget[layer_idx]", self.layer_budget[layer_idx])
        # print("(seq_len- (self.layer_budget[layer_idx]-self.window_size[layer_idx])", (seq_len- (self.layer_budget[layer_idx]-self.window_size[layer_idx])))
        if seq_len > self.layer_budget[layer_idx]:
            mink_indices = self.evict_scores[layer_idx].topk((seq_len- self.layer_budget[layer_idx]), dim=-1,
                                                             largest=False).indices
            #print((seq_len- (self.layer_budget[layer_idx]-self.window_size[layer_idx])))
            topk_indices = self.evict_scores[layer_idx].topk(
                (self.layer_budget[layer_idx] - self.window_size[layer_idx]), dim=-1).indices
            topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            mink_indices = mink_indices.unsqueeze(-1).expand(-1, -1, -1, D)

            red_k = torch.gather(self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2,
                                 index=mink_indices)
            red_v = torch.gather(self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2,
                                 index=mink_indices)
            # print("red_k", red_k.shape)
            self.L_cache[layer_idx] = self.L_cache[layer_idx] + torch.matmul(red_k.transpose(2, 3), red_v)
            # print("self.L_cache[layer_idx]_after_add", self.L_cache[layer_idx][0, 0, 0:2, 0:2])
            red_k = torch.sum(red_k, dim=-2, keepdim=True)
            red_v = torch.sum(red_v, dim=-2, keepdim=True)
            # print("self.k_sum[layer_idx]", self.k_sum[layer_idx].shape)
            self.k_sum[layer_idx] = self.k_sum[layer_idx] + red_k
            self.v_sum[layer_idx] = self.v_sum[layer_idx] + red_v

            self.key_cache[layer_idx] = torch.cat([torch.gather(
                self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                                                   self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]],
                                                  dim=-2)
            self.value_cache[layer_idx] = torch.cat([torch.gather(
                self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                                                     self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:,
                                                     :]], dim=-2)

            #print("after cut", self.value_cache[layer_idx].shape)

            self.lin_cached[layer_idx] = True

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
                          num_hidden_layers=None) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
        backward compatibility."""
        cache = cls(num_hidden_layers)
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache

class DynamicCache_LA_SA_2(DynamicCache):

    def __init__(self, layer_num) -> None:
        super().__init__()
        self.average_cache_size = 128
        self.min_window = 0
        self.layer_num = layer_num
        self.L_cache: List[torch.Tensor] = []
        self.k_sum: List[torch.Tensor] = []
        self.v_sum: List[torch.Tensor] = []
        self.lin_cached: List[bool] = []
        self.prefill_stage: List[bool] = [True] * layer_num
        self.pref_scores = []  # [201, 246, 73, 72, 77, 107, 104, 125, 145, 140, 156, 165, 146, 145, 177, 152, 144, 145, 115, 118, 129, 114, 114, 98, 117, 92, 115, 96, 99, 113, 95, 178]# []
        self.evict_scores = []
        self.layer_budget = []
        self.window_size = []
        self.first_token = True
        self.linear_cache_size = []
        self.seen_tokens_layerwise = [0] * layer_num

    def linear_cache(self, layer_idx: int, ):
        return self.L_cache[layer_idx], self.k_sum[layer_idx], self.v_sum[layer_idx], self.seen_tokens_layerwise[
            layer_idx], self.key_cache[layer_idx], self.value_cache[layer_idx]

    def budget_update(self, layer_idx, seq_len):

        #seq_len = seq_len-self.linear_cache_size[layer_idx]
        # 1. 초기 버짓 할당 계산
        total_budget_initial = (self.layer_budget[layer_idx] - self.window_size[
            layer_idx] - self.min_window) * self.layer_num

        # 2. 선호도 점수(pref_scores) 정규화
        temp_score = self.pref_scores.copy()
        total_score_temp = sum(temp_score)
        if total_score_temp == 0:
            normalized_scores = np.full(len(self.layer_budget), 1.0 / len(self.layer_budget))
        else:
            normalized_scores = temp_score / total_score_temp

        # 3. 정규화된 점수에 따라 각 레이어의 버짓 초기 업데이트 (NumPy 배열 사용)
        budget_list = np.array([0] * len(self.layer_budget), dtype=int)
        for i in range(len(self.layer_budget)):
            budget_list[i] = int(normalized_scores[i] * total_budget_initial) + self.window_size[
                i] + self.min_window + 1
        #print("budget_list", budget_list)

        # 4. seq_len을 초과하는 버짓 조정 및 재분배 (NumPy 로직 적용)

        # 초과 버짓 계산oioo
        excess = np.maximum(budget_list - seq_len, 0)
        budget_list = np.minimum(budget_list, seq_len)
        total_excess = np.sum(excess)

        if total_excess > 0:
            valid_indices = budget_list < seq_len
            num_valid = np.sum(valid_indices)

            if num_valid > 0:
                distribute_per_layer = total_excess // num_valid
                remainder = total_excess % num_valid

                budget_list[valid_indices] += distribute_per_layer
                budget_list[np.where(valid_indices)[0][:remainder]] += 1

        # 5. 최종 버짓을 self.layer_budget에 반영
        self.layer_budget = budget_list.tolist()
        #print("self.layer_budget", self.layer_budget)
        return self.layer_budget

    def update(
            self,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[Dict[str, Any]] = None,

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # get key_state's shape
        B, H, T, D = key_states.shape

        # get key_states's device
        device = key_states.device

        # get key_states's dtype
        dtype = key_states.dtype

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        self.seen_tokens_layerwise[layer_idx] += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])

                    # linear related
                    self.L_cache.append(torch.zeros(1, H, D, D, device=device, dtype=dtype))
                    self.k_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                    self.v_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                    self.lin_cached.append(False)

                self.key_cache.append(key_states)
                self.value_cache.append(value_states)

                self.L_cache.append(torch.zeros(1, H, D, D, device=device, dtype=dtype))
                self.k_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.v_sum.append(torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached.append(False)
            elif (
                    len(self.key_cache[layer_idx]) == 0
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states

                self.L_cache[layer_idx] = (torch.zeros(1, H, D, D, device=device, dtype=dtype))
                self.k_sum[layer_idx] = (torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.v_sum[layer_idx] = (torch.zeros(1, H, 1, D, device=device, dtype=dtype))
                self.lin_cached[layer_idx] = False
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def update_index_prefill(
            self,
            layer_idx: int,
            curr_layer_idx: int
    ):
        """
        Removes the key/value entries at selected_indices along the seq_len (T) dimension for each (B, H) pair.

        Assumes:
            selected_indices: [B, H, 1]
            self.key_cache[layer_idx]: [B, H, T, D]
            self.value_cache[layer_idx]: [B, H, T, D]
        """
        _, H_num, seq_len, D = self.key_cache[layer_idx].shape

        # print("cur_l", curr_layer_idx)
        if seq_len > self.layer_budget[layer_idx] + self.linear_cache_size[layer_idx]:
            # print("seq_len", seq_len)
            # print("self.evict_scores[layer_idx]1", self.evict_scores[layer_idx].shape)
            topk_indices = self.evict_scores[layer_idx].topk(
                (self.linear_cache_size[layer_idx] + self.layer_budget[layer_idx] - self.window_size[layer_idx]),
                dim=-1).indices
            # print("topk_indices1", topk_indices.shape)
            self.evict_scores[layer_idx] = torch.gather(self.evict_scores[layer_idx], dim=-1,
                                                        index=topk_indices)
            # print("self.evict_scores[layer_idx]2", self.evict_scores[layer_idx].shape)
            topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            # print("topk_indices2", topk_indices.shape)
            # print("self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
            # print("evicted_k", torch.gather(
            #     self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices).shape)
            # print("window_k", self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:,
            #     :].shape)

            self.key_cache[layer_idx] = torch.cat([torch.gather(
                self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)
            self.value_cache[layer_idx] = torch.cat([torch.gather(
                self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)

            # print("after cut", self.value_cache[layer_idx].shape)
        _, H_num, seq_len, D = self.key_cache[layer_idx].shape

        if (curr_layer_idx == self.layer_num - 1):
            # print("self.layer_budget[layer_idx]", self.layer_budget[layer_idx])
            # print(" self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
            #self.lin_cached[layer_idx] = True
            self.seen_tokens_layerwise[layer_idx] = seq_len
            # print("cur_layer",curr_layer_idx)
            # if seq_len > self.layer_budget[layer_idx]:
            #     # print("layer_idx", layer_idx)
            #     # # print("token to go lin", seq_len - self.layer_budget[layer_idx])
            #     # #print("self.evict_scores[layer_idx]", self.evict_scores[layer_idx].shape)
            #     # print("seq_len", seq_len)
            #     # print("self.layer_budget[layer_idx]",self.layer_budget[layer_idx])
            #     # print("self.evict_scores[layer_idx].shape", self.evict_scores[layer_idx].shape)
            #     mink_indices = self.evict_scores[layer_idx].topk((seq_len - self.layer_budget[layer_idx]), dim=-1,
            #                                                      largest=False).indices
            #     # print("self.evict_scores[layer_idx].shape", self.evict_scores[layer_idx].shape)
            #     #print("mink_indices", mink_indices.shape)
            #     mink_indices = mink_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            #     # print("mink_indices_expand", mink_indices.shape)
            #
            #     topk_indices = self.evict_scores[layer_idx].topk((self.layer_budget[layer_idx]- self.window_size[layer_idx]), dim=-1,
            #                                                      largest=True).indices
            #     topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            #
            #     red_k = torch.gather(self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=mink_indices)
            #     #print("red_k", red_k.shape)
            #     red_v = torch.gather(self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=mink_indices)
            #     self.seen_tokens_layerwise[layer_idx] = seq_len
            #
            #     #print("self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
            #     self.L_cache[layer_idx] = torch.matmul(red_k.transpose(2, 3), red_v)
            #     self.k_sum[layer_idx] = torch.sum(red_k, dim=-2, keepdim=True)
            #     self.v_sum[layer_idx] = torch.sum(red_v, dim=-2, keepdim=True)
            #
            #     self.key_cache[layer_idx] = torch.cat([torch.gather(
            #         self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
            #         self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)
            #     #print("self.key_cache[layer_idx]", self.key_cache[layer_idx].shape)
            #
            #
            #     self.value_cache[layer_idx] = torch.cat([torch.gather(
            #         self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
            #         self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]], dim=-2)
            #
            #     self.lin_cached[layer_idx] = True

    def update_index(
            self,
            layer_idx: int,
    ):
        """
        Removes the key/value entries at selected_indices along the seq_len (T) dimension for each (B, H) pair.

        Assumes:
            selected_indices: [B, H, 1]
            self.key_cache[layer_idx]: [B, H, T, D]
            self.value_cache[layer_idx]: [B, H, T, D]
        """
        _, H_num, seq_len, D = self.key_cache[layer_idx].shape
        # print(layer_idx,"layer, before eviction", self.key_cache[layer_idx].shape)
        if seq_len > self.layer_budget[layer_idx]:
            mink_indices = self.evict_scores[layer_idx].topk((seq_len - self.layer_budget[layer_idx]), dim=-1,
                                                             largest=False).indices
            topk_indices = self.evict_scores[layer_idx].topk(
                (self.layer_budget[layer_idx] - self.window_size[layer_idx]), dim=-1).indices
            topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
            mink_indices = mink_indices.unsqueeze(-1).expand(-1, -1, -1, D)

            red_k = torch.gather(self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2,
                                 index=mink_indices)
            red_v = torch.gather(self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2,
                                 index=mink_indices)
            # print("red_k", red_k.shape)
            self.L_cache[layer_idx] = self.L_cache[layer_idx] + torch.matmul(red_k.transpose(2, 3), red_v)
            # print("self.L_cache[layer_idx]_after_add", self.L_cache[layer_idx][0, 0, 0:2, 0:2])
            red_k = torch.sum(red_k, dim=-2, keepdim=True)
            red_v = torch.sum(red_v, dim=-2, keepdim=True)
            # print("self.k_sum[layer_idx]", self.k_sum[layer_idx].shape)
            self.k_sum[layer_idx] = self.k_sum[layer_idx] + red_k
            self.v_sum[layer_idx] = self.v_sum[layer_idx] + red_v

            self.key_cache[layer_idx] = torch.cat([torch.gather(
                self.key_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                                                   self.key_cache[layer_idx][:, :, -self.window_size[layer_idx]:, :]],
                                                  dim=-2)
            self.value_cache[layer_idx] = torch.cat([torch.gather(
                self.value_cache[layer_idx][:, :, :-self.window_size[layer_idx], :], dim=-2, index=topk_indices),
                                                     self.value_cache[layer_idx][:, :, -self.window_size[layer_idx]:,
                                                     :]], dim=-2)

            #print("after cut", self.value_cache[layer_idx].shape)

            self.lin_cached[layer_idx] = True

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
                          num_hidden_layers=None) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
        backward compatibility."""
        cache = cls(num_hidden_layers)
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)