{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ManifoldKV: Geometry-Driven KV Cache Compression\n",
    "\n",
    "**ICML 2026 - Interactive Demo**\n",
    "\n",
    "This notebook demonstrates the key ideas behind ManifoldKV:\n",
    "1. L2 distance captures both direction and magnitude\n",
    "2. Cosine similarity loses magnitude information\n",
    "3. Windowed centroids solve centroid dilution at 64K+"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '..')\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print(f\"PyTorch version: {torch.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. The Core Algorithm: L2 vs Cosine\n",
    "\n",
    "ManifoldKV scores tokens by their L2 distance from the centroid:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def manifold_kv_score(keys):\n",
    "    \"\"\"ManifoldKV: L2 distance from centroid.\"\"\"\n",
    "    mu = keys.mean(dim=2, keepdim=True)\n",
    "    return torch.norm(keys - mu, dim=-1)\n",
    "\n",
    "def keydiff_score(keys):\n",
    "    \"\"\"KeyDiff: Cosine similarity (normalized).\"\"\"\n",
    "    keys_norm = F.normalize(keys, dim=-1)\n",
    "    anchor = keys_norm.mean(dim=2, keepdim=True)\n",
    "    return -F.cosine_similarity(keys, anchor, dim=-1)\n",
    "\n",
    "# Create sample keys\n",
    "bsz, heads, seq_len, dim = 1, 4, 100, 64\n",
    "keys = torch.randn(bsz, heads, seq_len, dim)\n",
    "\n",
    "# Score with both methods\n",
    "scores_manifold = manifold_kv_score(keys)\n",
    "scores_keydiff = keydiff_score(keys)\n",
    "\n",
    "print(f\"ManifoldKV scores: mean={scores_manifold.mean():.3f}, std={scores_manifold.std():.3f}\")\n",
    "print(f\"KeyDiff scores:    mean={scores_keydiff.mean():.3f}, std={scores_keydiff.std():.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Radial Outliers: Where Cosine Fails\n",
    "\n",
    "A radial outlier (k = α * μ with α >> 1) has maximum cosine similarity but large L2 distance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create centroid\n",
    "mu = torch.randn(64)\n",
    "mu = mu / mu.norm()  # unit vector\n",
    "\n",
    "# Create radial outlier (same direction, 10x magnitude)\n",
    "radial_outlier = 10 * mu\n",
    "\n",
    "# Create angular outlier (different direction, same magnitude)\n",
    "angular_outlier = torch.randn(64)\n",
    "angular_outlier = angular_outlier / angular_outlier.norm()\n",
    "\n",
    "# Compute scores\n",
    "cosine_radial = F.cosine_similarity(radial_outlier.unsqueeze(0), mu.unsqueeze(0)).item()\n",
    "cosine_angular = F.cosine_similarity(angular_outlier.unsqueeze(0), mu.unsqueeze(0)).item()\n",
    "\n",
    "l2_radial = (radial_outlier - mu).norm().item()\n",
    "l2_angular = (angular_outlier - mu).norm().item()\n",
    "\n",
    "print(\"Radial Outlier (k = 10*μ):\")\n",
    "print(f\"  Cosine similarity: {cosine_radial:.3f} (looks TYPICAL!)\")\n",
    "print(f\"  L2 distance:       {l2_radial:.3f} (correctly identified as OUTLIER)\")\n",
    "print()\n",
    "print(\"Angular Outlier (different direction):\")\n",
    "print(f\"  Cosine similarity: {cosine_angular:.3f}\")\n",
    "print(f\"  L2 distance:       {l2_angular:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Geometric Decomposition\n",
    "\n",
    "L2 distance captures both angular and radial deviation:\n",
    "\n",
    "$$\\|\\mathbf{k} - \\boldsymbol{\\mu}\\|_2^2 = \\|\\mathbf{k}\\|^2 + \\|\\boldsymbol{\\mu}\\|^2 - 2\\mathbf{k}^\\top\\boldsymbol{\\mu}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the decomposition\n",
    "def decompose_l2(k, mu):\n",
    "    \"\"\"Decompose L2 distance into magnitude and angular components.\"\"\"\n",
    "    k_mag = k.norm()\n",
    "    mu_mag = mu.norm()\n",
    "    cos_angle = F.cosine_similarity(k.unsqueeze(0), mu.unsqueeze(0)).item()\n",
    "    \n",
    "    # L2^2 = |k|^2 + |mu|^2 - 2*|k|*|mu|*cos(angle)\n",
    "    l2_squared = k_mag**2 + mu_mag**2 - 2*k_mag*mu_mag*cos_angle\n",
    "    l2 = l2_squared.sqrt()\n",
    "    \n",
    "    return {\n",
    "        'l2': l2.item(),\n",
    "        'k_magnitude': k_mag.item(),\n",
    "        'mu_magnitude': mu_mag.item(),\n",
    "        'cos_angle': cos_angle,\n",
    "        'magnitude_term': (k_mag**2 + mu_mag**2).item(),\n",
    "        'angular_term': (-2*k_mag*mu_mag*cos_angle).item(),\n",
    "    }\n",
    "\n",
    "# Compare radial vs angular outliers\n",
    "print(\"Radial Outlier Decomposition:\")\n",
    "d = decompose_l2(radial_outlier, mu)\n",
    "for k, v in d.items():\n",
    "    print(f\"  {k}: {v:.3f}\")\n",
    "\n",
    "print(\"\\nAngular Outlier Decomposition:\")\n",
    "d = decompose_l2(angular_outlier, mu)\n",
    "for k, v in d.items():\n",
    "    print(f\"  {k}: {v:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Using ManifoldKV with KVPress\n",
    "\n",
    "Here's how to use ManifoldKV in practice:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kvpress import (\n",
    "    ManifoldKVPress,\n",
    "    WindowedManifoldKVPress,\n",
    "    AdaKVPress,\n",
    "    ManifoldKVSnapKVScorerPress,\n",
    ")\n",
    "\n",
    "# Standard ManifoldKV (4K-32K contexts)\n",
    "press_standard = ManifoldKVPress(compression_ratio=0.2)\n",
    "print(f\"Standard: {press_standard}\")\n",
    "\n",
    "# AdaKV + ManifoldKV (SOTA configuration)\n",
    "press_adakv = AdaKVPress(ManifoldKVSnapKVScorerPress())\n",
    "press_adakv.compression_ratio = 0.2\n",
    "print(f\"AdaKV + ManifoldKV: compression={press_adakv.compression_ratio}\")\n",
    "\n",
    "# Windowed ManifoldKV (64K+ contexts)\n",
    "press_windowed = WindowedManifoldKVPress(\n",
    "    compression_ratio=0.25,\n",
    "    window_size=4096,  # Best window size\n",
    ")\n",
    "print(f\"Windowed: window_size={press_windowed.window_size}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Windowed Centroids: Solving Centroid Dilution\n",
    "\n",
    "At 64K+ tokens, the global centroid becomes meaningless. Windowed centroids preserve local structure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def windowed_manifold_score(keys, window_size=4096):\n",
    "    \"\"\"Windowed ManifoldKV: Local centroids.\"\"\"\n",
    "    bsz, heads, seq_len, dim = keys.shape\n",
    "    scores = torch.zeros(bsz, heads, seq_len)\n",
    "    \n",
    "    for start in range(0, seq_len, window_size):\n",
    "        end = min(start + window_size, seq_len)\n",
    "        window = keys[:, :, start:end, :]\n",
    "        mu = window.mean(dim=2, keepdim=True)\n",
    "        scores[:, :, start:end] = torch.norm(window - mu, dim=-1)\n",
    "    \n",
    "    return scores\n",
    "\n",
    "# Simulate long context with diverse semantic content\n",
    "seq_len = 16384\n",
    "keys_long = torch.randn(1, 4, seq_len, 64)\n",
    "\n",
    "# Add structure: different \"topics\" in different regions\n",
    "for i in range(4):\n",
    "    start = i * seq_len // 4\n",
    "    end = (i+1) * seq_len // 4\n",
    "    topic_offset = torch.randn(64) * 5  # Different topic centers\n",
    "    keys_long[:, :, start:end, :] += topic_offset\n",
    "\n",
    "# Compare global vs windowed\n",
    "scores_global = manifold_kv_score(keys_long)\n",
    "scores_windowed = windowed_manifold_score(keys_long, window_size=4096)\n",
    "\n",
    "print(f\"Global centroid - score std: {scores_global.std():.4f}\")\n",
    "print(f\"Windowed (4K)   - score std: {scores_windowed.std():.4f}\")\n",
    "print(f\"\\nHigher std = more discriminative power\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "ManifoldKV improves KV cache compression by:\n",
    "\n",
    "1. **L2 Distance**: Captures both direction and magnitude (vs cosine which only captures direction)\n",
    "2. **Radial Outliers**: Correctly identifies tokens with unusual magnitude\n",
    "3. **Windowed Centroids**: Solves centroid dilution at 64K+ contexts\n",
    "\n",
    "**Results**:\n",
    "- 95.73% on RULER (4K-16K)\n",
    "- 84.3% at 64K (+49 points recovery from centroid dilution)\n",
    "- +15 points on multi-key retrieval"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
