# SPDX-FileCopyrightText: Copyright (c) 2024 ManifoldKV Research
# SPDX-License-Identifier: Apache-2.0

"""ManifoldKVPress: Magnitude-Aware Outlier Detection for KV Cache Compression.

Key insight: L2 distance (Euclidean) captures both directional AND magnitude deviation,
while cosine similarity (used by KeyDiff) captures only direction.

This simple change yields 40+ percentage point improvement on RULER benchmark:
- KeyDiff (cosine): 52.8%
- ManifoldKV (L2):  92.7%

Tokens far from the centroid in L2 distance are "spectral outliers" - semantically
unique tokens (names, numbers, entities) that should be retained.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal

import torch
from torch import nn
from torch.nn import functional as F

from kvpress.presses.scorer_press import ScorerPress


@dataclass
class ManifoldKVPress(ScorerPress):
    """Magnitude-aware KV cache compression using L2 distance from centroid.
    
    Identifies 'Critical' tokens by their Euclidean distance from the 
    context centroid. Unlike KeyDiff (cosine), this captures both:
    - Angular deviation (unusual direction)
    - Radial deviation (unusual magnitude)
    
    Parameters
    ----------
    compression_ratio : float, default=0.5
        Fraction of tokens to evict. Higher = more compression.
    distance_metric : str, default='l2'
        Distance metric: 'l2' (Euclidean), 'l1' (Manhattan), 'linf' (max)
    """

    compression_ratio: float = 0.5
    distance_metric: Literal['l2', 'l1', 'linf'] = 'l2'
    
    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        # keys: (bsz, heads, seq_len, head_dim)
        
        # Step 1: Compute Centroid (Mean Key) per head
        mu = keys.mean(dim=2, keepdim=True)  # (bsz, heads, 1, head_dim)
        
        # Step 2: Compute Residual
        res = keys - mu  # (bsz, heads, seq_len, head_dim)
        
        # Step 3: Compute Distance based on metric
        if self.distance_metric == 'l2':
            # Euclidean distance - captures direction + magnitude
            scores = torch.norm(res, p=2, dim=-1)
        elif self.distance_metric == 'l1':
            # Manhattan distance
            scores = torch.norm(res, p=1, dim=-1)
        elif self.distance_metric == 'linf':
            # Max norm
            scores = torch.norm(res, p=float('inf'), dim=-1)
        else:
            # Default to L2
            scores = torch.norm(res, p=2, dim=-1)
        
        return scores  # (bsz, heads, seq_len)


@dataclass
class ManifoldKVSnapKVScorerPress(ScorerPress):
    """ManifoldKV scoring for AdaKV wrappers.
    
    Returns the L2 distance from centroid. AdaKV uses the variance
    of these scores across heads to allocate compression budget.
    
    This is the recommended configuration: AdaKV + ManifoldKVSnapKVScorerPress
    achieves 92.69% on RULER (vs 84.09% for AdaKV + ExpectedAttention).
    """
    
    distance_metric: Literal['l2', 'l1', 'linf'] = 'l2'

    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        # Same logic as ManifoldKVPress
        mu = keys.mean(dim=2, keepdim=True)
        res = keys - mu
        
        if self.distance_metric == 'l2':
            scores = torch.norm(res, p=2, dim=-1)
        elif self.distance_metric == 'l1':
            scores = torch.norm(res, p=1, dim=-1)
        elif self.distance_metric == 'linf':
            scores = torch.norm(res, p=float('inf'), dim=-1)
        else:
            scores = torch.norm(res, p=2, dim=-1)
            
        return scores


@dataclass 
class ManifoldKVL1Press(ScorerPress):
    """ManifoldKV variant using L1 (Manhattan) distance.
    
    For ablation studies comparing distance metrics.
    L1 also captures magnitude but with different sensitivity.
    """
    
    compression_ratio: float = 0.5
    
    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        mu = keys.mean(dim=2, keepdim=True)
        return torch.norm(keys - mu, p=1, dim=-1)


@dataclass
class ManifoldKVLinfPress(ScorerPress):
    """ManifoldKV variant using L-infinity (max) distance.
    
    For ablation studies comparing distance metrics.
    Captures the maximum deviation in any dimension.
    """
    
    compression_ratio: float = 0.5
    
    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        mu = keys.mean(dim=2, keepdim=True)
        return torch.norm(keys - mu, p=float('inf'), dim=-1)


@dataclass
class WindowedManifoldKVPress(ScorerPress):
    """Windowed ManifoldKV: Local centroids for long context robustness.
    
    Instead of computing a single global centroid (which becomes diluted at 64K+),
    this variant computes LOCAL centroids over sliding windows. Each token's score
    is its L2 distance from its local neighborhood's centroid.
    
    This addresses the "centroid dilution problem" where:
    - At 64K tokens, global centroid averages over too many diverse keys
    - Local centroids remain representative of their neighborhood
    - Outliers are detected relative to their local context
    
    Parameters
    ----------
    compression_ratio : float, default=0.5
        Fraction of tokens to evict.
    window_size : int, default=4096
        Size of the local window for computing centroids.
        Recommended: 2048-8192 for 64K contexts.
    stride : int, default=None
        Stride between windows. If None, uses window_size (non-overlapping).
        Smaller stride = smoother scores but more compute.
    pooling : str, default='mean'
        How to combine scores when windows overlap: 'mean', 'max', 'min'
    """
    
    compression_ratio: float = 0.5
    window_size: int = 4096
    stride: int = None
    pooling: Literal['mean', 'max', 'min'] = 'mean'
    
    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        # keys: (bsz, heads, seq_len, head_dim)
        bsz, num_heads, seq_len, head_dim = keys.shape
        
        stride = self.stride if self.stride is not None else self.window_size
        
        # For short contexts, fall back to global centroid
        if seq_len <= self.window_size:
            mu = keys.mean(dim=2, keepdim=True)
            return torch.norm(keys - mu, p=2, dim=-1)
        
        # Initialize score accumulator for overlapping windows
        scores_sum = torch.zeros(bsz, num_heads, seq_len, device=keys.device, dtype=keys.dtype)
        scores_count = torch.zeros(bsz, num_heads, seq_len, device=keys.device, dtype=keys.dtype)
        
        # Sliding window approach
        for start in range(0, seq_len, stride):
            end = min(start + self.window_size, seq_len)
            
            # Extract window
            window_keys = keys[:, :, start:end, :]  # (bsz, heads, window_len, head_dim)
            
            # Compute local centroid
            local_mu = window_keys.mean(dim=2, keepdim=True)  # (bsz, heads, 1, head_dim)
            
            # Compute L2 distance from local centroid
            local_scores = torch.norm(window_keys - local_mu, p=2, dim=-1)  # (bsz, heads, window_len)
            
            # Accumulate scores
            if self.pooling == 'mean':
                scores_sum[:, :, start:end] += local_scores
                scores_count[:, :, start:end] += 1
            elif self.pooling == 'max':
                scores_sum[:, :, start:end] = torch.max(scores_sum[:, :, start:end], local_scores)
                scores_count[:, :, start:end] = 1  # Dummy for max
            elif self.pooling == 'min':
                # Initialize with inf for min pooling
                if start == 0:
                    scores_sum.fill_(float('inf'))
                scores_sum[:, :, start:end] = torch.min(scores_sum[:, :, start:end], local_scores)
                scores_count[:, :, start:end] = 1
        
        # Finalize scores
        if self.pooling == 'mean':
            scores = scores_sum / scores_count.clamp(min=1)
        else:
            scores = scores_sum
        
        return scores


@dataclass
class HybridManifoldKVPress(ScorerPress):
    """Hybrid ManifoldKV: Combines global and local centroids.
    
    Uses a weighted combination of:
    1. Global centroid distance (captures document-level outliers)
    2. Local centroid distance (captures local context outliers)
    
    This balances the benefits of both approaches:
    - Global: Good for truly unique tokens (entities, special numbers)
    - Local: Good for contextually important tokens (topic transitions)
    
    Parameters
    ----------
    compression_ratio : float, default=0.5
        Fraction of tokens to evict.
    window_size : int, default=4096
        Size of local windows.
    global_weight : float, default=0.3
        Weight for global centroid score (0-1).
        Higher = more emphasis on global outliers.
        Recommended: 0.2-0.4 for 64K contexts.
    """
    
    compression_ratio: float = 0.5
    window_size: int = 4096
    global_weight: float = 0.3
    
    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        bsz, num_heads, seq_len, head_dim = keys.shape
        
        # Global centroid score
        global_mu = keys.mean(dim=2, keepdim=True)
        global_scores = torch.norm(keys - global_mu, p=2, dim=-1)
        
        # For short contexts, use only global
        if seq_len <= self.window_size:
            return global_scores
        
        # Local centroid scores (non-overlapping windows for efficiency)
        local_scores = torch.zeros_like(global_scores)
        
        for start in range(0, seq_len, self.window_size):
            end = min(start + self.window_size, seq_len)
            window_keys = keys[:, :, start:end, :]
            local_mu = window_keys.mean(dim=2, keepdim=True)
            local_scores[:, :, start:end] = torch.norm(window_keys - local_mu, p=2, dim=-1)
        
        # Normalize both to similar scale before combining
        global_scores = global_scores / (global_scores.max(dim=-1, keepdim=True)[0] + 1e-8)
        local_scores = local_scores / (local_scores.max(dim=-1, keepdim=True)[0] + 1e-8)
        
        # Weighted combination
        combined = self.global_weight * global_scores + (1 - self.global_weight) * local_scores
        
        return combined


@dataclass
class NormalizedManifoldKVPress(ScorerPress):
    """Normalized ManifoldKV: L2 distance on unit-normalized keys.
    
    Addresses the magnitude noise problem at long contexts by:
    1. Normalizing all keys to unit vectors (removes magnitude)
    2. Computing centroid of normalized keys
    3. Using L2 distance (which becomes equivalent to angular distance)
    
    This combines ManifoldKV's L2 metric with KeyDiff's normalization insight.
    
    Parameters
    ----------
    compression_ratio : float, default=0.5
        Fraction of tokens to evict.
    """
    
    compression_ratio: float = 0.5
    
    def score(self, module, hidden_states, keys, values, attentions, kwargs) -> torch.Tensor:
        # Normalize keys to unit vectors
        keys_normalized = F.normalize(keys, p=2, dim=-1)
        
        # Compute centroid of normalized keys
        mu = keys_normalized.mean(dim=2, keepdim=True)
        
        # L2 distance from centroid (now purely directional)
        scores = torch.norm(keys_normalized - mu, p=2, dim=-1)
        
        return scores
