import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Union
from dataclasses import dataclass, field
import math


@dataclass
class TASMCompressorConfig:
    
    task_vector_method: str = "combined" 
    task_vector_weight: float = 0.3
    
    enable_merging: bool = False
    merge_similarity_threshold: float = 0.5
    preserve_spatial: bool = True
    spatial_window: int = 3
    
    enable_dynamic_retrieval: bool = True
    core_ratio: float = 0.2
    latent_ratio: float = 0.4
    retrieval_top_k: int = 96
    js_threshold: float = 0.002

    target_compression_ratio: float = 0.35 
    sink_tokens: int = 4
    local_tokens: int = 96 
    layer_adaptive: bool = True 
    

class TaskVectorComputer(nn.Module):
    def __init__(self, hidden_size: int, num_layers: int, method: str = "combined"):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.method = method
        
        self.task_projector = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )
        
        self.task_vectors: Dict[int, torch.Tensor] = {}
        
    def extract_from_qa_pairs(
        self,
        question_hidden: torch.Tensor,
        answer_hidden: torch.Tensor,
        layer_idx: int,
    ) -> torch.Tensor:
        q_repr = question_hidden.mean(dim=1)
        a_repr = answer_hidden.mean(dim=1)
        
        if self.method == "embedding_diff":
            task_vec = (a_repr - q_repr).mean(dim=0)
        elif self.method == "head_activation":
            q_proj = self.task_projector(q_repr)
            a_proj = self.task_projector(a_repr)
            task_vec = (a_proj - q_proj).mean(dim=0)
        else:
            diff = (a_repr - q_repr).mean(dim=0)
            q_proj = self.task_projector(q_repr)
            a_proj = self.task_projector(a_repr)
            proj_diff = (a_proj - q_proj).mean(dim=0)
            task_vec = 0.5 * diff + 0.5 * proj_diff
        
        task_vec = F.normalize(task_vec, dim=-1)
        self.task_vectors[layer_idx] = task_vec
        
        return task_vec

    def compute_importance_scores(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        attention_scores: Optional[torch.Tensor] = None,
        num_layers: int = 28,
        visual_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        device = key_states.device
        dtype = key_states.dtype
        
        task_vec = self.task_vectors.get(layer_idx)
        
        if task_vec is None:
            if attention_scores is not None:
                return attention_scores
            return torch.ones(batch_size, num_heads, seq_len, device=device, dtype=dtype)
        
        task_vec = task_vec.to(device=device, dtype=dtype)
        
        task_vec_size = task_vec.shape[0]
        
        if task_vec_size == num_heads * head_dim:
            task_vec_heads = task_vec.view(num_heads, head_dim)
        else:
            full_num_heads = task_vec_size // head_dim
            task_vec_full = task_vec.view(full_num_heads, head_dim)
            group_size = full_num_heads // num_heads
            if group_size > 0:
                task_vec_heads = task_vec_full.view(num_heads, group_size, head_dim).mean(dim=1)
            else:
                task_vec_heads = task_vec_full[:num_heads]
        
        task_vec_heads = F.normalize(task_vec_heads, dim=-1)
        key_norm = F.normalize(key_states, dim=-1)
        
        projections = torch.einsum('bhsd,hd->bhs', key_norm, task_vec_heads)
        
        value_mag = value_states.norm(dim=-1)
        value_mag = value_mag / (value_mag.max(dim=-1, keepdim=True)[0] + 1e-8)
        
        task_importance = 0.7 * projections + 0.3 * value_mag
        
        task_importance = task_importance - task_importance.min(dim=-1, keepdim=True)[0] + 0.01
        
        layer_ratio = layer_idx / max(num_layers - 1, 1)
        
        task_weight = 0.1 + 0.8 * (1 / (1 + math.exp(-10 * (layer_ratio - 0.5))))
        
        if attention_scores is not None:
            attn_norm = attention_scores - attention_scores.min(dim=-1, keepdim=True)[0] + 0.01
            
            importance = task_weight * task_importance + (1 - task_weight) * attn_norm
        else:
            importance = task_importance
        
        if visual_mask is not None:
            visual_boost_weight = 0.1 * math.exp(-((layer_ratio - 0.5) ** 2) / 0.1)
            
            if visual_mask.dim() == 1:
                visual_mask = visual_mask.unsqueeze(0)
            if visual_mask.shape[-1] != seq_len:
                if visual_mask.shape[-1] > seq_len:
                    visual_mask = visual_mask[..., :seq_len]
                else:
                    visual_mask = F.pad(visual_mask.float(), (0, seq_len - visual_mask.shape[-1])).bool()
            
            visual_mask_exp = visual_mask.unsqueeze(1).expand(-1, num_heads, -1).float()
            importance = importance + visual_boost_weight * visual_mask_exp
        
        importance = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8)
        
        return importance
    
    def clear(self):
        self.task_vectors.clear()


