from typing import List, Optional, Tuple, Union, Any,Dict
from transformers.cache_utils import Cache, DynamicCache
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import warnings
import torch
import time
import json
import math
import os

def _norm_probs(x: torch.Tensor) -> torch.Tensor:
    x = x.to(torch.float32)
    return x / (x.sum(dim=-1, keepdim=True) + 1e-8)

def _pairwise_jsd(P_1HT: torch.Tensor) -> torch.Tensor:
    """
    P_1HT: [1,H,T] probabilities (sum over T = 1 per head). Returns [H,H] distances.
    """
    P = P_1HT[0].clamp_min(1e-8)       # [H,T]
    logP = P.log()
    a = P.unsqueeze(1)                 # [H,1,T]
    b = P.unsqueeze(0)                 # [1,H,T]
    m = 0.5 * (a + b).clamp_min(1e-8)
    js = 0.5 * ((a * (logP.unsqueeze(1) - m.log())).sum(-1) +
                (b * (logP.unsqueeze(0) - m.log())).sum(-1))
    return js.abs()

def _pairwise_l1(P_1HT: torch.Tensor) -> torch.Tensor:
    P = P_1HT[0]
    a = P.unsqueeze(1)
    b = P.unsqueeze(0)
    return (a - b).abs().sum(dim=-1)

def _pairwise_l2(P_1HT: torch.Tensor) -> torch.Tensor:
    P = P_1HT[0]
    a = P.unsqueeze(1)
    b = P.unsqueeze(0)
    return ((a - b) ** 2).sum(dim=-1).sqrt()

