{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "497cdda2-e525-4ad9-8303-0790e44c2feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sae\n",
    "import jax\n",
    "import optax\n",
    "import equinox as eqx\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "from transformers import AutoTokenizer\n",
    "from openai import OpenAI\n",
    "k = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98c95297-a26b-4ea5-8a93-731efa6d5a81",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_sae(whitening=True):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b\")\n",
    "    vocab = tokenizer.get_vocab()\n",
    "    DIR = 'your/path/here'\n",
    "\n",
    "    def data_generator(data, batch_size):\n",
    "        num_samples = data.shape[0]\n",
    "        key = jax.random.PRNGKey(0)\n",
    "\n",
    "        while True:\n",
    "            key, subkey = jax.random.split(key)\n",
    "            indices = jax.random.permutation(subkey, num_samples)\n",
    "            \n",
    "            for start in range(0, num_samples, batch_size):\n",
    "                end = min(start + batch_size, num_samples)\n",
    "                batch_indices = indices[start:end]\n",
    "                batch = data[batch_indices]\n",
    "                \n",
    "                # # Normalize the batch\n",
    "                batch = batch - jnp.mean(batch, axis=1, keepdims=True)\n",
    "                norms = jnp.linalg.norm(batch, axis=1, keepdims=True)\n",
    "                batch = batch / (norms + 1e-8)  # Add small epsilon to avoid division by zero\n",
    "                \n",
    "                yield batch\n",
    "\n",
    "    def create_data_loader(data, batch_size):\n",
    "        return data_generator(data, batch_size)\n",
    "\n",
    "    if whitening:\n",
    "        # Load unembeddings and whiten\n",
    "        g = jnp.load(DIR)\n",
    "        g = g - g.mean(axis=0)\n",
    "        u, s, vt = jnp.linalg.svd(g, full_matrices=False)\n",
    "        eps = 1e-6  # Numerical stability to avoid division by zero\n",
    "        cov = (g.T @ g) / g.shape[0]\n",
    "        eigvals, eigvecs = jnp.linalg.eigh(cov)\n",
    "        inv_sqrt_eigvals = 1.0 / jnp.sqrt(eigvals + eps)\n",
    "        cov_inv_sqrt = eigvecs @ jnp.diag(inv_sqrt_eigvals) @ eigvecs.T\n",
    "        g = g @ cov_inv_sqrt\n",
    "        \n",
    "        g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # scaling for norm\n",
    "        \n",
    "        batch_size = 8192\n",
    "        train_loader = create_data_loader(g, batch_size)\n",
    "        example_batch = next(train_loader)\n",
    "        input_dim = example_batch[0].shape[0]\n",
    "        latent_dim = 2**12\n",
    "        key = jax.random.PRNGKey(0)\n",
    "        k = 5\n",
    "        model = sae.Autoencoder(latent_dim, input_dim, use_bias=True, key=key, k=k)\n",
    "\n",
    "    else:\n",
    "        # Load unembeddings\n",
    "        g = jnp.load(DIR)\n",
    "        \n",
    "        g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # scaling for norm\n",
    "        \n",
    "        batch_size = 8192\n",
    "        train_loader = create_data_loader(g, batch_size)\n",
    "        example_batch = next(train_loader)\n",
    "        input_dim = example_batch[0].shape[0]\n",
    "        latent_dim = 2**12\n",
    "        key = jax.random.PRNGKey(0)\n",
    "        k = 5\n",
    "        model = sae.Autoencoder(latent_dim, input_dim, use_bias=True, key=key, k=k)\n",
    "\n",
    "\n",
    "    num_epochs = 100\n",
    "    learning_rate = 3e-4 # kaparthy constant\n",
    "    \n",
    "    optimizer = optax.adam(learning_rate)\n",
    "    \n",
    "    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))\n",
    "    \n",
    "    step = 0\n",
    "    steps_per_epoch =  g.shape[0] // batch_size\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        for step_in_epoch in range(steps_per_epoch):\n",
    "            batch = next(train_loader)\n",
    "            model, opt_state, loss = sae.train_step(model, batch, opt_state, optimizer)\n",
    "            step += 1\n",
    "            if step % 100 == 0:\n",
    "                print(f\"Epoch {epoch + 1}, Step {step}: Loss = {loss.item():.4f}\")\n",
    "\n",
    "    vocab_list = [None] * (max(vocab.values()) + 1)\n",
    "    for word, index in vocab.items():\n",
    "        vocab_list[index] = word\n",
    "\n",
    "\n",
    "    import tqdm\n",
    "\n",
    "    class SAEActivationCache:\n",
    "        def __init__(self, vocab_size, latent_dim, max_examples_per_latent=10):\n",
    "            self.vocab_size = vocab_size\n",
    "            self.latent_dim = latent_dim\n",
    "            self.max_examples_per_latent = max_examples_per_latent\n",
    "    \n",
    "            self.activations = np.full((latent_dim, max_examples_per_latent), -np.inf, dtype=np.float32)\n",
    "            self.words = np.full((latent_dim, max_examples_per_latent), -1, dtype=np.int32)\n",
    "    \n",
    "        def update(self, batch_words, latents):\n",
    "            batch_size = latents.shape[0]\n",
    "            for i in range(batch_size):\n",
    "                word_idx = int(batch_words[i])\n",
    "                top_latent = int(np.argmax(latents[i]))\n",
    "                activation = float(latents[i, top_latent])\n",
    "    \n",
    "                current_acts = self.activations[top_latent]\n",
    "                min_idx = np.argmin(current_acts)\n",
    "                if activation > current_acts[min_idx]:\n",
    "                    self.activations[top_latent, min_idx] = activation\n",
    "                    self.words[top_latent, min_idx] = word_idx\n",
    "\n",
    "    def build_sae_activation_cache(dataset, model, batch_size, vocab_embeddings, vocab_indices, latent_dim, max_examples_per_latent=10):\n",
    "        activation_cache = SAEActivationCache(\n",
    "            vocab_size=len(vocab_list),\n",
    "            latent_dim=latent_dim,\n",
    "            max_examples_per_latent=max_examples_per_latent\n",
    "        )\n",
    "    \n",
    "        num_batches = int(np.ceil(len(dataset) / batch_size))\n",
    "    \n",
    "        for batch_start in tqdm.tqdm(range(0, len(dataset), batch_size), total=num_batches, desc=\"Building SAE Activation Cache\"):\n",
    "            batch_end = min(batch_start + batch_size, len(dataset))\n",
    "            batch = vocab_embeddings[vocab_indices[batch_start:batch_end]]\n",
    "    \n",
    "            latents = jax.vmap(model.top_k_encode)(batch)\n",
    "            latents = np.array(latents)\n",
    "            batch_words = np.array(vocab_indices[batch_start:batch_end])\n",
    "    \n",
    "            activation_cache.update(batch_words=batch_words, latents=latents)\n",
    "    \n",
    "        return activation_cache\n",
    "    \n",
    "    activation_cache = build_sae_activation_cache(\n",
    "        dataset=g,\n",
    "        model=model,\n",
    "        batch_size=batch_size,\n",
    "        vocab_embeddings=g,\n",
    "        vocab_indices=np.arange(g.shape[0]),\n",
    "        latent_dim=latent_dim,\n",
    "        max_examples_per_latent=10)\n",
    "\n",
    "    def explain_word(word, top_k=k):\n",
    "        word_vec = g[vocab[word]]\n",
    "        latents = np.array(model.top_k_encode(word_vec))\n",
    "    \n",
    "        ret_string = f\"Explanations for word: '{word.strip()}'\\n{'='*50} \\n\"\n",
    "    \n",
    "        top_latents = np.argsort(-latents)[:top_k]\n",
    "        for idx in top_latents:\n",
    "            activation = latents[idx]\n",
    "            ret_string += (f\"\\n🔹 Feature {idx} \\n\")\n",
    "            ret_string += (\"  Words that maximally activate this Feature: \\n\")\n",
    "    \n",
    "            top_examples = np.argsort(-activation_cache.activations[idx])\n",
    "            top_words = [vocab_list[int(w)] for w in activation_cache.words[idx, top_examples] if w >= 0]\n",
    "            ret_string += (f\"   {top_words} \\n\")\n",
    "        return ret_string\n",
    "\n",
    "    return explain_word\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76f63b67-88da-4d93-ae4c-7ef5051e5801",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word_whitening = prepare_sae(whitening=True)\n",
    "explain_word = prepare_sae(whitening=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "effda4dd-54cf-4ec0-ac40-13a0e93c245e",
   "metadata": {},
   "outputs": [],
   "source": [
    "word = 'puppy'\n",
    "A_string = explain_word(word, top_k=k)\n",
    "B_string = explain_word_whitening(word, top_k=k)\n",
    "print(A_string)\n",
    "print(B_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a81fcb0c-ba91-4b77-a64a-b26a79b5ba05",
   "metadata": {},
   "outputs": [],
   "source": [
    "words = [\n",
    "    \"apple\", \"book\", \"chair\", \"table\", \"car\", \"house\", \"tree\", \"dog\", \"cat\", \"water\",\n",
    "    \"phone\", \"door\", \"window\", \"bed\", \"pen\", \"paper\", \"shoe\", \"bag\", \"road\", \"city\",\n",
    "    \"park\", \"school\", \"bus\", \"train\", \"light\", \"clock\", \"street\", \"friend\", \"family\", \"hand\",\n",
    "    \"foot\", \"head\", \"face\", \"smile\", \"shop\", \"money\", \"key\", \"room\", \"work\", \"play\",\n",
    "    \"sun\", \"rain\", \"sky\", \"drink\", \"game\", \"child\", \"parent\", \"food\", \"music\", \"computer\",\n",
    "    \"river\", \"mountain\", \"ocean\", \"bookstore\", \"library\", \"garden\", \"flower\", \"bird\", \"fish\", \"bread\",\n",
    "    \"milk\", \"egg\", \"cheese\", \"salt\", \"sugar\", \"coffee\", \"tea\", \"plate\", \"cup\", \"fork\",\n",
    "    \"knife\", \"spoon\", \"puppy\", \"sandwich\", \"pizza\", \"burger\", \"fruit\", \"grape\", \"orange\", \"banana\",\n",
    "    \"lemon\", \"strawberry\", \"cherry\", \"peach\", \"plum\", \"melon\", \"cookie\", \"cake\", \"chocolate\", \"ice\",\n",
    "    \"snow\", \"wind\", \"cloud\", \"star\", \"moon\", \"earth\", \"map\", \"flag\", \"ball\", \"toy\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935c6dcc-f5ad-44c9-80ad-83324bc9bbd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize GPT-4o\n",
    "import openai\n",
    "client = OpenAI(api_key==\"your key here\")\n",
    "\n",
    "def judge_responses(word, response_a, response_b):\n",
    "    # Format the message for GPT-4o with the given prompt and responses\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"Please act as an impartial judge and evaluate which set of features best captures the meaning of the target word.\"\n",
    "            f\"Give a short verdict in the given format and provide a short explanation for the decision.\"\n",
    "            f\"following this format: '[[A]]' if explanation A is substantially better, '[[B]]' if explanation B is substantially better, and '[[C]]' when there's a tie. \\n\"\n",
    "            f\"[Target word] {word} \\n\"\n",
    "            f\"[Explanation A] {response_a} \\n [End of Explanation A]\\n\\n\"\n",
    "            f\"[Explanation B] {response_b} \\n [End of Explanation B]\\n\\n\",\n",
    "        }\n",
    "    ]\n",
    "\n",
    "    \n",
    "    # Generate the response using GPT-4o\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o\",  # Placeholder for GPT-4o API call\n",
    "        messages=messages,\n",
    "        max_tokens=256,\n",
    "    )\n",
    "    return response.choices[0].message.content\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4750a10a-6cca-423b-87aa-266a6cd5f339",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "results = []\n",
    "win_counts = {\"whiten\": 0, \"non-whiten\": 0, \"tie\": 0}\n",
    "\n",
    "# Iterate over each prompt-response pair\n",
    "for word in tqdm(words, total=len(words)):\n",
    "    my_keys = {\n",
    "        \"a\": \"whiten\",\n",
    "        \"b\": \"non-whiten\",\n",
    "    }\n",
    "\n",
    "    response_whiten = explain_word_whitening(word, top_k=k)\n",
    "    response_non_whiten = explain_word(word, top_k=k)\n",
    "\n",
    "    # Use GPT-4o to judge responses\n",
    "    judgement = judge_responses(word, response_whiten, response_non_whiten)\n",
    "    # Determine the winner based on GPT-4o's judgement\n",
    "    if \"[[A]]\" in judgement:\n",
    "        winner = (\"a\",my_keys[\"a\"])\n",
    "        win_counts[my_keys[\"a\"]] += 1\n",
    "    elif \"[[B]]\" in judgement:\n",
    "        winner = (\"b\", my_keys[\"b\"])\n",
    "        win_counts[my_keys[\"b\"]] += 1\n",
    "    else:\n",
    "        winner = (\"tie\", \"tie\")\n",
    "        win_counts[\"tie\"] += 1\n",
    "\n",
    "    # Append the result\n",
    "    \n",
    "    results.append(\n",
    "        {\n",
    "            \"word\": word,\n",
    "            \"response_whiten\": response_whiten,\n",
    "            \"response_non_whiten\": response_non_whiten,\n",
    "            \"judgement\": judgement.strip(),\n",
    "            \"winner\": winner,\n",
    "        }\n",
    "    )\n",
    "\n",
    "\n",
    "# Calculate win rates\n",
    "total = sum(win_counts.values())\n",
    "whiten_win_rate = (win_counts[\"whiten\"] / total) * 100 if total > 0 else 0\n",
    "non_whiten_win_rate = (win_counts[\"non-whiten\"] / total) * 100 if total > 0 else 0\n",
    "tie_rate = (win_counts[\"tie\"] / total) * 100 if total > 0 else 0\n",
    "\n",
    "print(f\"Win Rate for Winten Method: {whiten_win_rate:.2f}%\")\n",
    "print(f\"Win Rate for Non Whiten Method: {non_whiten_win_rate:.2f}%\")\n",
    "print(f\"Tie Rate: {tie_rate:.2f}%\")\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1490826-dcc7-4078-a9b3-c50c3b1ca0be",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "007b51bc-3b0e-48b1-afe8-c5f9ec6a523c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9369b41e-fb58-41da-be8b-aeb7daceb15b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