class SemanticMerger:
    def __init__(self, config: TASMCompressorConfig):
        self.config = config
        
    @torch.no_grad()
    def merge_tokens(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: torch.Tensor,
        num_keep: int,
        visual_mask: Optional[torch.Tensor] = None,
        protected_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        device = key_states.device
        dtype = key_states.dtype
        
        num_keep = min(num_keep, seq_len)
        
        if importance_scores.dim() == 3:
            imp_avg = importance_scores.mean(dim=1) 
        elif importance_scores.dim() == 2:
            imp_avg = importance_scores
        else:
            imp_avg = torch.ones(batch_size, seq_len, device=device, dtype=dtype)
        
        if imp_avg.shape[-1] != seq_len:
            if imp_avg.shape[-1] > seq_len:
                imp_avg = imp_avg[..., :seq_len]
            else:
                pad_size = seq_len - imp_avg.shape[-1]
                imp_avg = F.pad(imp_avg, (0, pad_size), value=0.0)
        
        if protected_mask is not None:
            imp_avg = imp_avg.clone()
            if protected_mask.shape[-1] != seq_len:
                if protected_mask.shape[-1] > seq_len:
                    protected_mask = protected_mask[..., :seq_len]
                else:
                    pad_size = seq_len - protected_mask.shape[-1]
                    protected_mask = F.pad(protected_mask.float(), (0, pad_size), value=0.0).bool()
            imp_avg[protected_mask] = float('inf')
        
        if visual_mask is not None:
            if visual_mask.shape[-1] != seq_len:
                if visual_mask.shape[-1] > seq_len:
                    visual_mask = visual_mask[..., :seq_len]
                else:
                    pad_size = seq_len - visual_mask.shape[-1]
                    visual_mask = F.pad(visual_mask.float(), (0, pad_size), value=0.0).bool()
        
        _, sorted_idx = imp_avg.sort(dim=-1, descending=True)
        
        sorted_idx = sorted_idx.clamp(0, seq_len - 1)
        
        target_idx = sorted_idx[:, :num_keep] 
        source_idx = sorted_idx[:, num_keep:] 
        
        target_idx = target_idx.sort(dim=-1)[0]
        source_idx = source_idx.sort(dim=-1)[0]
        
        num_source = source_idx.shape[1]
        
        if num_source == 0:
            mapping = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
            return key_states, value_states, mapping
        
        target_idx_exp = target_idx.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, head_dim)
        target_idx_exp = target_idx_exp.clamp(0, seq_len - 1)
        merged_keys = torch.gather(key_states, dim=2, index=target_idx_exp)
        merged_values = torch.gather(value_states, dim=2, index=target_idx_exp)
        
        k_avg = key_states.mean(dim=1)
        
        source_idx_clamped = source_idx.clamp(0, seq_len - 1)
        target_idx_clamped = target_idx.clamp(0, seq_len - 1)
        
        source_tokens = torch.gather(
            k_avg, dim=1, 
            index=source_idx_clamped.unsqueeze(-1).expand(-1, -1, head_dim)
        ) 
        target_tokens = torch.gather(
            k_avg, dim=1,
            index=target_idx_clamped.unsqueeze(-1).expand(-1, -1, head_dim)
        ) 
        
        source_norm = F.normalize(source_tokens, dim=-1)
        target_norm = F.normalize(target_tokens, dim=-1)
        similarity = torch.bmm(source_norm, target_norm.transpose(1, 2)) 

        if self.config.preserve_spatial and visual_mask is not None:
            for b in range(batch_size):
                src_idx_b = source_idx_clamped[b].clamp(0, visual_mask.shape[-1] - 1)
                tgt_idx_b = target_idx_clamped[b].clamp(0, visual_mask.shape[-1] - 1)
                
                src_visual = visual_mask[b, src_idx_b]
                tgt_visual = visual_mask[b, tgt_idx_b]
                
                src_pos = source_idx[b].float()
                tgt_pos = target_idx[b].float()
                pos_dist = (src_pos.unsqueeze(1) - tgt_pos.unsqueeze(0)).abs()
                
                visual_penalty = torch.zeros_like(similarity[b])
                visual_src = src_visual.unsqueeze(1).expand(-1, num_keep)
                visual_tgt = tgt_visual.unsqueeze(0).expand(num_source, -1)
                both_visual = visual_src & visual_tgt
                
                visual_penalty[both_visual & (pos_dist > self.config.spatial_window)] = -10.0
                similarity[b] = similarity[b] + visual_penalty
        
        match_scores, match_idx = similarity.max(dim=-1)
        
        valid_match = match_scores >= self.config.merge_similarity_threshold
        
        for b in range(batch_size):
            for h in range(num_heads):
                weight_sum = torch.ones(num_keep, device=device, dtype=dtype)
                
                for s_idx in range(num_source):
                    if valid_match[b, s_idx]:
                        t_idx = match_idx[b, s_idx]
                        weight = match_scores[b, s_idx]
                        
                        src_token_idx = source_idx_clamped[b, s_idx]
                        src_k = key_states[b, h, src_token_idx]
                        src_v = value_states[b, h, src_token_idx]
                        
                        merged_keys[b, h, t_idx] = merged_keys[b, h, t_idx] + weight * src_k
                        merged_values[b, h, t_idx] = merged_values[b, h, t_idx] + weight * src_v
                        weight_sum[t_idx] = weight_sum[t_idx] + weight
                
                merged_keys[b, h] = merged_keys[b, h] / weight_sum.unsqueeze(-1).clamp(min=1.0)
                merged_values[b, h] = merged_values[b, h] / weight_sum.unsqueeze(-1).clamp(min=1.0)
        
        mapping = torch.zeros(batch_size, seq_len, dtype=torch.long, device=device)
        for b in range(batch_size):
            for i, idx in enumerate(target_idx[b]):
                if idx < seq_len:
                    mapping[b, idx] = i
            for s_i, s_idx in enumerate(source_idx[b]):
                if s_idx < seq_len:
                    if valid_match[b, s_i]:
                        mapping[b, s_idx] = match_idx[b, s_i]
                    else:
                        mapping[b, s_idx] = -1
        
        return merged_keys, merged_values, mapping