def _to_redundancy(mean_dist_per_head: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Convert 'distance' (bigger=more diverse) to 'redundancy' (bigger=more redundant).
    """
    inv = 1.0 / (mean_dist_per_head + eps)
    inv = inv.clamp_min(eps)
    return inv / inv.sum()

def _roundrobin_dedup(per_head_sorted_idx: List[torch.Tensor], per_head_cap: List[int], T_domain: int) -> List[torch.Tensor]:
    """
    Simple cross-head round-robin dedup on absolute token indices (0..T_domain-1).
    Inputs per_head_sorted_idx are 1×1×T tensors (descending score order in *prefix* domain).
    Returns 1×1×cap_h tensors (unique globally where possible).
    """
    H = len(per_head_sorted_idx)
    device = per_head_sorted_idx[0].device
    dtype  = per_head_sorted_idx[0].dtype
    idx    = [t.view(-1).long() for t in per_head_sorted_idx]   # each [T]
    taken  = torch.zeros((T_domain,), device=device, dtype=torch.bool)
    out    = [torch.empty((0,), device=device, dtype=dtype) for _ in range(H)]
    remain = [int(c) for c in per_head_cap]
    ptr    = [0 for _ in range(H)]
    active = True
    step   = 16

    def has_active():
        return any(remain[h] > 0 and ptr[h] < idx[h].numel() for h in range(H))

    while active and has_active():
        for h in range(H):
            if remain[h] <= 0 or ptr[h] >= idx[h].numel():
                continue
            span = min(step, idx[h].numel() - ptr[h])
            cand = idx[h][ptr[h]:ptr[h]+span]
            # pick first not-taken
            picked = None
            for c in cand:
                if (c >= 0) and (c < T_domain) and (not taken[int(c)]):
                    picked = c
                    break
            if picked is not None:
                out[h] = torch.cat([out[h], picked.view(1).to(dtype)], dim=0)
                taken[int(picked)] = True
                remain[h] -= 1
                ptr[h] += span  # advance window
            else:
                ptr[h] += span
        active = has_active()

    # fill leftovers locally
    for h in range(H):
        need = max(0, remain[h])
        if need > 0:
            chosen = set(out[h].tolist())
            extra = []
            for c in idx[h].tolist():
                if c not in chosen:
                    extra.append(c)
                    if len(extra) >= need:
                        break
            if extra:
                out[h] = torch.cat([out[h], torch.tensor(extra, device=device, dtype=dtype)], dim=0)
        out[h] = out[h].view(1,1,-1)
    return out


class DynamicCacheSplitHeadFlatten(Cache):
    def __init__(self) ->None:
        # Token wise List[]  Head wise KV List[torch.Tensor]
        super().__init__()
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = 0

    def __len__(self):
        return len(self.key_cache)

    def __iter__(self):
        for layer_idx in range(len(self)):
            yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))

    def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]:
        if layer_idx < len(self):
            return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            assert self.key_cache[layer_idx].dim() == 2
            bs, head, seqlen, dim = key_states.shape
            assert bs == 1 and seqlen == 1
            head_lens = cache_kwargs["head_lens"]
            cu_klen = cache_kwargs["cu_klen"]
            if not torch.is_tensor(head_lens):
                head = key_states.shape[1]
                head_lens = torch.full((head,), int(head_lens), dtype=torch.int32, device=key_states.device)
            else:
                head_lens = head_lens.to(dtype=torch.int32, device=key_states.device)

            if not torch.is_tensor(cu_klen):
                raise TypeError("cu_klen must be a tensor")
            else:
                cu_klen = cu_klen.to(dtype=torch.int32, device=key_states.device)



            import nvtx
            copy_old_rng = nvtx.start_range("copy old")
            from tiny_api_cuda import update_flatten_view
            new_key_cache = update_flatten_view(self.key_cache[layer_idx].view(-1,dim), key_states.view(-1, dim), head_lens, cu_klen)
            new_value_cache = update_flatten_view(self.value_cache[layer_idx].view(-1,dim), value_states.view(-1, dim), head_lens, cu_klen)

            nvtx.end_range(copy_old_rng)

            self.key_cache[layer_idx] = new_key_cache
            self.value_cache[layer_idx] = new_value_cache


        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        if len(self.key_cache) <= layer_idx:
            return 0

        return max(1, self.key_cache[layer_idx].shape[-2]) if hasattr(self.key_cache[layer_idx], "shape") else 1

    def get_max_length(self) -> Optional[int]:
        return None

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
        return legacy_cache

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCacheEachHead":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

class SnapKVCluster():
    def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool', layer_idx = None, num_hidden_layers = None, pyram_mode = False, pyram_beta = 20):
        self.window_size = window_size
        self.max_capacity_prompt = max_capacity_prompt
        assert self.max_capacity_prompt - self.window_size > 0
        self.kernel_size = kernel_size
        self.pooling = pooling

        self.pyram_init = False
        self.pyram_mode = pyram_mode
        self.pyram_beta = pyram_beta
        self.layer_idx = layer_idx
        self.num_hidden_layers = num_hidden_layers


    def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):
        self.window_size = window_size
        self.max_capacity_prompt = max_capacity_prompt
        assert self.max_capacity_prompt - self.window_size > 0
        self.kernel_size = kernel_size
        self.pooling = pooling

    def update_kv(self, key_states, query_states, value_states):
        # check if prefix phase
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape

        # compute pyramidal capacity
        if self.pyram_mode and not self.pyram_init:
            # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity
            base_capacity = self.max_capacity_prompt - self.window_size
            min_num = base_capacity // self.pyram_beta
            max_num = base_capacity * 2 - min_num
                
            # if the max_num is larger than the query length, we need to adjust the max_num
            if max_num >= q_len - self.window_size:
                max_num = q_len - self.window_size
                min_num = base_capacity * 2 - max_num
        
            # NOTE: compute interval
            steps = (max_num - min_num) // (self.num_hidden_layers - 1)

            self.max_capacity_prompt = max_num - self.layer_idx * steps + self.window_size
            self.pyram_init = True
            print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, max_capacity_prompt: {self.max_capacity_prompt}, base_capacity: {self.max_capacity_prompt - self.window_size}")

        if q_len < self.max_capacity_prompt:
            return key_states, value_states
        else:
            attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
            mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
            mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
            mask = mask.to(attn_weights.device)
            attention_mask = mask[None, None, :, :]

            attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask

            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices
            indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
            v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
            k_cur = key_states[:, :, -self.window_size:, :]
            v_cur = value_states[:, :, -self.window_size:, :]
            key_states = torch.cat([k_past_compress, k_cur], dim = 2)
            value_states = torch.cat([v_past_compress, v_cur], dim = 2)
            return key_states, value_states

# --------------------------
# Utilities for RABA / Obs
# --------------------------
def _norm_probs(x: torch.Tensor) -> torch.Tensor:
    x = x.to(torch.float32)
    return x / (x.sum(dim=-1, keepdim=True) + 1e-8)

def _pairwise_jsd(P_1HT: torch.Tensor) -> torch.Tensor:
    """
    P_1HT: [1,H,T] probabilities (sum over T = 1 per head). Returns [H,H] distances.
    """
    P = P_1HT[0].clamp_min(1e-8)       # [H,T]
    logP = P.log()
    a = P.unsqueeze(1)                 # [H,1,T]
    b = P.unsqueeze(0)                 # [1,H,T]
    m = 0.5 * (a + b).clamp_min(1e-8)
    js = 0.5 * ((a * (logP.unsqueeze(1) - m.log())).sum(-1) +
                (b * (logP.unsqueeze(0) - m.log())).sum(-1))
    return js.abs()

def _pairwise_l1(P_1HT: torch.Tensor) -> torch.Tensor:
    P = P_1HT[0]
    a = P.unsqueeze(1)
    b = P.unsqueeze(0)
    return (a - b).abs().sum(dim=-1)

def _pairwise_l2(P_1HT: torch.Tensor) -> torch.Tensor:
    P = P_1HT[0]
    a = P.unsqueeze(1)
    b = P.unsqueeze(0)
    return ((a - b) ** 2).sum(dim=-1).sqrt()

def _to_redundancy(mean_dist_per_head: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    inv = 1.0 / (mean_dist_per_head + eps)
    inv = inv.clamp_min(eps)
    return inv / inv.sum()

def _roundrobin_dedup(per_head_sorted_idx: List[torch.Tensor], per_head_cap: List[int], T_domain: int) -> List[torch.Tensor]:
    H = len(per_head_sorted_idx)
    device = per_head_sorted_idx[0].device
    dtype  = per_head_sorted_idx[0].dtype
    idx    = [t.view(-1).long() for t in per_head_sorted_idx]   # each [T]
    taken  = torch.zeros((T_domain,), device=device, dtype=torch.bool)
    out    = [torch.empty((0,), device=device, dtype=dtype) for _ in range(H)]
    remain = [int(c) for c in per_head_cap]
    ptr    = [0 for _ in range(H)]
    active = True
    step   = 16

    def has_active():
        return any(remain[h] > 0 and ptr[h] < idx[h].numel() for h in range(H))

    while active and has_active():
        for h in range(H):
            if remain[h] <= 0 or ptr[h] >= idx[h].numel():
                continue
            span = min(step, idx[h].numel() - ptr[h])
            cand = idx[h][ptr[h]:ptr[h]+span]
            # pick first not-taken
            picked = None
            for c in cand:
                if (c >= 0) and (c < T_domain) and (not taken[int(c)]):
                    picked = c
                    break
            if picked is not None:
                out[h] = torch.cat([out[h], picked.view(1).to(dtype)], dim=0)
                taken[int(picked)] = True
                remain[h] -= 1
                ptr[h] += span  # advance window
            else:
                ptr[h] += span
        active = has_active()

    # fill leftovers locally
    for h in range(H):
        need = max(0, remain[h])
        if need > 0:
            chosen = set(out[h].tolist())
            extra = []
            for c in idx[h].tolist():
                if c not in chosen:
                    extra.append(c)
                    if len(extra) >= need:
                        break
            if extra:
                out[h] = torch.cat([out[h], torch.tensor(extra, device=device, dtype=dtype)], dim=0)
        out[h] = out[h].view(1,1,-1)
    return out


class AdaptiveSnapKVCluster:

    def __init__(
        self,
        window_size: int = 8,
        kernel_size: int = 7,
        pooling: str = "maxpool",
        base_capacity: int | None = None,   # top-k part only
        floor: float | None = None,
        skip: int | None = None,
        normalize: bool | None = None,
        layer_idx: int | None = None,
        num_hidden_layers: int | None = None,
        # QD
        query_scaling: str | None = None,
        query_scaling_lambda: float | None = None,
        # RABA
        redundancy_mode: str | None = None,
        # Extras
        cross_head_dedup: bool = True,
    ):
        assert base_capacity is not None, "base_capacity (top-k part) must be set"
        self.window_size = int(window_size)
        self.kernel_size = int(kernel_size)
        self.pooling     = pooling

        self.budget_capacity_total = int(base_capacity) + self.window_size  # display only
        self.base_capacity   = int(base_capacity)                            # ★ top-k part
        self.floor_ratio     = 0.0 if (floor is None) else float(floor)
        self.floor_capacity  = int(self.base_capacity * self.floor_ratio)
        self.adaptive_capacity = self.base_capacity - self.floor_capacity

        self.skip_layer_nums = int(skip or 1000)
        self.normalize = bool(normalize)
        self.layer_idx = int(layer_idx or 0)
        self.num_hidden_layers = int(num_hidden_layers or 0)

        # varlen flash-attn meta
        self.head_lens = None
        self.max_seqlen_k = 0
        self.klen_sum = 0
        self.cu_klen = 0
        self.cu_offset = None
        self.cu_headlens = None

        # QD / RABA
        self.query_scaling_mode   = query_scaling
        self.query_scaling_lambda = 0.0 if query_scaling_lambda is None else float(query_scaling_lambda)
        self.redundancy_mode      = redundancy_mode.lower() if redundancy_mode else None

        # OBS
        self._last_attn_weights_softmax = None  # [1,H,W,L] (we'll pad to total L)
        self.obs_ana = None
        self.method_label = None

        self.cross_head_dedup = bool(cross_head_dedup)

    # ------------------------------ scoring over prefix (no sink) ------------------------------
    def calcul_attn_sore(self, key_states, query_states):
        """
        observation_window only.
        - Queries: last W
        - Keys   : prefix [0 : L-W]  (exclude the same window from candidates)
        Returns pooled scores S: [1,H,Kp], Kp=L-W (same-length padding pooling)
        Also caches per-query softmax over FULL length [1,H,W,L] for Obs2'.
        """
        bsz, H, L, D = query_states.shape
        W = min(self.window_size, L)
        Kp = max(0, L - W)

        # slice
        Q = query_states[..., L - W :, :] if W > 0 else query_states[..., :0, :]
        K = key_states[..., : Kp, :]

        if W == 0 or Kp == 0:
            # cache a zero tensor of shape [1,H,W,L] for Obs2'
            if W > 0:
                self._last_attn_weights_softmax = torch.zeros((bsz, H, W, L), dtype=query_states.dtype, device=query_states.device)
            else:
                self._last_attn_weights_softmax = None
            return torch.zeros((bsz, H, 0), dtype=query_states.dtype, device=query_states.device)

        # QD (optional)
        if (self.query_scaling_mode == "orthogonal") and (self.query_scaling_lambda > 0.0):
            u = Q.mean(dim=2, keepdim=True)                               # [B,H,1,D]
            u_hat = u / (u.norm(p=2, dim=-1, keepdim=True) + 1e-6)        # [B,H,1,D]
            alpha = (Q * u_hat).sum(dim=-1, keepdim=True)                 # [B,H,W,1]
            Q = Q + float(self.query_scaling_lambda) * (Q - alpha * u_hat)

        attn_raw  = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(D)     # [B,H,W,Kp]
        attn_soft = torch.softmax(attn_raw.to(torch.float32), dim=-1).to(Q.dtype)  # [B,H,W,Kp]


        pad_w = W
        attn_full = F.pad(attn_soft, (0, pad_w, 0, 0, 0, 0, 0, 0), value=0.0)  # -> [B,H,W,Kp+W=L]
        self._last_attn_weights_softmax = attn_full  # [B,H,W,L]

        # mean over queries in window -> [B,H,Kp]
        S = attn_soft.mean(dim=-2)

        # 1D pooling (same length)
        if self.pooling == "avgpool":
            S = F.avg_pool1d(S, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
        elif self.pooling == "maxpool":
            S = F.max_pool1d(S, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
        else:
            raise ValueError("Unsupported pooling")
        return S  # [B,H,Kp]

    # ------------------------------ update_kv (no sink, observation_window only) ------------------------------
    def update_kv(self, key_states, query_states, value_states):
        """
        Returns:
          heads_key_states:   [sum_h Lh, D]
          heads_value_states: [sum_h Lh, D]
          head_budget_topk:   torch.IntTensor [H]    (per-head top-k counts)
          sorted_indices_per_head: List[Tensor(H*[Kp])]
          topk_indices_per_head:   List[Tensor(H*[k_h])]
          selected_abs_indices_per_head: List[Tensor(H*[Lh])]  (absolute indices used for cache)
        Side effects:
          - sets varlen meta (head_lens, cu_klen, ...)
          - logs Obs1/2/2’/3 if analyzer is present
        """
        device = key_states.device
        bsz, H, L, D = key_states.shape
        assert bsz == 1, "batch must be 1 here"
        W  = min(self.window_size, L)
        Kp = max(0, L - W)

        # 1) scores on prefix
        attn_score = self.calcul_attn_sore(key_states, query_states)  # [1,H,Kp]

        # trivial case: no prefix
        if Kp == 0:
            head_budget_topk = torch.zeros(H, dtype=torch.int32, device=device)
            per_head_idx_topk = [torch.empty(0, dtype=torch.long, device=device) for _ in range(H)]
            sorted_indices_per_head = [torch.empty(0, dtype=torch.long, device=device) for _ in range(H)]
        else:
            # sort on prefix domain [0..Kp-1]
            sorted_score, sorted_idx = attn_score.sort(dim=-1, descending=True)      # [1,H,Kp]
            sorted_indices_per_head = [sorted_idx[:, h, :].squeeze(0) for h in range(H)]  # H * [Kp]

            # head budgets for top-k (no window inside)
            bH = self.base_capacity * H
            if (self.redundancy_mode is not None) and (self.layer_idx >= self.skip_layer_nums):
                # global competition counts
                flat = attn_score.reshape(1, H * Kp)
                top_heads = torch.topk(flat, k=min(bH, flat.size(-1)), dim=-1, largest=True, sorted=False).indices // Kp
                C = torch.zeros((1, H), device=device, dtype=torch.float32)
                C.scatter_add_(1, top_heads, torch.ones_like(top_heads, dtype=torch.float32))

                # redundancy weights
                probs = _norm_probs(attn_score)                              # [1,H,Kp]
                w_red = self._redundancy_weights(probs)                      # [H]

                I = C * (1.0 - w_red.unsqueeze(0)).clamp_min(0.0)            # [1,H]
                I_sum = I.sum(dim=1, keepdim=True) + 1e-8
                I_tilde = I / I_sum

                per_head_floor = int(self.floor_capacity)
                Ktot_floor = H * per_head_floor
                Ktot_rem   = max(0, bH - Ktot_floor)

                cap_float = I_tilde * float(Ktot_rem)
                cap_rem   = torch.floor(cap_float).to(torch.int32)
                diff = int(Ktot_rem) - int(cap_rem.sum().item())
                if diff > 0:
                    frac = (cap_float - cap_rem.to(cap_float.dtype))[0]
                    add_order = torch.argsort(frac, descending=True)
                    for i in range(diff):
                        cap_rem[0, int(add_order[i % H])] += 1

                head_adaptive_capacity = cap_rem + per_head_floor            # [1,H]
            else:
                head_adaptive_capacity = torch.full((1, H), int(self.base_capacity), device=device, dtype=torch.int32)

            # cap by available prefix length
            head_adaptive_capacity = torch.clamp(head_adaptive_capacity, max=Kp)
            head_budget_topk = head_adaptive_capacity[0].to(torch.int32)     # [H]

            # build per-head indices
            if self.cross_head_dedup:
                blocks = [sorted_indices_per_head[h].view(1, 1, -1) for h in range(H)]
                caps   = [int(c) for c in head_budget_topk.tolist()]
                idx3d  = _roundrobin_dedup(blocks, caps, T_domain=Kp)        # H * [1,1,k_h]
                per_head_idx_topk = [t.view(-1) for t in idx3d]
            else:
                per_head_idx_topk = [sorted_indices_per_head[h][: int(head_budget_topk[h].item())] for h in range(H)]

        # 2) gather K/V: top-k (prefix absolute) + window
        heads_key_states, heads_value_states = [], []
        selected_abs_indices_per_head = []
        k_lens, klen_sum, max_seqlen_k = [], 0, 0

        origin_heads_key_states   = torch.split(key_states, 1, dim=1)   # H * [1,1,L,D]
        origin_heads_value_states = torch.split(value_states, 1, dim=1)

        win_start = max(0, L - W)
        win_idx_abs = torch.arange(win_start, L, device=device, dtype=torch.long)  # [W]

        for h in range(H):
            # absolute prefix indices are identical to prefix domain here
            idx_topk_abs = per_head_idx_topk[h]                                  # [k_h] & each < Kp
            hd = D

            head_K = origin_heads_key_states[h]
            head_V = origin_heads_value_states[h]

            if idx_topk_abs.numel() > 0:
                gidx = idx_topk_abs.view(1, 1, -1, 1).expand(1, 1, -1, hd)      # [1,1,k_h,D]
                topK = head_K.gather(2, gidx)
                topV = head_V.gather(2, gidx)
            else:
                topK = head_K[:, :, :0, :]
                topV = head_V[:, :, :0, :]

            win_K = head_K[:, :, win_start:, :]                                  # [1,1,W,D]
            win_V = head_V[:, :, win_start:, :]

            selected_k = torch.cat([topK, win_K], dim=2)                         # [1,1,k_h+W,D]
            selected_v = torch.cat([topV, win_V], dim=2)
            selected_abs = torch.cat([idx_topk_abs, win_idx_abs], dim=0)         # [k_h+W]

            Lh = int(selected_k.shape[2])
            k_lens.append(Lh); klen_sum += Lh; max_seqlen_k = max(max_seqlen_k, Lh)

            heads_key_states.append(selected_k.view(-1, D))
            heads_value_states.append(selected_v.view(-1, D))
            selected_abs_indices_per_head.append(selected_abs)

        # 3) varlen meta
        self.head_lens = torch.tensor(k_lens, dtype=torch.int32, device=device)  # [H]
        self.klen_sum = int(klen_sum)
        self.max_seqlen_k = int(max_seqlen_k)
        self.cu_headlens = torch.cumsum(self.head_lens, dim=0, dtype=torch.int32)
        self.cu_klen = self.cu_headlens - self.head_lens
        self.cu_klen = torch.cat([self.cu_klen, torch.tensor([self.klen_sum], dtype=torch.int32, device=device)], dim=0)
        self.layer_qlens = torch.ones(H, dtype=torch.int32, device=device)
        self.qlen_sum = H
        self.cu_qlen = torch.cumsum(self.layer_qlens, dim=0, dtype=torch.int32) - self.layer_qlens
        self.cu_qlen = torch.cat([self.cu_qlen, torch.tensor([self.qlen_sum], dtype=torch.int32, device=device)], dim=0)
        self.cu_offset = torch.arange(0, H + 1, dtype=torch.int32, device=device)
        self.cu_head_offset = torch.arange(1, H + 1, dtype=torch.int32, device=device)

        heads_key_states   = torch.cat(heads_key_states, dim=0)     # [sum_h Lh, D]
        heads_value_states = torch.cat(heads_value_states, dim=0)

        # 4) Obs hooks
        head_budget_topk_out = torch.tensor(
            [int(x.numel()) for x in per_head_idx_topk],
            dtype=torch.int32, device=device
        ) if Kp > 0 else torch.zeros(H, dtype=torch.int32, device=device)

        ana = getattr(self, "obs_ana", None)
        if ana is not None:
            head_budget_eff = self.head_lens  # topk + window
            method_label = int(getattr(self, "method_label", 1))

            max_ph = int(getattr(self.config, "obs1_max_per_head", 200))
            ana.log_obs1(
                layer_idx=int(getattr(self, "layer_idx", 0)),
                key_states_compress=heads_key_states,
                value_states_compress=heads_value_states,
                head_budget_eff=head_budget_eff,
                method_label=method_label,
                max_per_head=max_ph,   # ← config에서 읽어옴
            )

            ana.update_obs2(
                layer_idx=int(getattr(self, "layer_idx", 0)),
                head_budget_eff=head_budget_eff,
                sorted_indices_per_head=sorted_indices_per_head,
                query_states=query_states[:, :, -1, :].squeeze(0),  # [H,d]
            )

            if self._last_attn_weights_softmax is not None:
                ana.update_obs2_querywise(
                    layer_idx=int(getattr(self, "layer_idx", 0)),
                    attn_weights_softmax=self._last_attn_weights_softmax,  # [1,H,W,L]
                    topk_per_head=head_budget_topk_out,                    # top-k only
                    exclude_recent_window=True,                           
                )

            K_full_heads = torch.stack([origin_heads_key_states[h].squeeze(0).squeeze(0) for h in range(H)], dim=0)
            V_full_heads = torch.stack([origin_heads_value_states[h].squeeze(0).squeeze(0) for h in range(H)], dim=0)
            q_heads      = query_states[:, :, -1, :].squeeze(0)  # [H,d]
            selected_abs_np = [t.detach().cpu().numpy() for t in selected_abs_indices_per_head]
            ana.update_obs3_aae(q_heads, K_full_heads, V_full_heads, selected_abs_np)
            ana.update_obs3_extra(q_heads, K_full_heads, selected_abs_np)

        return (
            heads_key_states,
            heads_value_states,
            head_budget_topk_out,             # per-head top-k
            sorted_indices_per_head,          # full order on prefix
            per_head_idx_topk,                # actual top-k on prefix
            selected_abs_indices_per_head,    # absolute (top-k + window)
        )

    # ---- RABA weights (unchanged) ----
    def _redundancy_weights(self, attn_probs_1hT: torch.Tensor) -> torch.Tensor:
        if (self.redundancy_mode is None) or (self.redundancy_mode == "none"):
            H = attn_probs_1hT.shape[1]
            return torch.ones(H, device=attn_probs_1hT.device) / max(1, H)
        mode = self.redundancy_mode
        if mode == "jsd":
            M = _pairwise_jsd(attn_probs_1hT)
        elif mode == "l1":
            M = _pairwise_l1(attn_probs_1hT)
        elif mode == "l2":
            M = _pairwise_l2(attn_probs_1hT)
        else:
            H = attn_probs_1hT.shape[1]
            return torch.ones(H, device=attn_probs_1hT.device) / max(1, H)
        mean_dist = M.mean(dim=1)
        return _to_redundancy(mean_dist)


class ReasonSnapKVCluster():
    def __init__(self, window_size = 32, kernel_size = 7, pooling = 'maxpool',base_capacity=None, head_choice=None, beta=None, temp=None, layer_idx = None, num_hidden_layers = None, num_attention_heads=None, model=None):
        self.window_size = window_size
        self.kernel_size = kernel_size
        self.pooling = pooling
        self.base_capacity = base_capacity - window_size
        self.beta = beta
        self.temp = temp

        self.layer_idx = layer_idx
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads

        self.head_lens = None
        self.max_seqlen_k = 0
        self.klen_sum = 0
        self.cu_klen = 0
        self.cu_offset = None
        self.cu_headlens = None

        root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        if head_choice == 'random':
            raise ValueError
        elif head_choice == 'copy':
            if 'llama' in model.lower():
                path = f'{root_path}/Important_Head/head_score/Meta-Llama-3-8B-Instruct_retrieval_heads.json'
            elif 'mistral' in model.lower():
                path = f'{root_path}/Important_Head/head_score/Mistral-7B-Instruct-v0.2_retrieval_heads.json'
            else:
                raise ValueError
        elif head_choice == 'reason':
            if 'llama' in model.lower():
                path = f'{root_path}/Important_Head/head_score/Meta-Llama-3-8B-Instruct_retrieval_reasoning_heads.json'
            elif 'mistral' in model.lower():
                path = f'{root_path}/Important_Head/head_score/Mistral-7B-Instruct-v0.2_retrieval_reasoning_heads.json'
            else:
                raise ValueError
        with open(path, 'r') as file:
            head_list = json.loads(file.readline())
        head_score_list = [np.mean(l[1]) for l in head_list.items()]
        head_score_list = torch.tensor(head_score_list / sum(head_score_list))
        head_score_list = torch.pow(head_score_list, self.temp)
        head_score_list = head_score_list / torch.sum(head_score_list)
        self.total_attention = head_score_list.reshape(self.num_hidden_layers, self.num_attention_heads)

        total_pool_capacity = (self.base_capacity // self.beta) * self.num_hidden_layers * self.num_attention_heads
        min_num = (self.base_capacity - self.base_capacity // self.beta)
        self.head_capacity = torch.round(self.total_attention * total_pool_capacity + min_num).int()
            
    def calcul_attn_sore(self, key_states, query_states):
        bsz, num_heads, q_len, head_dim = query_states.shape
        attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(
            head_dim)
        mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min,
                          device=attn_weights.device)
        mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
        mask = mask.to(attn_weights.device)
        attention_mask = mask[None, None, :, :]

        attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights_mean = attn_weights[:, :, -self.window_size:, : -self.window_size].mean(dim=-2)
        if self.pooling == 'avgpool':
            attn_weights_mean_pooling = F.avg_pool1d(attn_weights_mean, kernel_size=self.kernel_size,
                                                     padding=self.kernel_size // 2,
                                                     stride=1)
        elif self.pooling == 'maxpool':
            attn_weights_mean_pooling = F.max_pool1d(attn_weights_mean, kernel_size=self.kernel_size,
                                                     padding=self.kernel_size // 2,
                                                     stride=1)
        else:
            raise ValueError('Pooling method not supported')
        return attn_weights_mean_pooling

    def update_kv(self,  key_states, query_states, value_states):

        _device = key_states.device
        bsz, num_heads, q_len, head_dim = query_states.shape
        attn_score= self.calcul_attn_sore(key_states,query_states)
        origin_heads_key_states = torch.split(key_states, 1, dim=1)
        origin_heads_value_states = torch.split(value_states, 1, dim=1)

        def init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k):
            # init metadata
            self.head_lens = torch.tensor(k_lens, dtype=torch.int32, device=_device)
            self.klen_sum = klen_sum
            self.max_seqlen_k = max_seqlen_k
            self.cu_headlens = torch.cumsum(self.head_lens, dim=0, dtype=torch.int32)
            # init varlen flash attention metadata
            self.cu_klen = self.cu_headlens - self.head_lens
            self.cu_klen = torch.cat(
                [self.cu_klen, torch.tensor([self.klen_sum], dtype=torch.int32, device=_device)], dim=0)
            self.layer_qlens = torch.ones(num_heads, dtype=torch.int32,device=_device)
            self.qlen_sum = num_heads
            self.cu_qlen = torch.cumsum(self.layer_qlens, dim=0, dtype=torch.int32) - self.layer_qlens
            self.cu_qlen = torch.cat(
                [self.cu_qlen, torch.tensor([self.qlen_sum], dtype=torch.int32, device=_device)], dim=0)
            self.cu_offset = torch.arange(0, num_heads + 1, dtype=torch.int32, device=_device)
            self.cu_head_offset = torch.arange(1, num_heads+1, dtype=torch.int32, device=_device)

        if self.base_capacity > attn_score.size(-1):
            init_metadata(num_heads, [q_len] * num_heads, q_len * num_heads, q_len)
            # not compress
            return key_states.reshape(-1, head_dim), value_states.reshape(-1, head_dim)

        # if you need to weight the attn_score
        _,indices = attn_score.sort(dim=-1,descending=True)

        indices = indices.split(1,dim=1)

        heads_key_states = []
        heads_value_states = []
        assert bsz == 1

        # per head
        # reinit varlen metadata
        k_lens = []
        klen_sum = 0
        max_seqlen_k = 0
        self.cu_klen = 0


        for head_idx in range(num_heads):
            cache_index = indices[head_idx][...,:self.head_capacity[self.layer_idx][head_idx]]

            l = cache_index.shape[-1] + self.window_size
            k_lens.append(l)
            max_seqlen_k = max(max_seqlen_k, l)
            klen_sum += l

            cache_index = cache_index.view(1, 1, -1, 1).expand(-1, -1, -1, head_dim)
            top_Kcache = origin_heads_key_states[head_idx].gather(dim=2,index=cache_index)
            top_Vcache = origin_heads_value_states[head_idx].gather(dim=2,index=cache_index)
            selected_k = torch.cat([top_Kcache,origin_heads_key_states[head_idx][:, :, -self.window_size:, :]],dim=2)
            selected_v = torch.cat([top_Vcache,origin_heads_value_states[head_idx][:, :, -self.window_size:, :]],dim=2)

            # NOTE: flatten view
            heads_key_states.append(selected_k.view(-1, head_dim))
            heads_value_states.append(selected_v.view(-1, head_dim))

        init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k)

        # NOTE: compose as flatten view
        heads_key_states = torch.cat(heads_key_states, dim=0)
        heads_value_states = torch.cat(heads_value_states, dim=0)

        return heads_key_states,heads_value_states


def init_pyramidkv(self):
    assert hasattr(self.config, 'window_size'), "window_size not set"
    assert hasattr(self.config, 'kernel_size'), "kernel_size not set"
    assert hasattr(self.config, "pooling"), "pooling not set"
    assert hasattr(self.config, "base_capacity"), "base_capacity not set"
    if not hasattr(self.config, "pyram_beta"):
        self.config.pyram_beta = 20
    # init only once
    if not hasattr(self, "kv_cluster"):
        self.kv_cluster = SnapKVCluster(
            window_size = self.config.window_size, 
            max_capacity_prompt = self.config.base_capacity,
            kernel_size = self.config.kernel_size,
            pooling = self.config.pooling,
            layer_idx = self.layer_idx,
            num_hidden_layers = self.config.num_hidden_layers,
            pyram_mode = self.config.pyram_beta,
            )

def init_snapkv(self):

    assert hasattr(self.config, 'window_size'), "window_size not set"
    assert hasattr(self.config, 'kernel_size'), "kernel_size not set"
    assert hasattr(self.config, "pooling"), "pooling not set"
    assert hasattr(self.config, "base_capacity"), "base_capacity not set"
    # init only once
    if not hasattr(self, "kv_cluster"):
        self.kv_cluster = SnapKVCluster(
            window_size = self.config.window_size, 
            max_capacity_prompt = self.config.base_capacity,
            kernel_size = self.config.kernel_size,
            pooling = self.config.pooling,

            layer_idx = self.layer_idx,
            num_hidden_layers = self.config.num_hidden_layers,
            )
        print(f"Compress config(Snap): window_size={self.kv_cluster.window_size}, max_capacity_prompt={self.kv_cluster.max_capacity_prompt}, kernel_size={self.kv_cluster.kernel_size}, pooling={self.kv_cluster.pooling}")

def init_reason_snapkv(self):
    assert hasattr(self.config,'window_size'),"window_size not set"
    assert hasattr(self.config,'kernel_size'),"kernel_size not set"
    assert hasattr(self.config,"pooling"),"pooling not set"
    assert hasattr(self.config, "base_capacity"), "base_capacity not set"
    assert hasattr(self.config, 'head_choice'), "head_choice not set"
    assert hasattr(self.config, 'beta'), "beta not set"
    assert hasattr(self.config, 'temp'), 'temp not set'

    # init only once
    if not hasattr(self, "kv_cluster"):
        self.kv_cluster = ReasonSnapKVCluster(
            window_size = self.config.window_size,
            base_capacity=self.config.base_capacity,
            head_choice=self.config.head_choice,
            beta=self.config.beta,
            temp=self.config.temp,
            kernel_size = self.config.kernel_size,
            pooling = self.config.pooling,
            layer_idx = self.layer_idx,
            num_hidden_layers = self.config.num_hidden_layers,
            num_attention_heads=self.config.num_attention_heads,
            model=self.config._name_or_path
            )

def init_headkv(self):
    assert hasattr(self.config, 'window_size'), "window_size not set"
    assert hasattr(self.config, 'kernel_size'), "kernel_size not set"
    assert hasattr(self.config, 'pooling'),     "pooling not set"
    assert hasattr(self.config, 'base_capacity'), "base_capacity not set"


    base_capacity = int(self.config.base_capacity)
    floor_ratio_config = getattr(self.config, "floor_ratio", None)
    floor_capacity_cfg  = getattr(self.config, "floor_capacity", None)
    if floor_ratio_config is not None:
        floor_ratio = float(floor_ratio_config)
    elif floor_capacity_cfg is not None and base_capacity > 0:
        floor_ratio = float(floor_capacity_cfg) / float(base_capacity)
    else:
        floor_ratio = float(getattr(self.config, "floor", 0.0))


    skip_layers = int(getattr(self.config, "skip_layer_nums", 1000))
    normalize   = bool(getattr(self.config, "normalize", True))
    num_layers  = int(getattr(self.config, "num_hidden_layers", 0))
    pooling     = str(getattr(self.config, "pooling", "maxpool"))
    kernel_size = int(getattr(self.config, "kernel_size", 7))
    window_size = int(getattr(self.config, "window_size", 8))
    sink_tokens = int(getattr(self.config, "sink", 0))        # 없으면 0
    layer_idx   = int(getattr(self, "layer_idx", 0))


    qd_enable       = bool(getattr(self.config, "qd_enable", False))
    qd_mode         = str(getattr(self.config, "qd_mode", "roundrobin"))
    qd_lambda       = float(getattr(self.config, "qd_lambda", 0.2))
    redundancy_mode = str(getattr(self.config, "redundancy_mode", "none"))  
    random_baseline = bool(getattr(self.config, "random_baseline", False))


    if not hasattr(self, "kv_cluster"):
        self.kv_cluster = AdaptiveSnapKVCluster(

            window_size=window_size,
            base_capacity=base_capacity,
            kernel_size=kernel_size,
            pooling=pooling,

            floor=floor_ratio,                  
            skip=skip_layers,
            normalize=normalize,
            layer_idx=layer_idx,
            num_hidden_layers=num_layers,


            sink=sink_tokens,

            # === QD / RABA / Random ===
            qd_enable=qd_enable,
            qd_mode=qd_mode,
            qd_lambda=qd_lambda,
            redundancy_mode=redundancy_mode,
            random_baseline=random_baseline,
        )

        if hasattr(self, "obs_ana"):
            self.kv_cluster.obs_ana = self.obs_ana
        self.kv_cluster.config = getattr(self, "config", None)




def _init_kv_cluster_from_config(self):
    import math
    cfg = getattr(self, "config", None)
    assert cfg is not None, "[HEADKV] missing self.config"

    base_capacity_total = int(getattr(cfg, "base_capacity"))   
    window_size         = int(getattr(cfg, "window_size", 8))  
    kernel_size         = int(getattr(cfg, "kernel_size", 7))
    pooling             = str(getattr(cfg, "pooling", "maxpool"))
    normalize           = bool(getattr(cfg, "normalize", True))

    floor_ratio_cfg = getattr(cfg, "floor_ratio", None)
    floor_cap_cfg   = getattr(cfg, "floor_capacity", None)
    if floor_ratio_cfg is not None:
        floor_ratio = float(floor_ratio_cfg)
    elif floor_cap_cfg is not None and base_capacity_total > 0:
        floor_ratio = float(floor_cap_cfg) / float(base_capacity_total)
    else:
        floor_ratio = float(getattr(cfg, "floor", 0.0))

    skip_layers = int(getattr(cfg, "skip_layer_nums", 1000))
    num_layers  = int(getattr(cfg, "num_hidden_layers", 0))
    layer_idx   = int(getattr(self, "layer_idx", 0))

    # ---- QD / RABA ----
    qd_enable     = bool(getattr(cfg, "qd_enable", False))
    qd_lambda     = float(getattr(cfg, "qd_lambda", 0.0))
    query_scaling = "orthogonal" if qd_enable and qd_lambda > 0 else None

    redundancy_mode = str(getattr(cfg, "redundancy_mode", "none")).lower()
    if redundancy_mode == "none":
        redundancy_mode = None

    base_capacity_topk = max(0, base_capacity_total - window_size)

    created = False
    if not hasattr(self, "kv_cluster"):
        from headkv.snapkv_utils import AdaptiveSnapKVCluster  
        self.kv_cluster = AdaptiveSnapKVCluster(
            window_size=window_size,
            kernel_size=kernel_size,
            pooling=pooling,
            base_capacity=base_capacity_topk,    
            floor=floor_ratio,                   
            skip=skip_layers,
            normalize=normalize,
            layer_idx=layer_idx,
            num_hidden_layers=num_layers,
            # QD
            query_scaling=query_scaling,
            query_scaling_lambda=qd_lambda,
            # RABA
            redundancy_mode=redundancy_mode,
            # no sink, no query_mode
            cross_head_dedup=True,
        )
        created = True
    else:
        kc = self.kv_cluster
        kc.window_size = window_size
        kc.kernel_size = kernel_size
        kc.pooling     = pooling

        kc.budget_capacity_total = base_capacity_topk + window_size  
        kc.base_capacity = int(base_capacity_topk)                  
        kc.floor_ratio   = float(floor_ratio)
        kc.floor_capacity    = int(kc.base_capacity * kc.floor_ratio)
        kc.adaptive_capacity = kc.base_capacity - kc.floor_capacity

        kc.skip_layer_nums   = skip_layers
        kc.normalize         = normalize
        kc.layer_idx         = layer_idx
        kc.num_hidden_layers = num_layers

        kc.query_scaling_mode   = query_scaling
        kc.query_scaling_lambda = float(qd_lambda)
        kc.redundancy_mode      = redundancy_mode



    ana = getattr(cfg, "_obs_ana", None)
    self.kv_cluster.obs_ana      = ana
    self.kv_cluster.method_label = int(getattr(cfg, "obs_method_label", 1))
    self.kv_cluster.config       = cfg

    if created:
        print(f"[HEADKV] init L{layer_idx}: window={window_size} "
              f"cap_total={base_capacity_topk + window_size} (topk={base_capacity_topk}+window={window_size}), "
              f"qd_lambda={qd_lambda} red={redundancy_mode}")
