{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "21340c9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sae\n",
    "import jax\n",
    "import optax\n",
    "import equinox as eqx\n",
    "import jax.numpy as jnp\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675950ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b\")\n",
    "vocab = tokenizer.get_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8eca61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_generator(data, batch_size):\n",
    "    num_samples = data.shape[0]\n",
    "    key = jax.random.PRNGKey(0)\n",
    "\n",
    "    while True:\n",
    "        key, subkey = jax.random.split(key)\n",
    "        indices = jax.random.permutation(subkey, num_samples)\n",
    "        \n",
    "        for start in range(0, num_samples, batch_size):\n",
    "            end = min(start + batch_size, num_samples)\n",
    "            batch_indices = indices[start:end]\n",
    "            batch = data[batch_indices]\n",
    "            \n",
    "            # # Normalize the batch\n",
    "            batch = batch - jnp.mean(batch, axis=1, keepdims=True)\n",
    "            norms = jnp.linalg.norm(batch, axis=1, keepdims=True)\n",
    "            batch = batch / (norms + 1e-8)  # Add small epsilon to avoid division by zero\n",
    "            \n",
    "            yield batch\n",
    "\n",
    "def create_data_loader(data, batch_size):\n",
    "    return data_generator(data, batch_size)\n",
    "\n",
    "# Load unembeddings and whiten\n",
    "g = jnp.load(DIR)\n",
    "g = g - g.mean(axis=0)\n",
    "u, s, vt = jnp.linalg.svd(g, full_matrices=False)\n",
    "eps = 1e-6  # Numerical stability to avoid division by zero\n",
    "cov = (g.T @ g) / g.shape[0]\n",
    "eigvals, eigvecs = jnp.linalg.eigh(cov)\n",
    "inv_sqrt_eigvals = 1.0 / jnp.sqrt(eigvals + eps)\n",
    "cov_inv_sqrt = eigvecs @ jnp.diag(inv_sqrt_eigvals) @ eigvecs.T\n",
    "g = g @ cov_inv_sqrt\n",
    "\n",
    "g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # scaling for norm\n",
    "\n",
    "batch_size = 8192\n",
    "train_loader = create_data_loader(g, batch_size)\n",
    "example_batch = next(train_loader)\n",
    "input_dim = example_batch[0].shape[0]\n",
    "latent_dim = 2**12\n",
    "key = jax.random.PRNGKey(0)\n",
    "k = 5\n",
    "model = sae.Autoencoder(latent_dim, input_dim, use_bias=True, key=key, k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9df37db0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8192, 2304)\n"
     ]
    }
   ],
   "source": [
    "print(next(train_loader).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "36aa2242",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4, Step 100: Loss = 0.7594\n",
      "Epoch 7, Step 200: Loss = 0.7554\n",
      "Epoch 10, Step 300: Loss = 0.7526\n",
      "Epoch 13, Step 400: Loss = 0.7556\n",
      "Epoch 17, Step 500: Loss = 0.7563\n",
      "Epoch 20, Step 600: Loss = 0.7594\n",
      "Epoch 23, Step 700: Loss = 0.7560\n",
      "Epoch 26, Step 800: Loss = 0.7551\n",
      "Epoch 30, Step 900: Loss = 0.7590\n",
      "Epoch 33, Step 1000: Loss = 0.7573\n",
      "Epoch 36, Step 1100: Loss = 0.7522\n",
      "Epoch 39, Step 1200: Loss = 0.7490\n",
      "Epoch 42, Step 1300: Loss = 0.7523\n",
      "Epoch 46, Step 1400: Loss = 0.7561\n",
      "Epoch 49, Step 1500: Loss = 0.7520\n",
      "Epoch 52, Step 1600: Loss = 0.7525\n",
      "Epoch 55, Step 1700: Loss = 0.7535\n",
      "Epoch 59, Step 1800: Loss = 0.7489\n",
      "Epoch 62, Step 1900: Loss = 0.7484\n",
      "Epoch 65, Step 2000: Loss = 0.7502\n",
      "Epoch 68, Step 2100: Loss = 0.7497\n",
      "Epoch 71, Step 2200: Loss = 0.7549\n",
      "Epoch 75, Step 2300: Loss = 0.7526\n",
      "Epoch 78, Step 2400: Loss = 0.7550\n",
      "Epoch 81, Step 2500: Loss = 0.7559\n",
      "Epoch 84, Step 2600: Loss = 0.7509\n",
      "Epoch 88, Step 2700: Loss = 0.7481\n",
      "Epoch 91, Step 2800: Loss = 0.7518\n",
      "Epoch 94, Step 2900: Loss = 0.7524\n",
      "Epoch 97, Step 3000: Loss = 0.7527\n",
      "Epoch 100, Step 3100: Loss = 0.7539\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 100\n",
    "learning_rate = 3e-4 # kaparthy constant\n",
    "\n",
    "optimizer = optax.adam(learning_rate)\n",
    "\n",
    "opt_state = optimizer.init(eqx.filter(model, eqx.is_array))\n",
    "\n",
    "step = 0\n",
    "steps_per_epoch =  g.shape[0] // batch_size\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for step_in_epoch in range(steps_per_epoch):\n",
    "        batch = next(train_loader)\n",
    "        model, opt_state, loss = sae.train_step(model, batch, opt_state, optimizer)\n",
    "        step += 1\n",
    "        if step % 100 == 0:\n",
    "            print(f\"Epoch {epoch + 1}, Step {step}: Loss = {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ff439933",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_list = [None] * (max(vocab.values()) + 1)\n",
    "for word, index in vocab.items():\n",
    "    vocab_list[index] = word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "39b7694e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building SAE Activation Cache: 100%|██████████| 32/32 [00:06<00:00,  5.00it/s]\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "class SAEActivationCache:\n",
    "    def __init__(self, vocab_size, latent_dim, max_examples_per_latent=10):\n",
    "        self.vocab_size = vocab_size\n",
    "        self.latent_dim = latent_dim\n",
    "        self.max_examples_per_latent = max_examples_per_latent\n",
    "\n",
    "        self.activations = np.full((latent_dim, max_examples_per_latent), -np.inf, dtype=np.float32)\n",
    "        self.words = np.full((latent_dim, max_examples_per_latent), -1, dtype=np.int32)\n",
    "\n",
    "    def update(self, batch_words, latents):\n",
    "        batch_size = latents.shape[0]\n",
    "        for i in range(batch_size):\n",
    "            word_idx = int(batch_words[i])\n",
    "            top_latent = int(np.argmax(latents[i]))\n",
    "            activation = float(latents[i, top_latent])\n",
    "\n",
    "            current_acts = self.activations[top_latent]\n",
    "            min_idx = np.argmin(current_acts)\n",
    "            if activation > current_acts[min_idx]:\n",
    "                self.activations[top_latent, min_idx] = activation\n",
    "                self.words[top_latent, min_idx] = word_idx\n",
    "\n",
    "def build_sae_activation_cache(dataset, model, batch_size, vocab_embeddings, vocab_indices, latent_dim, max_examples_per_latent=10):\n",
    "    activation_cache = SAEActivationCache(\n",
    "        vocab_size=len(vocab_list),\n",
    "        latent_dim=latent_dim,\n",
    "        max_examples_per_latent=max_examples_per_latent\n",
    "    )\n",
    "\n",
    "    num_batches = int(np.ceil(len(dataset) / batch_size))\n",
    "\n",
    "    for batch_start in tqdm.tqdm(range(0, len(dataset), batch_size), total=num_batches, desc=\"Building SAE Activation Cache\"):\n",
    "        batch_end = min(batch_start + batch_size, len(dataset))\n",
    "        batch = vocab_embeddings[vocab_indices[batch_start:batch_end]]\n",
    "\n",
    "        latents = jax.vmap(model.top_k_encode)(batch)\n",
    "        latents = np.array(latents)\n",
    "        batch_words = np.array(vocab_indices[batch_start:batch_end])\n",
    "\n",
    "        activation_cache.update(batch_words=batch_words, latents=latents)\n",
    "\n",
    "    return activation_cache\n",
    "\n",
    "activation_cache = build_sae_activation_cache(\n",
    "    dataset=g,\n",
    "    model=model,\n",
    "    batch_size=batch_size,\n",
    "    vocab_embeddings=g,\n",
    "    vocab_indices=np.arange(g.shape[0]),\n",
    "    latent_dim=latent_dim,\n",
    "    max_examples_per_latent=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9e97d9a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_word(word, top_k=k):\n",
    "    word_vec = g[vocab[word]]\n",
    "    latents = np.array(model.top_k_encode(word_vec))\n",
    "\n",
    "    print(f\"\\nExplanations for word: '{word.strip()}'\\n{'='*50}\")\n",
    "\n",
    "    top_latents = np.argsort(-latents)[:top_k]\n",
    "    for idx in top_latents:\n",
    "        activation = latents[idx]\n",
    "        print(f\"\\n🔹 Feature {idx} (Activation: {activation:.4f})\")\n",
    "        print(\"  Words that maximally activate this Feature:\")\n",
    "\n",
    "        top_examples = np.argsort(-activation_cache.activations[idx])\n",
    "        top_words = [vocab_list[int(w)] for w in activation_cache.words[idx, top_examples] if w >= 0]\n",
    "        print(f\"   {top_words}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d5e4bea3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'puppy'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2980 (Activation: 161.4889)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁puppy', '▁puppies', '▁Puppy', '▁Paws', '▁kitten', '▁kittens', 'Puppy', '▁Puppies', 'puppy', '▁Kennel']\n",
      "\n",
      "🔹 Feature 1149 (Activation: 137.3440)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Pup', '▁pup', 'pup', '▁Pup', '▁PUP', '▁puppet', '▁pupil', '▁Pupil', 'puppet', '▁Puppet']\n",
      "\n",
      "🔹 Feature 1577 (Activation: 101.8598)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁dog', 'dog', '▁Dog', 'Dog', '▁dogs', '▁DOG', '▁Dogs', 'dogs', 'DOG', 'Dogs']\n",
      "\n",
      "🔹 Feature 904 (Activation: 87.4508)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁baby', '▁Baby', 'baby', 'Baby', '▁BABY', '▁babies', 'BABY', '▁Babies', '▁infant', '▁bébé']\n",
      "\n",
      "🔹 Feature 3103 (Activation: 51.7833)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁childhood', '▁Childhood', 'Childhood', '▁pediatric', '▁infancia', '▁enfance', '▁kindergarten', '▁infantil', '▁Kindheit', '▁Pediatric']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"puppy\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aa2cf007",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Queen'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 418 (Activation: 343.9613)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Queen', 'Queen', '▁queen', '▁QUEEN', 'queen', '▁queens', 'QUEEN', '▁Queens', 'queens', 'Queens']\n",
      "\n",
      "🔹 Feature 2669 (Activation: 210.6978)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁King', '▁king', 'King', '▁KING', '▁kings', '▁Kings', '▁royal', 'Kings', 'KING', 'kings']\n",
      "\n",
      "🔹 Feature 3236 (Activation: 84.3711)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Q', 'Q', '▁q', 'q', 'Ｑ', 'Qs', 'QS', 'qs', 'QL', 'QN']\n",
      "\n",
      "🔹 Feature 2995 (Activation: 63.6071)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁mother', '▁Mother', 'Mother', 'mother', '▁MOTHER', 'MOTHER', '▁mothers', '▁madre', '▁Mothers', '▁moeder']\n",
      "\n",
      "🔹 Feature 2929 (Activation: 61.0109)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁girl', '▁girls', '▁Girls', 'girls', 'Girl', 'Girls', '▁GIRLS', 'GIRLS', '▁Mädchen', '▁meisje']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Queen\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "894e334d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Chicago'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 623 (Activation: 218.2256)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Detroit', '▁Chicago', 'Chicago', 'Detroit', '▁Milwaukee', '▁Cincinnati', 'Milwaukee', 'Cincinnati', '▁Pittsburgh', '▁Minneapolis']\n",
      "\n",
      "🔹 Feature 3963 (Activation: 144.1507)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁London', 'London', '▁Paris', '▁Tokyo', '▁Berlin', '▁Madrid', 'Paris', '▁Beijing', '▁Delhi', '▁Jakarta']\n",
      "\n",
      "🔹 Feature 4059 (Activation: 71.2980)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['il', 'IL', 'ils', '▁IL', '▁il', 'Il', '▁Il', 'ILS', 'ili', 'ile']\n",
      "\n",
      "🔹 Feature 1442 (Activation: 56.6519)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁American', 'American', '▁America', '▁AMERICAN', '▁american', 'american', '▁Americans', 'America', 'AMERICAN', '▁AMERICA']\n",
      "\n",
      "🔹 Feature 1460 (Activation: 44.4602)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁NY', 'NY', '▁NYC', 'NYC', '▁Ny', '▁ny', '▁NYS', '▁NYPD', 'nyc', 'Ny']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Chicago\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d1b0db31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'London'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 3963 (Activation: 227.1938)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁London', 'London', '▁Paris', '▁Tokyo', '▁Berlin', '▁Madrid', 'Paris', '▁Beijing', '▁Delhi', '▁Jakarta']\n",
      "\n",
      "🔹 Feature 2096 (Activation: 189.4191)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁British', 'British', '▁Britain', '▁UK', '▁BRITISH', 'Britain', '▁british', 'UK', '▁brit', 'british']\n",
      "\n",
      "🔹 Feature 1009 (Activation: 122.3449)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Lon', 'Lon', 'LON', '▁LON', 'lon', '▁Longitudinal', '▁lon', '▁longitudinal', '▁Longitud', '▁Lone']\n",
      "\n",
      "🔹 Feature 3464 (Activation: 85.4231)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Nottingham', '▁Leicester', '▁Bristol', '▁Manchester', '▁Essex', '▁Southampton', '▁Sheffield', '▁Coventry', '▁Leeds', 'Bristol']\n",
      "\n",
      "🔹 Feature 1008 (Activation: 49.0102)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Italy', '▁Spain', '▁France', '▁Germany', 'Italy', '▁India', 'France', 'Germany', 'Spain', '▁Poland']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"London\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1a33b9d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Twitter'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 582 (Activation: 275.5387)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁tweet', '▁tweets', '▁Tweet', '▁tweeting', '▁tweeted', '▁Tweets', 'Tweet', 'tweet', '▁Twitter', 'Tweets']\n",
      "\n",
      "🔹 Feature 2949 (Activation: 204.4818)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Facebook', 'Facebook', '▁Instagram', 'Instagram', '▁facebook', '▁YouTube', '▁LinkedIn', '▁Pinterest', '▁instagram', '▁Youtube']\n",
      "\n",
      "🔹 Feature 2389 (Activation: 57.7809)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁blog', '▁Blog', 'Blog', 'blog', '▁blogs', '▁BLOG', 'BLOG', '▁Blogs', '▁bloggers', 'blogs']\n",
      "\n",
      "🔹 Feature 2966 (Activation: 55.9109)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁social', 'social', 'Social', '▁Social', '▁SOCIAL', 'SOCIAL', '▁sociale', '▁sosial', '▁sociales', '▁sozialen']\n",
      "\n",
      "🔹 Feature 754 (Activation: 52.6369)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁@', '@', '(@', '=@', ',@', '>@', \"▁'@\", '.@', '-@', '\">@']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Twitter\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "93a335c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'python'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2738 (Activation: 198.7793)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Python', '▁Java', '▁python', 'Python', 'Java', 'python', 'java', '▁java', '▁JAVA', '▁PHP']\n",
      "\n",
      "🔹 Feature 646 (Activation: 116.3763)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Pi', 'pi', 'Pi', '▁pi', 'Py', '▁Py', '▁py', 'PI', 'py', '▁PY']\n",
      "\n",
      "🔹 Feature 505 (Activation: 76.6764)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁lizard', '▁Lizard', '▁turtle', 'Lizard', '▁frog', '▁Frog', '▁Turtle', '▁reptile', '▁alligator', 'Frog']\n",
      "\n",
      "🔹 Feature 993 (Activation: 67.0827)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁script', '▁Script', '▁scripts', 'script', 'Script', '▁Scripts', '▁SCRIPT', 'SCRIPT', 'cript', 'Scripts']\n",
      "\n",
      "🔹 Feature 1821 (Activation: 57.3671)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['p', 'pt', 'pd', 'pf', 'pto', 'pg', 'pte', 'pk', 'pj', 'pb']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"python\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "712d35d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'Bayesian'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2570 (Activation: 150.5196)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Bay', '▁bay', 'Bay', '▁BAY', 'bay', 'BAY', '▁bays', '▁Bays', '▁bahía', '▁Bayer']\n",
      "\n",
      "🔹 Feature 3022 (Activation: 81.2093)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁linear', '▁Linear', 'linear', 'Linear', '▁LINEAR', '▁liné', '线性', '▁linearly', 'LINEAR', '▁nonlinear']\n",
      "\n",
      "🔹 Feature 2298 (Activation: 69.7305)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁probability', '▁Probability', 'probability', 'Probability', '▁probabilities', '▁likelihood', '▁Probab', '▁probab', 'probab', '▁probabilidad']\n",
      "\n",
      "🔹 Feature 383 (Activation: 54.9099)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁miR', 'miR', '▁miRNA', 'TGF', '▁miRNAs', '▁VEGF', '▁TGF', 'VEGF', 'RNAs', 'poptotic']\n",
      "\n",
      "🔹 Feature 3341 (Activation: 54.8417)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Muslim', '▁Catholic', 'Muslim', 'Catholic', '▁Islamic', '▁Protestant', '▁muslim', '▁Muslims', '▁Catholics', '▁Hindu']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"Bayesian\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b1a0e3b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'SUV'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 2474 (Activation: 122.5702)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Ford', '▁Toyota', '▁Chevrolet', '▁Nissan', '▁Volkswagen', 'Ford', '▁Porsche', '▁Jeep', '▁Mercedes', '▁Chevy']\n",
      "\n",
      "🔹 Feature 4021 (Activation: 106.8759)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['Su', '▁Su', '▁su', 'su', '▁SU', 'SU', '▁Suk', '▁Су', 'Су', '▁Suz']\n",
      "\n",
      "🔹 Feature 3127 (Activation: 57.7977)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁vehicle', '▁vehicles', 'vehicle', '▁Vehicle', '▁Vehicles', 'Vehicle', '▁VEHICLE', 'vehicles', 'Vehicles', '▁VEH']\n",
      "\n",
      "🔹 Feature 25 (Activation: 57.1076)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['HV', '▁pv', 'pv', '▁LV', 'LV', 'PV', 'MV', 'bv', '▁MV', 'BV']\n",
      "\n",
      "🔹 Feature 3583 (Activation: 54.2649)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁drone', '▁Drone', '▁drones', 'Drone', 'drone', '▁Drones', '▁UAV', '▁helicopter', '▁dron', '▁Helicopter']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"SUV\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e8a18ec9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'MIT'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 1106 (Activation: 308.0102)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['mit', '▁Mit', 'Mit', '▁mit', 'MIT', '▁MIT', 'mitt', 'mits', '▁mits', '▁Mito']\n",
      "\n",
      "🔹 Feature 1729 (Activation: 83.6036)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Mi', 'Mi', '▁Mis', 'Mis', 'mis', '▁mis', '▁mi', 'mi', '▁MIS', 'MIS']\n",
      "\n",
      "🔹 Feature 100 (Activation: 66.2066)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁university', '▁University', 'University', '▁UNIVERSITY', '▁college', 'university', '▁College', 'College', '▁universities', 'college']\n",
      "\n",
      "🔹 Feature 1179 (Activation: 61.7111)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁USDA', '▁FDA', 'USDA', '▁UNESCO', '▁NASA', 'UNESCO', '▁NOAA', 'NASA', '▁OSHA', '▁USGS']\n",
      "\n",
      "🔹 Feature 640 (Activation: 52.8986)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁M', 'M', '▁М', '▁m', 'getM', 'Ｍ', '▁getM', 'М', '▁м', '▁Ml']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"MIT\", top_k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "4a1f5716",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Explanations for word: 'swim'\n",
      "==================================================\n",
      "\n",
      "🔹 Feature 1118 (Activation: 337.5678)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁swim', '▁swimming', '▁Swim', 'swim', 'Swim', '▁swims', '▁Swimming', 'swimming', '▁swam', 'Swimming']\n",
      "\n",
      "🔹 Feature 3638 (Activation: 130.1786)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['s', 'sho', 'som', 'sti', 'sz', 'ska', 'sin', 'sis', 'shi', 'sf']\n",
      "\n",
      "🔹 Feature 1297 (Activation: 86.6638)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁Sw', 'Sw', '▁sw', 'sw', '▁SW', 'SW', 'Swi', '▁Swat', '▁Swing', '▁swing']\n",
      "\n",
      "🔹 Feature 2383 (Activation: 56.6680)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁dive', '▁diving', '▁Dive', 'dive', '▁Diving', 'Dive', '▁dives', 'Diving', 'diving', '▁dived']\n",
      "\n",
      "🔹 Feature 1713 (Activation: 56.2625)\n",
      "  Words that maximally activate this Feature:\n",
      "   ['▁bath', '▁shower', '▁baths', '▁Shower', '▁BATH', '▁Bath', 'bath', 'shower', 'Shower', 'Bath']\n"
     ]
    }
   ],
   "source": [
    "explain_word(\"swim\", top_k=k)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