class HierarchicalMemory:
    def __init__(self, num_layers: int, num_heads: int, head_dim: int, config: TASMCompressorConfig):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.config = config
        
        self.core_keys: List[Optional[torch.Tensor]] = [None] * num_layers
        self.core_values: List[Optional[torch.Tensor]] = [None] * num_layers
        
        self.latent_keys: List[Optional[torch.Tensor]] = [None] * num_layers
        self.latent_values: List[Optional[torch.Tensor]] = [None] * num_layers
        self.latent_importance: List[Optional[torch.Tensor]] = [None] * num_layers
        
        self._retrieval_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
        
        self.reference_distribution: Optional[torch.Tensor] = None
        
    def store(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        device = key_states.device
        
        if seq_len < 4:
            self.core_keys[layer_idx] = key_states
            self.core_values[layer_idx] = value_states
            return key_states, value_states
        
        num_core = min(max(int(seq_len * self.config.core_ratio), 4), seq_len)
        num_latent = min(max(int(seq_len * self.config.latent_ratio) - num_core, 0), seq_len - num_core)
        
        if importance_scores.dim() == 2:
            imp_avg = importance_scores
        elif importance_scores.dim() == 3:
            imp_avg = importance_scores.mean(dim=1)
        else:
            imp_avg = torch.ones(batch_size, seq_len, device=device)
        
        if imp_avg.shape[-1] != seq_len:
            if imp_avg.shape[-1] > seq_len:
                imp_avg = imp_avg[..., :seq_len]
            else:
                pad_size = seq_len - imp_avg.shape[-1]
                imp_avg = F.pad(imp_avg, (0, pad_size), value=0.0)
        
        _, sorted_idx = imp_avg.sort(dim=-1, descending=True)
        
        sorted_idx = sorted_idx.clamp(0, seq_len - 1)
        
        num_core = min(num_core, sorted_idx.shape[-1])
        core_idx = sorted_idx[:, :num_core].sort(dim=-1)
        core_idx = core_idx.clamp(0, seq_len - 1)
        
        core_idx_exp = core_idx.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, head_dim)
        core_k = torch.gather(key_states, dim=2, index=core_idx_exp)
        core_v = torch.gather(value_states, dim=2, index=core_idx_exp)
        
        self.core_keys[layer_idx] = core_k
        self.core_values[layer_idx] = core_v
        
        num_latent = min(num_latent, sorted_idx.shape[-1] - num_core)
        if num_latent > 0:
            latent_idx = sorted_idx[:, num_core:num_core + num_latent].sort(dim=-1)[0]
            latent_idx = latent_idx.clamp(0, seq_len - 1)
            
            latent_idx_exp = latent_idx.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, head_dim)
            latent_k = torch.gather(key_states, dim=2, index=latent_idx_exp)
            latent_v = torch.gather(value_states, dim=2, index=latent_idx_exp)
            
            if importance_scores.dim() == 3:
                latent_imp_idx = latent_idx.unsqueeze(1).expand(-1, num_heads, -1)
                latent_imp = torch.gather(importance_scores, dim=2, index=latent_imp_idx)
            else:
                latent_imp = torch.ones(batch_size, num_heads, num_latent, device=device)
            
            if self.config.enable_dynamic_retrieval:
                self.latent_keys[layer_idx] = latent_k.cpu()
                self.latent_values[layer_idx] = latent_v.cpu()
                self.latent_importance[layer_idx] = latent_imp.cpu()
        
        return core_k, core_v

    def retrieve(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        top_k: Optional[int] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
        if self.latent_keys[layer_idx] is None:
            return None, None
        
        if layer_idx in self._retrieval_cache:
            return self._retrieval_cache[layer_idx]
        
        top_k = top_k or self.config.retrieval_top_k
        device = query_states.device
        dtype = query_states.dtype
        
        latent_k = self.latent_keys[layer_idx].to(device, dtype=dtype)
        latent_v = self.latent_values[layer_idx].to(device, dtype=dtype)
        latent_imp = self.latent_importance[layer_idx].to(device, dtype=dtype) if self.latent_importance[layer_idx] is not None else None
        
        latent_len = latent_k.shape[2]
        top_k = min(top_k, latent_len)
        
        query_for_retrieval = query_states[:, :, -1:, :]
        
        query_norm = F.normalize(query_for_retrieval, dim=-1)
        latent_norm = F.normalize(latent_k, dim=-1)
        
        scores = torch.matmul(query_norm, latent_norm.transpose(-2, -1)).squeeze(2)
        
        if latent_imp is not None:
            imp_weight = (latent_imp - latent_imp.min()) / (latent_imp.max() - latent_imp.min() + 1e-8)
            scores = scores + 0.3 * imp_weight
        
        _, top_idx = scores.topk(top_k, dim=-1)
        
        top_idx_exp = top_idx.unsqueeze(-1).expand(-1, -1, -1, self.head_dim)
        retrieved_k = torch.gather(latent_k, dim=2, index=top_idx_exp)
        retrieved_v = torch.gather(latent_v, dim=2, index=top_idx_exp)
        
        self._retrieval_cache[layer_idx] = (retrieved_k, retrieved_v)
        
        return retrieved_k, retrieved_v
    
    def get_augmented_kv(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        current_keys: Optional[torch.Tensor] = None,
        current_values: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        parts_k = []
        parts_v = []
        
        device = query_states.device
        dtype = query_states.dtype
        batch_size = query_states.shape[0]
        
        if self.core_keys[layer_idx] is not None:
            core_k = self.core_keys[layer_idx].to(device, dtype=dtype)
            core_v = self.core_values[layer_idx].to(device, dtype=dtype)
            core_k = core_k.expand(batch_size, -1, -1, -1)
            core_v = core_v.expand(batch_size, -1, -1, -1)
            parts_k.append(core_k)
            parts_v.append(core_v)
        
        if self.config.enable_dynamic_retrieval:
            ret_k, ret_v = self.retrieve(layer_idx, query_states)
            if ret_k is not None:
                parts_k.append(ret_k)
                parts_v.append(ret_v)
        
        if current_keys is not None:
            parts_k.append(current_keys)
            parts_v.append(current_values)
        
        if len(parts_k) == 0:
            return None, None
        
        return torch.cat(parts_k, dim=2), torch.cat(parts_v, dim=2)
    
    def set_reference_distribution(self, distribution: torch.Tensor):
        """Set reference attention distribution for JS divergence."""
        self.reference_distribution = distribution.detach().cpu()
    
    def should_retrieve(self, current_distribution: torch.Tensor) -> bool:
        """Check if retrieval is needed based on JS divergence."""
        if self.reference_distribution is None:
            return False
        
        ref = self.reference_distribution.to(current_distribution.device)
        
        curr_len = current_distribution.shape[-1]
        ref_len = ref.shape[-1]
        
        if curr_len != ref_len:
            if curr_len < ref_len:
                current_distribution = F.pad(current_distribution, (0, ref_len - curr_len))
            else:
                current_distribution = current_distribution[..., :ref_len]
        
        p = current_distribution + 1e-8
        q = ref + 1e-8
        p = p / p.sum(dim=-1, keepdim=True)
        q = q / q.sum(dim=-1, keepdim=True)
        
        m = 0.5 * (p + q)
        kl_pm = (p * (p.log() - m.log())).sum(dim=-1)
        kl_qm = (q * (q.log() - m.log())).sum(dim=-1)
        js_div = 0.5 * (kl_pm + kl_qm)
        
        return js_div.mean().item() > self.config.js_threshold
    
    def clear_retrieval_cache(self):
        self._retrieval_cache.clear()
    
    def clear_all(self):
        self.core_keys = [None] * self.num_layers
        self.core_values = [None] * self.num_layers
        self.latent_keys = [None] * self.num_layers
        self.latent_values = [None] * self.num_layers
        self.latent_importance = [None] * self.num_layers
        self._retrieval_cache.clear()
        self.reference_distribution = None


class TASMCompressor:
    
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        num_layers: int,
        head_dim: Optional[int] = None,
        config: Optional[TASMCompressorConfig] = None,
    ):
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_layers = num_layers
        self.head_dim = head_dim or hidden_size // num_heads
        self.config = config or TASMCompressorConfig()
        
        self.task_vector = TaskVectorComputer(
            hidden_size=hidden_size,
            num_layers=num_layers,
            method=self.config.task_vector_method,
        )
        
        self.merger = SemanticMerger(self.config)
        
        self.memory = HierarchicalMemory(
            num_layers=num_layers,
            num_heads=num_kv_heads,
            head_dim=self.head_dim,
            config=self.config,
        )
        
        self._stats = {
            'original_tokens': 0,
            'core_tokens': 0,
            'latent_tokens': 0,
            'merged_tokens': 0,
        }
        
        self._device = None
        self._dtype = torch.bfloat16
    
    def to(self, device, dtype=torch.bfloat16):
        """Move components to device."""
        self._device = device
        self._dtype = dtype
        self.task_vector = self.task_vector.to(device, dtype=dtype)
        return self
    
    def extract_task_vector(
        self,
        question_hidden: torch.Tensor,
        answer_hidden: torch.Tensor,
        layer_idx: Optional[int] = None,
    ):
        """
        Extract task vector from ICL examples.
        
        Should be called once during compression setup.
        """
        if layer_idx is not None:
            self.task_vector.extract_from_qa_pairs(question_hidden, answer_hidden, layer_idx)
        else:
            for idx in range(self.num_layers):
                self.task_vector.extract_from_qa_pairs(question_hidden, answer_hidden, idx)
    
    def compress_layer(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        attention_scores: Optional[torch.Tensor] = None,
        visual_mask: Optional[torch.Tensor] = None,
        protected_mask: Optional[torch.Tensor] = None,
        num_keep: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        
        if num_keep is None:
            num_keep = int(seq_len * self.config.target_compression_ratio)
        num_keep = max(num_keep, self.config.sink_tokens + self.config.local_tokens)
        
        self._stats['original_tokens'] += seq_len
        
        importance = self.task_vector.compute_importance_scores(
            key_states, value_states, layer_idx, attention_scores
        )
        
        if attention_scores is not None:
            importance = (
                self.config.task_vector_weight * importance +
                (1 - self.config.task_vector_weight) * attention_scores
            )

        if self.config.enable_merging:
            merged_k, merged_v, merge_map = self.merger.merge_tokens(
                key_states=key_states,
                value_states=value_states,
                importance_scores=importance,
                num_keep=num_keep,
                visual_mask=visual_mask,
                protected_mask=protected_mask,
            )
            self._stats['merged_tokens'] += merged_k.shape[2]
            
            merged_importance = torch.ones(
                batch_size, num_heads, merged_k.shape[2],
                device=key_states.device, dtype=importance.dtype
            )
        else:
            imp_avg = importance.mean(dim=1)
            if protected_mask is not None:
                imp_avg = imp_avg.clone()
                imp_avg[protected_mask] = float('inf')
            
            _, sorted_idx = imp_avg.sort(dim=-1, descending=True)
            keep_idx = sorted_idx[:, :num_keep].sort(dim=-1)[0]
            
            keep_idx_exp = keep_idx.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, head_dim)
            merged_k = torch.gather(key_states, dim=2, index=keep_idx_exp)
            merged_v = torch.gather(value_states, dim=2, index=keep_idx_exp)
            
            keep_imp_idx = keep_idx.unsqueeze(1).expand(-1, num_heads, -1)
            merged_importance = torch.gather(importance, dim=2, index=keep_imp_idx)
        
        if self.config.enable_dynamic_retrieval:
            core_k, core_v = self.memory.store(
                layer_idx=layer_idx,
                key_states=merged_k,
                value_states=merged_v,
                importance_scores=merged_importance,
            )
            self._stats['core_tokens'] += core_k.shape[2]
            self._stats['latent_tokens'] += merged_k.shape[2] - core_k.shape[2]
        else:
            core_k = merged_k
            core_v = merged_v
            self._stats['core_tokens'] += core_k.shape[2]
        
        return core_k, core_v
    
    def get_augmented_kv(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        current_keys: Optional[torch.Tensor] = None,
        current_values: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.memory.get_augmented_kv(
            layer_idx, query_states, current_keys, current_values
        )
    
    def clear(self):
        self.task_vector.clear()
        self.memory.clear_all()
        self._stats = {k: 0 for k in self._stats}
    
    def clear_retrieval_cache(self):
        self.memory.clear_retrieval_cache()
    
    def get_stats(self) -> Dict:
        stats = self._stats.copy()
        if stats['original_tokens'] > 0:
            stats['compression_ratio'] = stats['core_tokens'] / stats['original_tokens']
            stats['total_retained'] = (stats['core_tokens'] + stats['latent_tokens']) / stats['original_tokens']
        return stats


def create_tasm_compressor(model_config, tasm_config: Optional[TASMCompressorConfig] = None) -> TASMCompressor:
    return TASMCompressor(
        hidden_size=model_config.hidden_size,
        num_heads=model_config.num_attention_heads,
        num_kv_heads=getattr(model_config, 'num_key_value_heads', model_config.num_attention_heads),
        num_layers=model_config.num_hidden_layers,
        head_dim=model_config.hidden_size // model_config.num_attention_heads,
        config=tasm_config,
    )
