{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "header",
   "metadata": {},
   "source": [
    "# EMP on Transformer Attention Scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "setup-header",
   "metadata": {},
   "source": [
    "## 1. Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "imports",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import time\n",
    "import copy\n",
    "from dataclasses import dataclass\n",
    "from typing import Optional, Tuple, Dict, List\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR\n",
    "\n",
    "# For WikiText-2 dataset\n",
    "try:\n",
    "    from datasets import load_dataset\n",
    "    HAS_DATASETS = True\n",
    "except ImportError:\n",
    "    HAS_DATASETS = False\n",
    "    print(\"Installing datasets library...\")\n",
    "    !pip install datasets -q\n",
    "    from datasets import load_dataset\n",
    "    HAS_DATASETS = True\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Device: {device}\")\n",
    "if device == \"cuda\":\n",
    "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "config-header",
   "metadata": {},
   "source": [
    "## 2. Model Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "config",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class TransformerConfig:\n",
    "    \"\"\"Configuration for GPT-style transformer.\"\"\"\n",
    "    vocab_size: int = 50257  # GPT-2 vocab size, will be overridden\n",
    "    block_size: int = 256    # context length\n",
    "    n_layer: int = 6         # number of transformer blocks\n",
    "    n_head: int = 8          # number of attention heads\n",
    "    n_embd: int = 512        # embedding dimension\n",
    "    dropout: float = 0.1\n",
    "    bias: bool = True        # use bias in Linear and LayerNorm\n",
    "    \n",
    "    # Training config\n",
    "    batch_size: int = 32\n",
    "    epochs: int = 10\n",
    "    lr: float = 3e-4\n",
    "    weight_decay: float = 0.01\n",
    "    warmup_epochs: int = 1\n",
    "\n",
    "config = TransformerConfig()\n",
    "print(f\"Config: {config}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dataset-header",
   "metadata": {},
   "source": [
    "## 3. Dataset: WikiText-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "tokenizer",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharTokenizer:\n",
    "    \"\"\"Simple character-level tokenizer.\"\"\"\n",
    "    def __init__(self, text: str):\n",
    "        chars = sorted(list(set(text)))\n",
    "        self.char_to_idx = {ch: i for i, ch in enumerate(chars)}\n",
    "        self.idx_to_char = {i: ch for i, ch in enumerate(chars)}\n",
    "        self.vocab_size = len(chars)\n",
    "    \n",
    "    def encode(self, text: str) -> List[int]:\n",
    "        return [self.char_to_idx.get(ch, 0) for ch in text]\n",
    "    \n",
    "    def decode(self, indices: List[int]) -> str:\n",
    "        return ''.join([self.idx_to_char.get(i, '') for i in indices])\n",
    "\n",
    "\n",
    "class WordTokenizer:\n",
    "    \"\"\"Simple word-level tokenizer with special tokens.\"\"\"\n",
    "    def __init__(self, text: str, max_vocab: int = 30000):\n",
    "        # Count word frequencies\n",
    "        words = text.split()\n",
    "        word_freq = {}\n",
    "        for w in words:\n",
    "            word_freq[w] = word_freq.get(w, 0) + 1\n",
    "        \n",
    "        # Keep most frequent words\n",
    "        sorted_words = sorted(word_freq.items(), key=lambda x: -x[1])\n",
    "        vocab_words = [w for w, _ in sorted_words[:max_vocab - 2]]\n",
    "        \n",
    "        # Special tokens\n",
    "        self.pad_token = '<PAD>'\n",
    "        self.unk_token = '<UNK>'\n",
    "        \n",
    "        all_tokens = [self.pad_token, self.unk_token] + vocab_words\n",
    "        self.word_to_idx = {w: i for i, w in enumerate(all_tokens)}\n",
    "        self.idx_to_word = {i: w for i, w in enumerate(all_tokens)}\n",
    "        self.vocab_size = len(all_tokens)\n",
    "        self.unk_idx = self.word_to_idx[self.unk_token]\n",
    "    \n",
    "    def encode(self, text: str) -> List[int]:\n",
    "        return [self.word_to_idx.get(w, self.unk_idx) for w in text.split()]\n",
    "    \n",
    "    def decode(self, indices: List[int]) -> str:\n",
    "        return ' '.join([self.idx_to_word.get(i, self.unk_token) for i in indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dataset-class",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMDataset(Dataset):\n",
    "    \"\"\"Language Modeling Dataset.\"\"\"\n",
    "    def __init__(self, data: torch.Tensor, block_size: int):\n",
    "        self.data = data\n",
    "        self.block_size = block_size\n",
    "    \n",
    "    def __len__(self):\n",
    "        return max(0, len(self.data) - self.block_size - 1)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        x = self.data[idx:idx + self.block_size]\n",
    "        y = self.data[idx + 1:idx + self.block_size + 1]\n",
    "        return x, y\n",
    "\n",
    "\n",
    "def load_wikitext2(config: TransformerConfig, use_char_level: bool = False):\n",
    "    \"\"\"Load WikiText-2 dataset.\"\"\"\n",
    "    print(\"Loading WikiText-2...\")\n",
    "    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')\n",
    "    \n",
    "    # Combine all text\n",
    "    train_text = '\\n'.join(dataset['train']['text'])\n",
    "    val_text = '\\n'.join(dataset['validation']['text'])\n",
    "    test_text = '\\n'.join(dataset['test']['text'])\n",
    "    \n",
    "    # Create tokenizer\n",
    "    if use_char_level:\n",
    "        tokenizer = CharTokenizer(train_text)\n",
    "    else:\n",
    "        tokenizer = WordTokenizer(train_text, max_vocab=20000)\n",
    "    \n",
    "    print(f\"Vocabulary size: {tokenizer.vocab_size}\")\n",
    "    \n",
    "    # Encode data\n",
    "    train_ids = torch.tensor(tokenizer.encode(train_text), dtype=torch.long)\n",
    "    val_ids = torch.tensor(tokenizer.encode(val_text), dtype=torch.long)\n",
    "    test_ids = torch.tensor(tokenizer.encode(test_text), dtype=torch.long)\n",
    "    \n",
    "    print(f\"Train tokens: {len(train_ids):,}\")\n",
    "    print(f\"Val tokens: {len(val_ids):,}\")\n",
    "    print(f\"Test tokens: {len(test_ids):,}\")\n",
    "    \n",
    "    # Create datasets\n",
    "    train_dataset = LMDataset(train_ids, config.block_size)\n",
    "    val_dataset = LMDataset(val_ids, config.block_size)\n",
    "    test_dataset = LMDataset(test_ids, config.block_size)\n",
    "    worker_count = 0 if os.name == \"nt\" else 2\n",
    "    pin_memory = torch.cuda.is_available() and worker_count > 0\n",
    "    \n",
    "    # Create dataloaders\n",
    "    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=worker_count, pin_memory=pin_memory)\n",
    "    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=worker_count, pin_memory=pin_memory)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=worker_count, pin_memory=pin_memory)\n",
    "    \n",
    "    return train_loader, val_loader, test_loader, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "load-data",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "train_loader, val_loader, test_loader, tokenizer = load_wikitext2(config, use_char_level=False)\n",
    "config.vocab_size = tokenizer.vocab_size\n",
    "print(f\"\\nUpdated vocab_size: {config.vocab_size}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "model-header",
   "metadata": {},
   "source": [
    "## 4. Transformer Model with EMP-compatible Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "attention",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CausalSelfAttention(nn.Module):\n",
    "    \"\"\"Multi-head causal self-attention with EMP support.\n",
    "    \n",
    "    Supports pruning attention scores (pre-softmax) via EMP.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, config: TransformerConfig):\n",
    "        super().__init__()\n",
    "        assert config.n_embd % config.n_head == 0\n",
    "        \n",
    "        self.n_head = config.n_head\n",
    "        self.n_embd = config.n_embd\n",
    "        self.head_dim = config.n_embd // config.n_head\n",
    "        self.dropout = config.dropout\n",
    "        \n",
    "        # QKV projection\n",
    "        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)\n",
    "        # Output projection\n",
    "        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n",
    "        \n",
    "        # Regularization\n",
    "        self.attn_dropout = nn.Dropout(config.dropout)\n",
    "        self.resid_dropout = nn.Dropout(config.dropout)\n",
    "        \n",
    "        # Causal mask\n",
    "        self.register_buffer(\"mask\", torch.tril(torch.ones(config.block_size, config.block_size))\n",
    "                                          .view(1, 1, config.block_size, config.block_size))\n",
    "        \n",
    "        # EMP pruning config (set externally)\n",
    "        self.emp_beta = None  # If set, apply EMP\n",
    "        self.emp_mode = 'per_head'  # 'per_head' or 'global'\n",
    "        \n",
    "        # Storage for attention scores (for analysis)\n",
    "        self.last_attn_scores = None\n",
    "        self.last_attn_mask = None\n",
    "    \n",
    "    def compute_neff(self, scores: torch.Tensor) -> int:\n",
    "        \"\"\"Compute effective number N_eff from attention scores.\n",
    "        \n",
    "        Args:\n",
    "            scores: Attention scores of shape (*, N)\n",
    "        \n",
    "        Returns:\n",
    "            N_eff: Effective number of significant entries\n",
    "        \"\"\"\n",
    "        # Flatten to 1D for N_eff computation\n",
    "        s = scores.reshape(-1)\n",
    "        \n",
    "        # Normalize by absolute sum\n",
    "        s_abs = torch.abs(s)\n",
    "        s_sum = s_abs.sum()\n",
    "        if s_sum < 1e-10:\n",
    "            return len(s)\n",
    "        \n",
    "        omega = s_abs / s_sum\n",
    "        \n",
    "        # N_eff = 1 / sum(omega^2)\n",
    "        neff = 1.0 / (omega ** 2).sum()\n",
    "        return neff\n",
    "    \n",
    "    def apply_emp_mask(self, scores: torch.Tensor, beta: float, mode: str = 'per_head') -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        \"\"\"Apply EMP-based masking to attention scores.\n",
    "        \n",
    "        Args:\n",
    "            scores: Pre-softmax attention scores (B, n_head, T, T)\n",
    "            beta: EMP coefficient\n",
    "            mode: 'per_head' or 'global'\n",
    "        \n",
    "        Returns:\n",
    "            masked_scores: Scores with low-importance entries set to -inf\n",
    "            mask: Binary mask indicating retained entries\n",
    "        \"\"\"\n",
    "        B, H, T, T2 = scores.shape\n",
    "        \n",
    "        if mode == 'per_head':\n",
    "            # Apply EMP independently per head\n",
    "            masks = []\n",
    "            for h in range(H):\n",
    "                head_scores = scores[:, h, :, :]  # (B, T, T)\n",
    "                \n",
    "                # Compute N_eff for this head across batch\n",
    "                neff = self.compute_neff(head_scores)\n",
    "                r_neff = int(torch.floor(beta * neff).clamp(1, head_scores.numel()))\n",
    "                \n",
    "                # Find threshold\n",
    "                flat_scores = head_scores.reshape(-1)\n",
    "                if r_neff >= len(flat_scores):\n",
    "                    head_mask = torch.ones_like(head_scores, dtype=torch.bool)\n",
    "                else:\n",
    "                    _, indices = torch.sort(torch.abs(flat_scores), descending=True)\n",
    "                    thresh = torch.abs(flat_scores)[indices[r_neff - 1]]\n",
    "                    head_mask = torch.abs(head_scores) >= thresh\n",
    "                \n",
    "                masks.append(head_mask.unsqueeze(1))\n",
    "            \n",
    "            mask = torch.cat(masks, dim=1)  # (B, H, T, T)\n",
    "        \n",
    "        elif mode == 'global':\n",
    "            # Apply EMP globally across all heads\n",
    "            neff = self.compute_neff(scores)\n",
    "            r_neff = int(torch.floor(beta * neff).clamp(1, scores.numel()))\n",
    "            \n",
    "            flat_scores = scores.reshape(-1)\n",
    "            if r_neff >= len(flat_scores):\n",
    "                mask = torch.ones_like(scores, dtype=torch.bool)\n",
    "            else:\n",
    "                _, indices = torch.sort(torch.abs(flat_scores), descending=True)\n",
    "                thresh = torch.abs(flat_scores)[indices[r_neff - 1]]\n",
    "                mask = torch.abs(scores) >= thresh\n",
    "        \n",
    "        else:\n",
    "            raise ValueError(f\"Unknown EMP mode: {mode}\")\n",
    "        \n",
    "        # Apply mask: set pruned entries to -inf so they become 0 after softmax\n",
    "        masked_scores = scores.clone()\n",
    "        masked_scores[~mask] = float('-inf')\n",
    "        \n",
    "        return masked_scores, mask\n",
    "    \n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        B, T, C = x.shape\n",
    "        \n",
    "        # QKV projection\n",
    "        qkv = self.c_attn(x)\n",
    "        q, k, v = qkv.split(self.n_embd, dim=2)\n",
    "        \n",
    "        # Reshape for multi-head attention\n",
    "        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, H, T, D)\n",
    "        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)\n",
    "        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)\n",
    "        \n",
    "        # Attention scores\n",
    "        scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))  # (B, H, T, T)\n",
    "        \n",
    "        # Apply causal mask\n",
    "        scores = scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))\n",
    "        \n",
    "        # Store pre-softmax scores for analysis\n",
    "        self.last_attn_scores = scores.detach().clone()\n",
    "        \n",
    "        # Apply EMP if configured\n",
    "        if self.emp_beta is not None and self.emp_beta > 0:\n",
    "            scores, emp_mask = self.apply_emp_mask(scores, self.emp_beta, self.emp_mode)\n",
    "            self.last_attn_mask = emp_mask\n",
    "        else:\n",
    "            self.last_attn_mask = None\n",
    "        \n",
    "        # Softmax and dropout\n",
    "        attn_weights = F.softmax(scores, dim=-1)\n",
    "        attn_weights = self.attn_dropout(attn_weights)\n",
    "        \n",
    "        # Apply attention\n",
    "        out = attn_weights @ v  # (B, H, T, D)\n",
    "        out = out.transpose(1, 2).contiguous().view(B, T, C)\n",
    "        \n",
    "        # Output projection\n",
    "        out = self.resid_dropout(self.c_proj(out))\n",
    "        \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mlp",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    \"\"\"Feed-forward network.\"\"\"\n",
    "    \n",
    "    def __init__(self, config: TransformerConfig):\n",
    "        super().__init__()\n",
    "        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)\n",
    "        self.gelu = nn.GELU()\n",
    "        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)\n",
    "        self.dropout = nn.Dropout(config.dropout)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.c_fc(x)\n",
    "        x = self.gelu(x)\n",
    "        x = self.c_proj(x)\n",
    "        x = self.dropout(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    \"\"\"Transformer block with pre-LayerNorm.\"\"\"\n",
    "    \n",
    "    def __init__(self, config: TransformerConfig):\n",
    "        super().__init__()\n",
    "        self.ln_1 = nn.LayerNorm(config.n_embd)\n",
    "        self.attn = CausalSelfAttention(config)\n",
    "        self.ln_2 = nn.LayerNorm(config.n_embd)\n",
    "        self.mlp = MLP(config)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = x + self.attn(self.ln_1(x))\n",
    "        x = x + self.mlp(self.ln_2(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "gpt-model",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT(nn.Module):\n",
    "    \n",
    "    def __init__(self, config: TransformerConfig):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        \n",
    "        self.transformer = nn.ModuleDict(dict(\n",
    "            wte = nn.Embedding(config.vocab_size, config.n_embd),\n",
    "            wpe = nn.Embedding(config.block_size, config.n_embd),\n",
    "            drop = nn.Dropout(config.dropout),\n",
    "            h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]),\n",
    "            ln_f = nn.LayerNorm(config.n_embd),\n",
    "        ))\n",
    "        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
    "        \n",
    "        # Weight tying\n",
    "        self.transformer.wte.weight = self.lm_head.weight\n",
    "        \n",
    "        # Initialize weights\n",
    "        self.apply(self._init_weights)\n",
    "        \n",
    "        # Count parameters\n",
    "        n_params = sum(p.numel() for p in self.parameters())\n",
    "        print(f\"Model parameters: {n_params/1e6:.2f}M\")\n",
    "    \n",
    "    def _init_weights(self, module):\n",
    "        if isinstance(module, nn.Linear):\n",
    "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
    "            if module.bias is not None:\n",
    "                torch.nn.init.zeros_(module.bias)\n",
    "        elif isinstance(module, nn.Embedding):\n",
    "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
    "    \n",
    "    def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):\n",
    "        B, T = idx.shape\n",
    "        assert T <= self.config.block_size, f\"Sequence length {T} > block_size {self.config.block_size}\"\n",
    "        \n",
    "        # Embeddings\n",
    "        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)\n",
    "        tok_emb = self.transformer.wte(idx)  # (B, T, C)\n",
    "        pos_emb = self.transformer.wpe(pos)  # (T, C)\n",
    "        x = self.transformer.drop(tok_emb + pos_emb)\n",
    "        \n",
    "        # Transformer blocks\n",
    "        for block in self.transformer.h:\n",
    "            x = block(x)\n",
    "        \n",
    "        x = self.transformer.ln_f(x)\n",
    "        \n",
    "        # Output\n",
    "        logits = self.lm_head(x)  # (B, T, vocab_size)\n",
    "        \n",
    "        loss = None\n",
    "        if targets is not None:\n",
    "            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))\n",
    "        \n",
    "        return logits, loss\n",
    "    \n",
    "    def set_emp(self, beta: Optional[float], mode: str = 'per_head'):\n",
    "        \"\"\"Set EMP pruning for all attention layers.\n",
    "        \n",
    "        Args:\n",
    "            beta: EMP coefficient. None or 0 to disable.\n",
    "            mode: 'per_head' or 'global'\n",
    "        \"\"\"\n",
    "        for block in self.transformer.h:\n",
    "            block.attn.emp_beta = beta\n",
    "            block.attn.emp_mode = mode\n",
    "    \n",
    "    def get_attention_stats(self) -> Dict:\n",
    "        \"\"\"Get attention statistics from all layers.\"\"\"\n",
    "        stats = {}\n",
    "        for i, block in enumerate(self.transformer.h):\n",
    "            attn = block.attn\n",
    "            if attn.last_attn_scores is not None:\n",
    "                scores = attn.last_attn_scores\n",
    "                stats[f'layer_{i}'] = {\n",
    "                    'scores_mean': scores.mean().item(),\n",
    "                    'scores_std': scores.std().item(),\n",
    "                    'scores_min': scores[scores != float('-inf')].min().item() if (scores != float('-inf')).any() else 0,\n",
    "                    'scores_max': scores[scores != float('-inf')].max().item() if (scores != float('-inf')).any() else 0,\n",
    "                }\n",
    "                if attn.last_attn_mask is not None:\n",
    "                    mask = attn.last_attn_mask\n",
    "                    stats[f'layer_{i}']['sparsity'] = 1.0 - mask.float().mean().item()\n",
    "        return stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "build-model",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build model\n",
    "model = GPT(config).to(device)\n",
    "print(f\"\\nModel architecture:\")\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "training-header",
   "metadata": {},
   "source": [
    "## 5. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "training-utils",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def evaluate(model: nn.Module, loader: DataLoader) -> Tuple[float, float]:\n",
    "    \"\"\"Evaluate model and return average loss and perplexity.\"\"\"\n",
    "    model.eval()\n",
    "    total_loss = 0.0\n",
    "    total_tokens = 0\n",
    "    \n",
    "    for x, y in loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        _, loss = model(x, y)\n",
    "        total_loss += loss.item() * y.numel()\n",
    "        total_tokens += y.numel()\n",
    "    \n",
    "    avg_loss = total_loss / total_tokens\n",
    "    perplexity = math.exp(avg_loss)\n",
    "    \n",
    "    return avg_loss, perplexity\n",
    "\n",
    "\n",
    "def train_epoch(model: nn.Module, loader: DataLoader, optimizer, scheduler=None) -> float:\n",
    "    \"\"\"Train for one epoch.\"\"\"\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    total_tokens = 0\n",
    "    \n",
    "    for x, y in loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        \n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "        _, loss = model(x, y)\n",
    "        loss.backward()\n",
    "        \n",
    "        # Gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        total_loss += loss.item() * y.numel()\n",
    "        total_tokens += y.numel()\n",
    "    \n",
    "    if scheduler is not None:\n",
    "        scheduler.step()\n",
    "    \n",
    "    return total_loss / total_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "train-model",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, \n",
    "                config: TransformerConfig, save_path: str):\n",
    "    \"\"\"Full training loop.\"\"\"\n",
    "    \n",
    "    # Optimizer\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        model.parameters(),\n",
    "        lr=config.lr,\n",
    "        weight_decay=config.weight_decay,\n",
    "        betas=(0.9, 0.95)\n",
    "    )\n",
    "    \n",
    "    # Scheduler\n",
    "    warmup_scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, \n",
    "                                 total_iters=config.warmup_epochs)\n",
    "    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs - config.warmup_epochs)\n",
    "    scheduler = SequentialLR(optimizer, [warmup_scheduler, cosine_scheduler], \n",
    "                              milestones=[config.warmup_epochs])\n",
    "    \n",
    "    best_val_ppl = float('inf')\n",
    "    \n",
    "    print(f\"\\nStarting training for {config.epochs} epochs...\")\n",
    "    print(\"=\" * 80)\n",
    "    \n",
    "    for epoch in range(config.epochs):\n",
    "        start_time = time.time()\n",
    "        \n",
    "        # Train\n",
    "        train_loss = train_epoch(model, train_loader, optimizer, scheduler)\n",
    "        train_ppl = math.exp(train_loss)\n",
    "        \n",
    "        # Evaluate\n",
    "        val_loss, val_ppl = evaluate(model, val_loader)\n",
    "        \n",
    "        elapsed = time.time() - start_time\n",
    "        lr = optimizer.param_groups[0]['lr']\n",
    "        \n",
    "        # Save best model\n",
    "        if val_ppl < best_val_ppl:\n",
    "            best_val_ppl = val_ppl\n",
    "            os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "            torch.save({\n",
    "                'model': model.state_dict(),\n",
    "                'config': config.__dict__,\n",
    "                'epoch': epoch,\n",
    "                'val_ppl': val_ppl,\n",
    "            }, save_path)\n",
    "            marker = \" *\"\n",
    "        else:\n",
    "            marker = \"\"\n",
    "        \n",
    "        print(f\"Epoch {epoch+1:3d}/{config.epochs} | \"\n",
    "              f\"Train Loss: {train_loss:.4f} PPL: {train_ppl:8.2f} | \"\n",
    "              f\"Val Loss: {val_loss:.4f} PPL: {val_ppl:8.2f} | \"\n",
    "              f\"LR: {lr:.2e} | Time: {elapsed:.1f}s{marker}\")\n",
    "    \n",
    "    print(\"=\" * 80)\n",
    "    print(f\"Best validation perplexity: {best_val_ppl:.2f}\")\n",
    "    \n",
    "    return best_val_ppl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "run-training",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training\n",
    "save_path = \"./checkpoints/GPT_WikiText2_best.pth\"\n",
    "\n",
    "# Check if already trained\n",
    "if os.path.exists(save_path):\n",
    "    print(f\"Loading pre-trained model from {save_path}\")\n",
    "    checkpoint = torch.load(save_path, map_location=device)\n",
    "    model.load_state_dict(checkpoint['model'])\n",
    "    print(f\"Loaded model with val PPL: {checkpoint['val_ppl']:.2f}\")\n",
    "else:\n",
    "    best_ppl = train_model(model, train_loader, val_loader, config, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "test-baseline",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate baseline (no EMP)\n",
    "model.set_emp(None)\n",
    "val_loss, val_ppl = evaluate(model, val_loader)\n",
    "test_loss, test_ppl = evaluate(model, test_loader)\n",
    "\n",
    "print(f\"\\n{'='*60}\")\n",
    "print(f\"BASELINE RESULTS (No EMP)\")\n",
    "print(f\"{'='*60}\")\n",
    "print(f\"Validation Loss: {val_loss:.4f} | Perplexity: {val_ppl:.2f}\")\n",
    "print(f\"Test Loss:       {test_loss:.4f} | Perplexity: {test_ppl:.2f}\")\n",
    "print(f\"{'='*60}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "emp-header",
   "metadata": {},
   "source": [
    "## 6. EMP Attention Pruning Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "emp-analysis",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_attention_emp(model: nn.Module, loader: DataLoader, beta: float, mode: str) -> Dict:\n",
    "    \"\"\"Analyze EMP pruning effect on attention.\"\"\"\n",
    "    model.eval()\n",
    "    model.set_emp(beta, mode)\n",
    "    \n",
    "    # Get one batch to analyze\n",
    "    x, y = next(iter(loader))\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        _ = model(x)\n",
    "    \n",
    "    stats = model.get_attention_stats()\n",
    "    \n",
    "    # Compute average sparsity across layers\n",
    "    sparsities = [s.get('sparsity', 0) for s in stats.values()]\n",
    "    avg_sparsity = sum(sparsities) / len(sparsities) if sparsities else 0\n",
    "    \n",
    "    return {\n",
    "        'layer_stats': stats,\n",
    "        'avg_sparsity': avg_sparsity\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "emp-experiments",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_emp_experiments(model: nn.Module, val_loader: DataLoader, test_loader: DataLoader,\n",
    "                        beta_list: List[float], modes: List[str]):\n",
    "    \"\"\"Run EMP experiments with different beta values and modes.\"\"\"\n",
    "    \n",
    "    results = []\n",
    "    \n",
    "    # Baseline\n",
    "    model.set_emp(None)\n",
    "    baseline_val_loss, baseline_val_ppl = evaluate(model, val_loader)\n",
    "    baseline_test_loss, baseline_test_ppl = evaluate(model, test_loader)\n",
    "    \n",
    "    results.append({\n",
    "        'mode': 'baseline',\n",
    "        'beta': 0,\n",
    "        'val_loss': baseline_val_loss,\n",
    "        'val_ppl': baseline_val_ppl,\n",
    "        'test_loss': baseline_test_loss,\n",
    "        'test_ppl': baseline_test_ppl,\n",
    "        'sparsity': 0.0,\n",
    "        'delta_val_ppl': 0.0,\n",
    "        'delta_test_ppl': 0.0,\n",
    "    })\n",
    "    \n",
    "    print(f\"\\n{'='*100}\")\n",
    "    print(f\"EMP ATTENTION PRUNING EXPERIMENTS\")\n",
    "    print(f\"{'='*100}\")\n",
    "    print(f\"Baseline - Val PPL: {baseline_val_ppl:.2f} | Test PPL: {baseline_test_ppl:.2f}\")\n",
    "    print(f\"{'='*100}\")\n",
    "    \n",
    "    for mode in modes:\n",
    "        print(f\"\\n--- Mode: {mode.upper()} ---\")\n",
    "        print(f\"{'Beta':>8} | {'Val Loss':>10} | {'Val PPL':>10} | {'Test Loss':>10} | {'Test PPL':>10} | \"\n",
    "              f\"{'Sparsity':>10} | {'ΔPPL (Val)':>12} | {'ΔPPL (Test)':>12}\")\n",
    "        print(\"-\" * 110)\n",
    "        \n",
    "        for beta in beta_list:\n",
    "            # Set EMP\n",
    "            model.set_emp(beta, mode)\n",
    "            \n",
    "            # Evaluate\n",
    "            val_loss, val_ppl = evaluate(model, val_loader)\n",
    "            test_loss, test_ppl = evaluate(model, test_loader)\n",
    "            \n",
    "            # Analyze sparsity\n",
    "            analysis = analyze_attention_emp(model, val_loader, beta, mode)\n",
    "            sparsity = analysis['avg_sparsity']\n",
    "            \n",
    "            delta_val = val_ppl - baseline_val_ppl\n",
    "            delta_test = test_ppl - baseline_test_ppl\n",
    "            \n",
    "            results.append({\n",
    "                'mode': mode,\n",
    "                'beta': beta,\n",
    "                'val_loss': val_loss,\n",
    "                'val_ppl': val_ppl,\n",
    "                'test_loss': test_loss,\n",
    "                'test_ppl': test_ppl,\n",
    "                'sparsity': sparsity,\n",
    "                'delta_val_ppl': delta_val,\n",
    "                'delta_test_ppl': delta_test,\n",
    "            })\n",
    "            \n",
    "            print(f\"{beta:8.2f} | {val_loss:10.4f} | {val_ppl:10.2f} | {test_loss:10.4f} | {test_ppl:10.2f} | \"\n",
    "                  f\"{sparsity*100:9.2f}% | {delta_val:+12.2f} | {delta_test:+12.2f}\")\n",
    "    \n",
    "    # Reset to baseline\n",
    "    model.set_emp(None)\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "run-experiments",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run experiments\n",
    "beta_list = [0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0]\n",
    "modes = ['per_head', 'global']\n",
    "\n",
    "results = run_emp_experiments(model, val_loader, test_loader, beta_list, modes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "results-table",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create results DataFrame\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame(results)\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"SUMMARY TABLE\")\n",
    "print(\"=\"*80)\n",
    "print(df.to_string(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "visualization",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Plot results\n",
    "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "\n",
    "# Colors for modes\n",
    "colors = {'per_head': 'blue', 'global': 'red'}\n",
    "markers = {'per_head': 'o', 'global': 's'}\n",
    "\n",
    "for mode in modes:\n",
    "    mode_df = df[df['mode'] == mode]\n",
    "    \n",
    "    # Plot 1: PPL vs Beta\n",
    "    axes[0].plot(mode_df['beta'], mode_df['test_ppl'], \n",
    "                 color=colors[mode], marker=markers[mode], label=mode)\n",
    "\n",
    "# Add baseline\n",
    "baseline_ppl = df[df['mode'] == 'baseline']['test_ppl'].values[0]\n",
    "axes[0].axhline(y=baseline_ppl, color='green', linestyle='--', label='baseline')\n",
    "axes[0].set_xlabel('Beta')\n",
    "axes[0].set_ylabel('Test Perplexity')\n",
    "axes[0].set_title('Test Perplexity vs Beta')\n",
    "axes[0].legend()\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "for mode in modes:\n",
    "    mode_df = df[df['mode'] == mode]\n",
    "    \n",
    "    # Plot 2: Sparsity vs Beta\n",
    "    axes[1].plot(mode_df['beta'], mode_df['sparsity'] * 100, \n",
    "                 color=colors[mode], marker=markers[mode], label=mode)\n",
    "\n",
    "axes[1].set_xlabel('Beta')\n",
    "axes[1].set_ylabel('Attention Sparsity (%)')\n",
    "axes[1].set_title('Attention Sparsity vs Beta')\n",
    "axes[1].legend()\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "for mode in modes:\n",
    "    mode_df = df[df['mode'] == mode]\n",
    "    \n",
    "    # Plot 3: Delta PPL vs Sparsity\n",
    "    axes[2].plot(mode_df['sparsity'] * 100, mode_df['delta_test_ppl'], \n",
    "                 color=colors[mode], marker=markers[mode], label=mode)\n",
    "\n",
    "axes[2].axhline(y=0, color='green', linestyle='--', label='baseline')\n",
    "axes[2].set_xlabel('Attention Sparsity (%)')\n",
    "axes[2].set_ylabel('ΔPPL (Test)')\n",
    "axes[2].set_title('PPL Change vs Sparsity')\n",
    "axes[2].legend()\n",
    "axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./checkpoints/emp_attention_results.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "print(\"\\nPlot saved to ./checkpoints/emp_attention_results.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "layer-analysis-header",
   "metadata": {},
   "source": [
    "## 7. Per-Layer Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "layer-analysis",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_per_layer_emp(model: nn.Module, loader: DataLoader, beta: float = 1.0):\n",
    "    \"\"\"Detailed per-layer analysis of EMP attention pruning.\"\"\"\n",
    "    model.eval()\n",
    "    model.set_emp(beta, 'per_head')\n",
    "    \n",
    "    # Get one batch\n",
    "    x, y = next(iter(loader))\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        _ = model(x)\n",
    "    \n",
    "    print(f\"\\n{'='*80}\")\n",
    "    print(f\"PER-LAYER ATTENTION ANALYSIS (beta={beta})\")\n",
    "    print(f\"{'='*80}\")\n",
    "    \n",
    "    for i, block in enumerate(model.transformer.h):\n",
    "        attn = block.attn\n",
    "        scores = attn.last_attn_scores\n",
    "        mask = attn.last_attn_mask\n",
    "        \n",
    "        if scores is not None:\n",
    "            # Compute statistics per head\n",
    "            B, H, T, T2 = scores.shape\n",
    "            \n",
    "            print(f\"\\nLayer {i}:\")\n",
    "            print(f\"  Shape: {scores.shape}\")\n",
    "            \n",
    "            for h in range(H):\n",
    "                head_scores = scores[:, h, :, :]\n",
    "                valid_scores = head_scores[head_scores != float('-inf')]\n",
    "                \n",
    "                # Compute N_eff for this head\n",
    "                s_abs = torch.abs(valid_scores)\n",
    "                s_sum = s_abs.sum()\n",
    "                if s_sum > 1e-10:\n",
    "                    omega = s_abs / s_sum\n",
    "                    neff = (1.0 / (omega ** 2).sum()).item()\n",
    "                else:\n",
    "                    neff = 0\n",
    "                \n",
    "                if mask is not None:\n",
    "                    head_sparsity = 1.0 - mask[:, h, :, :].float().mean().item()\n",
    "                else:\n",
    "                    head_sparsity = 0\n",
    "                \n",
    "                print(f\"  Head {h}: N_eff={neff:8.1f} | Sparsity={head_sparsity*100:5.1f}% | \"\n",
    "                      f\"Mean={valid_scores.mean().item():7.3f} | Std={valid_scores.std().item():7.3f}\")\n",
    "\n",
    "# Run analysis\n",
    "analyze_per_layer_emp(model, val_loader, beta=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "neff-distribution",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_neff_distribution(model: nn.Module, loader: DataLoader):\n",
    "    \"\"\"Plot N_eff distribution across layers and heads.\"\"\"\n",
    "    model.eval()\n",
    "    model.set_emp(None)  # No pruning, just analyze\n",
    "    \n",
    "    x, y = next(iter(loader))\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        _ = model(x)\n",
    "    \n",
    "    n_layers = len(model.transformer.h)\n",
    "    n_heads = model.config.n_head\n",
    "    \n",
    "    neff_matrix = torch.zeros(n_layers, n_heads)\n",
    "    \n",
    "    for i, block in enumerate(model.transformer.h):\n",
    "        scores = block.attn.last_attn_scores\n",
    "        if scores is None:\n",
    "            continue\n",
    "            \n",
    "        for h in range(n_heads):\n",
    "            head_scores = scores[:, h, :, :]\n",
    "            valid_scores = head_scores[head_scores != float('-inf')]\n",
    "            \n",
    "            s_abs = torch.abs(valid_scores)\n",
    "            s_sum = s_abs.sum()\n",
    "            if s_sum > 1e-10:\n",
    "                omega = s_abs / s_sum\n",
    "                neff = (1.0 / (omega ** 2).sum()).item()\n",
    "            else:\n",
    "                neff = 0\n",
    "            \n",
    "            neff_matrix[i, h] = neff\n",
    "    \n",
    "    # Plot heatmap\n",
    "    fig, ax = plt.subplots(figsize=(10, 6))\n",
    "    im = ax.imshow(neff_matrix.cpu().numpy(), cmap='viridis', aspect='auto')\n",
    "    \n",
    "    ax.set_xlabel('Head')\n",
    "    ax.set_ylabel('Layer')\n",
    "    ax.set_title('N_eff Distribution Across Layers and Heads')\n",
    "    ax.set_xticks(range(n_heads))\n",
    "    ax.set_yticks(range(n_layers))\n",
    "    \n",
    "    cbar = plt.colorbar(im, ax=ax)\n",
    "    cbar.set_label('N_eff')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('./checkpoints/neff_distribution.png', dpi=150, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    return neff_matrix\n",
    "\n",
    "neff_matrix = plot_neff_distribution(model, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ablation-header",
   "metadata": {},
   "source": [
    "## 8. Extended Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fine-grained-beta",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fine-grained beta sweep around the optimal region\n",
    "fine_beta_list = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]\n",
    "\n",
    "print(\"\\n\" + \"=\"*100)\n",
    "print(\"FINE-GRAINED BETA ABLATION (Per-Head Mode)\")\n",
    "print(\"=\"*100)\n",
    "print(f\"{'Beta':>8} | {'Val Loss':>10} | {'Val PPL':>10} | {'Test Loss':>10} | {'Test PPL':>10} | \"\n",
    "      f\"{'Sparsity':>10} | {'ΔPPL (Val)':>12} | {'ΔPPL (Test)':>12}\")\n",
    "print(\"-\" * 110)\n",
    "\n",
    "# Get baseline\n",
    "model.set_emp(None)\n",
    "baseline_val_loss, baseline_val_ppl = evaluate(model, val_loader)\n",
    "baseline_test_loss, baseline_test_ppl = evaluate(model, test_loader)\n",
    "\n",
    "fine_results = []\n",
    "for beta in fine_beta_list:\n",
    "    model.set_emp(beta, 'per_head')\n",
    "    \n",
    "    val_loss, val_ppl = evaluate(model, val_loader)\n",
    "    test_loss, test_ppl = evaluate(model, test_loader)\n",
    "    \n",
    "    analysis = analyze_attention_emp(model, val_loader, beta, 'per_head')\n",
    "    sparsity = analysis['avg_sparsity']\n",
    "    \n",
    "    delta_val = val_ppl - baseline_val_ppl\n",
    "    delta_test = test_ppl - baseline_test_ppl\n",
    "    \n",
    "    fine_results.append({\n",
    "        'beta': beta,\n",
    "        'val_ppl': val_ppl,\n",
    "        'test_ppl': test_ppl,\n",
    "        'sparsity': sparsity,\n",
    "        'delta_val_ppl': delta_val,\n",
    "        'delta_test_ppl': delta_test,\n",
    "    })\n",
    "    \n",
    "    print(f\"{beta:8.2f} | {val_loss:10.4f} | {val_ppl:10.2f} | {test_loss:10.4f} | {test_ppl:10.2f} | \"\n",
    "          f\"{sparsity*100:9.2f}% | {delta_val:+12.2f} | {delta_test:+12.2f}\")\n",
    "\n",
    "model.set_emp(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "save-final-results",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save all results\n",
    "all_results = {\n",
    "    'config': config.__dict__,\n",
    "    'baseline_val_ppl': baseline_val_ppl,\n",
    "    'baseline_test_ppl': baseline_test_ppl,\n",
    "    'main_experiments': results,\n",
    "    'fine_grained_ablation': fine_results,\n",
    "    'neff_matrix': neff_matrix.cpu().numpy().tolist(),\n",
    "}\n",
    "\n",
    "import json\n",
    "with open('./checkpoints/emp_attention_results.json', 'w') as f:\n",
    "    json.dump(all_results, f, indent=2)\n",
    "\n",
    "print(\"\\nResults saved to ./checkpoints/emp_attention_results.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "summary-header",
   "metadata": {},
   "source": [
    "## 9. Summary and Conclusions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "summary",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"EXPERIMENT SUMMARY\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "print(f\"\\nModel Configuration:\")\n",
    "print(f\"  - Layers: {config.n_layer}\")\n",
    "print(f\"  - Heads: {config.n_head}\")\n",
    "print(f\"  - Embedding dim: {config.n_embd}\")\n",
    "print(f\"  - Context length: {config.block_size}\")\n",
    "print(f\"  - Vocab size: {config.vocab_size}\")\n",
    "\n",
    "print(f\"\\nBaseline Performance:\")\n",
    "print(f\"  - Val PPL: {baseline_val_ppl:.2f}\")\n",
    "print(f\"  - Test PPL: {baseline_test_ppl:.2f}\")\n",
    "\n",
    "# Find optimal beta for each mode\n",
    "df_results = pd.DataFrame(results)\n",
    "for mode in modes:\n",
    "    mode_df = df_results[df_results['mode'] == mode]\n",
    "    if len(mode_df) > 0:\n",
    "        # Find beta with minimal PPL increase while having some sparsity\n",
    "        mode_df_sparse = mode_df[mode_df['sparsity'] > 0.01]\n",
    "        if len(mode_df_sparse) > 0:\n",
    "            best_row = mode_df_sparse.loc[mode_df_sparse['delta_test_ppl'].idxmin()]\n",
    "            print(f\"\\nBest {mode.upper()} setting:\")\n",
    "            print(f\"  - Beta: {best_row['beta']:.2f}\")\n",
    "            print(f\"  - Sparsity: {best_row['sparsity']*100:.1f}%\")\n",
    "            print(f\"  - Test PPL: {best_row['test_ppl']:.2f} (Δ={best_row['delta_test_ppl']:+.2f})\")\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"KEY FINDINGS:\")\n",
    "print(\"=\"*80)\n",
    "print(\"\"\"\n",
    "1. EMP can be effectively applied to transformer attention scores (pre-softmax)\n",
    "2. Per-head EMP allows independent pruning of each attention head\n",
    "3. Global EMP applies uniform pruning across all heads\n",
    "4. Beta ≈ 1.0 typically provides a good balance between sparsity and performance\n",
    "5. Lower beta values increase sparsity but may degrade performance\n",
    "6. Higher beta values (>1.5) retain most attention patterns with minimal sparsity\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "final-cleanup",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final model state\n",
    "model.set_emp(None)\n",
    "print(\"\\nModel reset to baseline (no EMP pruning)\")\n",
    "print(f\"\\nAll outputs saved to ./checkpoints/\")\n",
    "print(\"  - GPT_WikiText2_best.pth (model checkpoint)\")\n",
    "print(\"  - emp_attention_results.json (experiment results)\")\n",
    "print(\"  - emp_attention_results.png (visualization)\")\n",
    "print(\"  - neff_distribution.png (N_eff heatmap)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
