{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "21340c9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sae\n",
    "import importlib\n",
    "import jax\n",
    "import optax\n",
    "import equinox as eqx\n",
    "import jax.numpy as jnp\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675950ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b\")\n",
    "vocab = tokenizer.get_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8eca61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_generator(data, batch_size):\n",
    "    num_samples = data.shape[0]\n",
    "    key = jax.random.PRNGKey(0)\n",
    "\n",
    "    while True:\n",
    "        key, subkey = jax.random.split(key)\n",
    "        indices = jax.random.permutation(subkey, num_samples)\n",
    "        \n",
    "        for start in range(0, num_samples, batch_size):\n",
    "            end = min(start + batch_size, num_samples)\n",
    "            batch_indices = indices[start:end]\n",
    "            batch = data[batch_indices]\n",
    "            \n",
    "            # # Normalize the batch\n",
    "            batch = batch - jnp.mean(batch, axis=1, keepdims=True)\n",
    "            norms = jnp.linalg.norm(batch, axis=1, keepdims=True)\n",
    "            batch = batch / (norms + 1e-8)  # Add small epsilon to avoid division by zero\n",
    "            \n",
    "            yield batch\n",
    "\n",
    "def create_data_loader(data, batch_size):\n",
    "    return data_generator(data, batch_size)\n",
    "\n",
    "# Load unembeddings\n",
    "g = jnp.load(DIR)\n",
    "\n",
    "g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # scaling for norm\n",
    "\n",
    "batch_size = 8192\n",
    "train_loader = create_data_loader(g, batch_size)\n",
    "example_batch = next(train_loader)\n",
    "input_dim = example_batch[0].shape[0]\n",
    "latent_dim = 2**12\n",
    "key = jax.random.PRNGKey(0)\n",
    "k = 5\n",
    "model = sae.Autoencoder(latent_dim, input_dim, use_bias=True, key=key, k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9df37db0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8192, 2304)\n"
     ]
    }
   ],
   "source": [
    "print(next(train_loader).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "36aa2242",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-14 21:58:35.430120: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1061] Compiling 128 configs for 4 fusions on a single thread.\n",
      "2025-05-14 21:59:05.626423: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1061] Compiling 128 configs for 4 fusions on a single thread.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4, Step 100: Loss = 0.8028\n",
      "Epoch 7, Step 200: Loss = 0.6944\n",
      "Epoch 10, Step 300: Loss = 0.6479\n",
      "Epoch 13, Step 400: Loss = 0.6260\n",
      "Epoch 17, Step 500: Loss = 0.6162\n",
      "Epoch 20, Step 600: Loss = 0.6082\n",
      "Epoch 23, Step 700: Loss = 0.6009\n",
      "Epoch 26, Step 800: Loss = 0.5942\n",
      "Epoch 30, Step 900: Loss = 0.5915\n",
      "Epoch 33, Step 1000: Loss = 0.5914\n",
      "Epoch 36, Step 1100: Loss = 0.5889\n",
      "Epoch 39, Step 1200: Loss = 0.5907\n",
      "Epoch 42, Step 1300: Loss = 0.5900\n",
      "Epoch 46, Step 1400: Loss = 0.5865\n",
      "Epoch 49, Step 1500: Loss = 0.5865\n",
      "Epoch 52, Step 1600: Loss = 0.5798\n",
      "Epoch 55, Step 1700: Loss = 0.5828\n",
      "Epoch 59, Step 1800: Loss = 0.5807\n",
      "Epoch 62, Step 1900: Loss = 0.5835\n",
      "Epoch 65, Step 2000: Loss = 0.5821\n",
      "Epoch 68, Step 2100: Loss = 0.5813\n",
      "Epoch 71, Step 2200: Loss = 0.5817\n",
      "Epoch 75, Step 2300: Loss = 0.5802\n",
      "Epoch 78, Step 2400: Loss = 0.5774\n",
      "Epoch 81, Step 2500: Loss = 0.5738\n",
      "Epoch 84, Step 2600: Loss = 0.5772\n",
      "Epoch 88, Step 2700: Loss = 0.5763\n",
      "Epoch 91, Step 2800: Loss = 0.5752\n",
      "Epoch 94, Step 2900: Loss = 0.5774\n",
      "Epoch 97, Step 3000: Loss = 0.5790\n",
      "Epoch 100, Step 3100: Loss = 0.5796\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 100\n",
    "learning_rate = 3e-4 # kaparthy constant\n",
    "\n",
    "optimizer = optax.adam(learning_rate)\n",
    "\n",
    "opt_state = optimizer.init(eqx.filter(model, eqx.is_array))\n",
    "\n",
    "step = 0\n",
    "steps_per_epoch =  g.shape[0] // batch_size\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for step_in_epoch in range(steps_per_epoch):\n",
    "        batch = next(train_loader)\n",
    "        model, opt_state, loss = sae.train_step(model, batch, opt_state, optimizer)\n",
    "        step += 1\n",
    "        if step % 100 == 0:\n",
    "            print(f\"Epoch {epoch + 1}, Step {step}: Loss = {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ff439933",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_list = [None] * (max(vocab.values()) + 1)\n",
    "for word, index in vocab.items():\n",
    "    vocab_list[index] = word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "39b7694e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building SAE Activation Cache:   0%|          | 0/32 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building SAE Activation Cache: 100%|██████████| 32/32 [00:06<00:00,  5.16it/s]\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "class SAEActivationCache:\n",
    "    def __init__(self, vocab_size, latent_dim, max_examples_per_latent=10):\n",
    "        self.vocab_size = vocab_size\n",
    "        self.latent_dim = latent_dim\n",
    "        self.max_examples_per_latent = max_examples_per_latent\n",
    "\n",
    "        self.activations = np.full((latent_dim, max_examples_per_latent), -np.inf, dtype=np.float32)\n",
    "        self.words = np.full((latent_dim, max_examples_per_latent), -1, dtype=np.int32)\n",
    "\n",
    "    def update(self, batch_words, latents):\n",
    "        batch_size = latents.shape[0]\n",
    "        for i in range(batch_size):\n",
    "            word_idx = int(batch_words[i])\n",
    "            top_latent = int(np.argmax(latents[i]))\n",
    "            activation = float(latents[i, top_latent])\n",
    "\n",
    "            current_acts = self.activations[top_latent]\n",
    "            min_idx = np.argmin(current_acts)\n",
    "            if activation > current_acts[min_idx]:\n",
    "                self.activations[top_latent, min_idx] = activation\n",
    "                self.words[top_latent, min_idx] = word_idx\n",
    "\n",
    "def build_sae_activation_cache(dataset, model, batch_size, vocab_embeddings, vocab_indices, latent_dim, max_examples_per_latent=10):\n",
    "    activation_cache = SAEActivationCache(\n",
    "        vocab_size=len(vocab_list),\n",
    "        latent_dim=latent_dim,\n",
    "        max_examples_per_latent=max_examples_per_latent\n",
    "    )\n",
    "\n",
    "    num_batches = int(np.ceil(len(dataset) / batch_size))\n",
    "\n",
    "    for batch_start in tqdm.tqdm(range(0, len(dataset), batch_size), total=num_batches, desc=\"Building SAE Activation Cache\"):\n",
    "        batch_end = min(batch_start + batch_size, len(dataset))\n",
    "        batch = vocab_embeddings[vocab_indices[batch_start:batch_end]]\n",
    "\n",
    "        latents = jax.vmap(model.top_k_encode)(batch)\n",
    "        latents = np.array(latents)\n",
    "        batch_words = np.array(vocab_indices[batch_start:batch_end])\n",
    "\n",
    "        activation_cache.update(batch_words=batch_words, latents=latents)\n",
    "\n",
    "    return activation_cache\n",
    "\n",
    "activation_cache = build_sae_activation_cache(\n",
    "    dataset=g,\n",
    "    model=model,\n",
    "    batch_size=batch_size,\n",
    "    vocab_embeddings=g,\n",
    "    vocab_indices=np.arange(g.shape[0]),\n",
    "    latent_dim=latent_dim,\n",
    "    max_examples_per_latent=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "9e97d9a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_word(word, top_k=k):\n",
    "    word_vec = g[vocab[word]]\n",
    "    latents = np.array(model.top_k_encode(word_vec))\n",
    "\n",
    "    print(f\"\\nExplanations for word: '{word.strip()}'\\n{'='*50}\")\n",
    "\n",
    "    top_latents = np.argsort(-latents)[:top_k]\n",
    "    for idx in top_latents:\n",
    "        activation = latents[idx]\n",
    "        print(f\"\\n🔹 Feature {idx} (Activation: {activation:.4f})\")\n",
    "        print(\"  Words that maximally activate this Feature:\")\n",
    "\n",
    "        top_examples = np.argsort(-activation_cache.activations[idx])\n",
    "        top_words = [vocab_list[int(w)] for w in activation_cache.words[idx, top_examples] if w >= 0]\n",
    "        print(f\"   {top_words}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "d5e4bea3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'puppy'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2758 (Activation: 7.4287)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁dog', '▁dogs', '▁Dog', 'Dog', 'dog', '▁Dogs', '▁DOG', 'Dogs', 'DOG', 'dogs']\n",
      "\n",
      "🔹 Feature 920 (Activation: 6.8790)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['^(@)', '<unused47>', '<unused43>', '<unused80>', '<unused99>', '<unused23>', '<unused52>', '<unused42>', '<unused21>', '<unused79>']\n",
      "\n",
      "🔹 Feature 964 (Activation: 5.6121)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['suicide', 'pupil', 'mortgage', 'governance', 'religion', 'ecosystem', 'warranty', 'tournament', 'worship', 'physician']\n",
      "\n",
      "🔹 Feature 3247 (Activation: 4.9620)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁myſelf', '▁itſelf', '▁Efq', '▁Monfieur', '▁pleaſure', '▁Jefus', '▁purpoſe', '▁Theſe', '▁Majefty', '▁Anſ']\n",
      "\n",
      "🔹 Feature 2751 (Activation: 4.2003)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Unemployment', 'Malware', 'Addiction']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"puppy\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "aa2cf007",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Queen'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2669 (Activation: 9.3682)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁prince', '▁Prince', 'Prince', '▁Queen', '▁queen', '▁Princess', '▁princess', 'prince', 'Queen', '▁PRINCE']\n",
      "\n",
      "🔹 Feature 3418 (Activation: 3.6822)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁QU', '▁Qu', 'Qu', 'QU', '▁qu', '▁Q', '▁q', 'qu', '▁Quin', '▁Que']\n",
      "\n",
      "🔹 Feature 2621 (Activation: 3.2669)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁mother', '▁Mother', 'Mother', 'mother', '▁mom', '▁mothers', '▁MOTHER', 'MOTHER', '▁Mom', 'Mom']\n",
      "\n",
      "🔹 Feature 1990 (Activation: 3.0304)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁King', '▁king', 'King', '▁KING', '▁Kings', '▁kings', 'Kings', 'king', 'kings', 'KING']\n",
      "\n",
      "🔹 Feature 920 (Activation: 2.0817)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['^(@)', '<unused47>', '<unused43>', '<unused80>', '<unused99>', '<unused23>', '<unused52>', '<unused42>', '<unused21>', '<unused79>']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Queen\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "894e334d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Chicago'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 1828 (Activation: 7.8370)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Chicago', '▁Chicago', '▁CHICAGO', 'CHICAGO', '▁chicago', '▁Detroit', 'Detroit', '▁Milwaukee', 'Milwaukee', 'chicago']\n",
      "\n",
      "🔹 Feature 2783 (Activation: 6.1555)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Bangkok', '▁Nairobi', '▁Madrid', '▁Jakarta', 'Madrid', '▁Istanbul', '▁Berlin', '▁Manila', '▁Beijing', 'Berlin']\n",
      "\n",
      "🔹 Feature 246 (Activation: 3.0201)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Texas', 'Texas', '▁Florida', '▁Tennessee', '▁Alabama', '▁Pennsylvania', 'Florida', '▁Kentucky', '▁Louisiana', 'Tennessee']\n",
      "\n",
      "🔹 Feature 1352 (Activation: 2.8685)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁CALIFORNIA', '▁Califor', 'California', '▁California', '▁Kalifor', 'CALIFORNIA', '▁Californian', '▁Calif', '▁california', 'california']\n",
      "\n",
      "🔹 Feature 487 (Activation: 2.5884)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Greensboro', '▁Chattanooga', '▁Peterborough', '▁Asheville', '▁Sarasota', '▁Tulsa', '▁Knoxville', '▁Ipswich', '▁Wichita', '▁Portsmouth']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Chicago\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d1b0db31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'London'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 3406 (Activation: 8.6483)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['London', '▁London', '▁LONDON', '▁london', 'LONDON', 'london', '▁Londres', '▁Lond', '▁Лон', '▁paris']\n",
      "\n",
      "🔹 Feature 2666 (Activation: 7.0375)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁British', 'British', '▁Britain', 'Britain', '▁UK', '▁BRITISH', '▁british', '▁brit', 'BRIT', 'UK']\n",
      "\n",
      "🔹 Feature 2783 (Activation: 6.8139)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Bangkok', '▁Nairobi', '▁Madrid', '▁Jakarta', 'Madrid', '▁Istanbul', '▁Berlin', '▁Manila', '▁Beijing', 'Berlin']\n",
      "\n",
      "🔹 Feature 2289 (Activation: 2.7443)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Roberts', 'Gordon', 'Robertson', 'Allen', 'Howard', 'Meyer', 'Kelly', 'Leslie']\n",
      "\n",
      "🔹 Feature 1943 (Activation: 2.2330)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Spain', '▁Poland', 'Spain', '▁Hungary', '▁Brazil', 'Poland', '▁España', '▁Portugal', '▁Romania', '▁Austria']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"London\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "1a33b9d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Twitter'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 3011 (Activation: 9.1133)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁tweet', '▁Tweet', '▁tweets', '▁tweeting', '▁Tweets', 'Tweet', '▁tweeted', 'tweet', 'Tweets', '▁Twe']\n",
      "\n",
      "🔹 Feature 2949 (Activation: 8.6393)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Instagram', '▁Facebook', 'Instagram', '▁instagram', 'Facebook', '▁facebook', '▁YouTube', 'instagram', '▁Youtube', 'facebook']\n",
      "\n",
      "🔹 Feature 2324 (Activation: 2.6326)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Python', '▁Java', 'Python', '▁Linux', 'Java', 'Linux', 'JAVA', '▁python', '▁MATLAB', '▁PHP']\n",
      "\n",
      "🔹 Feature 3939 (Activation: 2.5824)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Fortnite', '▁Fortnite', '▁Esports', 'Overwatch', '▁Minecraft', 'Minecraft', 'NFT', 'Elon', 'minecraft']\n",
      "\n",
      "🔹 Feature 3247 (Activation: 2.2320)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁myſelf', '▁itſelf', '▁Efq', '▁Monfieur', '▁pleaſure', '▁Jefus', '▁purpoſe', '▁Theſe', '▁Majefty', '▁Anſ']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Twitter\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "93a335c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'python'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2324 (Activation: 5.9169)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Python', '▁Java', 'Python', '▁Linux', 'Java', 'Linux', 'JAVA', '▁python', '▁MATLAB', '▁PHP']\n",
      "\n",
      "🔹 Feature 97 (Activation: 3.6341)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁pi', '▁Pi', 'pi', 'Pi', 'PI', '▁PI', '▁Py', 'Py', '▁py', '▁PY']\n",
      "\n",
      "🔹 Feature 3710 (Activation: 3.6227)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁snake', '▁Snake', 'Snake', '▁snakes', 'snake', '▁Snakes', '▁serpent', '▁Serpent', '▁serpiente', '蛇']\n",
      "\n",
      "🔹 Feature 2063 (Activation: 3.5208)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁JavaScript', 'JavaScript', 'js', '▁javascript', '▁Javascript', 'Javascript', 'javascript', '▁js', 'JS', '▁JS']\n",
      "\n",
      "🔹 Feature 993 (Activation: 2.1818)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁script', '▁Script', '▁scripts', 'script', 'Script', '▁SCRIPT', '▁Scripts', 'SCRIPT', 'scripts', 'Scripts']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"python\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "712d35d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Bayesian'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 920 (Activation: 6.8839)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['^(@)', '<unused47>', '<unused43>', '<unused80>', '<unused99>', '<unused23>', '<unused52>', '<unused42>', '<unused21>', '<unused79>']\n",
      "\n",
      "🔹 Feature 2656 (Activation: 5.2432)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Bay', '▁bay', 'Bay', '▁BAY', 'bay', 'BAY', '▁bays', '▁Bays', '▁bahía', '湾']\n",
      "\n",
      "🔹 Feature 3247 (Activation: 5.0911)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁myſelf', '▁itſelf', '▁Efq', '▁Monfieur', '▁pleaſure', '▁Jefus', '▁purpoſe', '▁Theſe', '▁Majefty', '▁Anſ']\n",
      "\n",
      "🔹 Feature 2751 (Activation: 4.6057)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Unemployment', 'Malware', 'Addiction']\n",
      "\n",
      "🔹 Feature 964 (Activation: 4.1847)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['suicide', 'pupil', 'mortgage', 'governance', 'religion', 'ecosystem', 'warranty', 'tournament', 'worship', 'physician']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Bayesian\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "b1a0e3b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'SUV'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 920 (Activation: 6.6659)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['^(@)', '<unused47>', '<unused43>', '<unused80>', '<unused99>', '<unused23>', '<unused52>', '<unused42>', '<unused21>', '<unused79>']\n",
      "\n",
      "🔹 Feature 3247 (Activation: 6.1519)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁myſelf', '▁itſelf', '▁Efq', '▁Monfieur', '▁pleaſure', '▁Jefus', '▁purpoſe', '▁Theſe', '▁Majefty', '▁Anſ']\n",
      "\n",
      "🔹 Feature 1099 (Activation: 5.0219)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['TDS', 'SRS', 'BPS', 'MBR', 'SBS', 'PMA', 'TGA', 'PPI', 'CTO', 'SLS']\n",
      "\n",
      "🔹 Feature 1089 (Activation: 4.8235)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Playlist', '▁Webinar', '▁Questionnaire', '▁Brochure', 'Webinar', '▁Checklist', '▁Newsletters', '▁Newsletter', 'Newsletter', '▁playlist']\n",
      "\n",
      "🔹 Feature 719 (Activation: 4.4372)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Metallica', 'Shrek', 'Bentley', 'Loki', 'Juventus', 'Sevilla', 'LSU', 'Ares', 'Zelda', 'Napoli']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"SUV\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "e8a18ec9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'MIT'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2997 (Activation: 8.4866)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Mit', 'mit', 'Mit', 'MIT', '▁mit', '▁Mis', '▁MIT', 'mis', 'Mis', '▁MIS']\n",
      "\n",
      "🔹 Feature 3402 (Activation: 4.3481)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Harvard', 'Harvard', '▁Yale', '▁HARVARD', 'Yale', '▁Stanford', 'Stanford', '▁Cornell', '▁Princeton', '▁UCLA']\n",
      "\n",
      "🔹 Feature 3313 (Activation: 2.8401)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁PT', 'CTT', '▁DT', 'IFT', 'cdt', 'mmt', 'ibt', 'hlt', 'ngt', 'ikt']\n",
      "\n",
      "🔹 Feature 603 (Activation: 2.5626)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['->___', '(\"/:', \"('/:\", '[++', ':\\\\/\\\\/', '▁ldc', '獷', 'җ', '▁Kelle', 'romes']\n",
      "\n",
      "🔹 Feature 1573 (Activation: 2.4520)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['ites', 'ite', 'ITA', 'ITE', 'ita', 'IT', 'ITO', 'ITES', 'iter', 'itas']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"MIT\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "4a1f5716",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'swim'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 1118 (Activation: 7.9468)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁pool', '▁Pool', '▁pools', 'Pool', 'pool', '▁Pools', '▁POOL', 'POOL', 'pools', 'Pools']\n",
      "\n",
      "🔹 Feature 920 (Activation: 5.8672)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['^(@)', '<unused47>', '<unused43>', '<unused80>', '<unused99>', '<unused23>', '<unused52>', '<unused42>', '<unused21>', '<unused79>']\n",
      "\n",
      "🔹 Feature 964 (Activation: 5.2352)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['suicide', 'pupil', 'mortgage', 'governance', 'religion', 'ecosystem', 'warranty', 'tournament', 'worship', 'physician']\n",
      "\n",
      "🔹 Feature 1381 (Activation: 5.1449)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁bath', '▁baths', '▁BATH', '▁Bath', 'bath', 'Bath', '▁Baths', '▁shower', '▁bathe', 'BATH']\n",
      "\n",
      "🔹 Feature 2056 (Activation: 4.8179)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁paddle', '▁Paddle', 'Paddle', '▁paddling', '▁kayaking', '▁padd', 'paddle', '▁kayak', '▁canoe', '▁sail']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"swim\", top_k=k)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
