{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Clean Gemma Unembeddings\n",
    "\n",
    "Gemma vocab includes a lot of junk. Here's a particular scheme for tossing the junk and whitening the embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "print(\"Devices:\", jax.devices())\n",
    "import jax.numpy as jnp\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from jax.experimental import mesh_utils\n",
    "from jax.sharding import PositionalSharding\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b\")\n",
    "vocab = tokenizer.get_vocab()\n",
    "\n",
    "num_devices = len(jax.devices())\n",
    "sharding = PositionalSharding(mesh_utils.create_device_mesh((num_devices,1)))\n",
    "import numpy as np\n",
    "out_dir = \"\" # where to store the unembeddings and such?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = jax.device_put(np.load('path/to/gemma_unembedding_matrix/raw_unembeddings.npy'), sharding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import treescope\n",
    "treescope.basic_interactive_setup(autovisualize_arrays=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = g - g.mean(axis=0)\n",
    "u, s, vt = jnp.linalg.svd(g, full_matrices=False)\n",
    "g = u @ vt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "vocab = spm.SentencePieceProcessor()\n",
    "vocab.Load(\"path/to/gemma_2_tokenizer/tokenizer.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dog_idx = jax.lax.top_k(g @ g[vocab.EncodeAsIds(' dog')[0]], 20)[1]\n",
    "vocab.DecodeIds(dog_idx.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# heuristic to filter junk words out of the vocab\n",
    "norms = jnp.linalg.norm(g, axis=1)\n",
    "acceptable_vocab = jnp.where((norms < 0.11558999) & (norms > 0.07008683))\n",
    "g = g[acceptable_vocab]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# try reindexing\n",
    "vocab_dict = {}\n",
    "for new_idx, orig_idx in enumerate(acceptable_vocab[0].tolist()):\n",
    "  vocab_dict[vocab.DecodeIds([orig_idx])] = new_idx\n",
    "\n",
    "vocab_list = [None] * (max(vocab_dict.values()) + 1)\n",
    "for word, index in vocab_dict.items():\n",
    "    vocab_list[index] = word\n",
    "\n",
    "dog_idx = jax.lax.top_k(g @ g[vocab_dict[' dog']], 20)[1]\n",
    "print([vocab_list[idx] for idx in dog_idx.tolist()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# whiten reindexed\n",
    "g = g - g.mean(axis=0)\n",
    "u, s, vt = jnp.linalg.svd(g, full_matrices=False)\n",
    "g = u @ vt\n",
    "\n",
    "# check that rewhitening doesn't break anything\n",
    "dog_idx = jax.lax.top_k(g @ g[vocab_dict[' dog']], 20)[1]\n",
    "print([vocab_list[idx] for idx in dog_idx.tolist()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jnp.save(f'{out_dir}/clean_unembeddings.npy', g)\n",
    "jnp.save(f'{out_dir}/clean_unembeddings_indices.npy', acceptable_vocab)\n",
    "with open(f'{out_dir}clean_vocab_dict.json', 'w') as fout:\n",
    "  json.dump(vocab_dict, fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
