{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6f0374ad",
   "metadata": {},
   "source": [
    "Baseline: Yes L1, Yes Orthog"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ad7b442-a158-46ba-8181-251c1fbdfa43",
   "metadata": {},
   "outputs": [],
   "source": [
    "from moe_eqx import MixtureOfExperts_v2, train_step, mask_codes\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import optax\n",
    "import equinox as eqx\n",
    "unembeds_dir = \"\" # where do you have the downloaded or generated vocab and unembeddings saved?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7d02913",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_generator(data, batch_size):\n",
    "    num_samples = data.shape[0]\n",
    "    key = jax.random.PRNGKey(0)\n",
    "\n",
    "    while True:\n",
    "        key, subkey = jax.random.split(key)\n",
    "        indices = jax.random.permutation(subkey, num_samples)\n",
    "        \n",
    "        for start in range(0, num_samples, batch_size):\n",
    "            end = min(start + batch_size, num_samples)\n",
    "            batch_indices = indices[start:end]\n",
    "            batch = data[batch_indices]\n",
    "            \n",
    "            # # Normalize the batch\n",
    "            batch = batch - jnp.mean(batch, axis=1, keepdims=True)\n",
    "            norms = jnp.linalg.norm(batch, axis=1, keepdims=True)\n",
    "            batch = batch / (norms + 1e-8)  # Add small epsilon to avoid division by zero\n",
    "            \n",
    "            yield batch\n",
    "\n",
    "def create_data_loader(data, batch_size):\n",
    "    return data_generator(data, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8502337-d77b-4f21-b43a-9ce7e2bacff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = jnp.load(f'{unembeds_dir}/clean_gemma_embeddings.npy')\n",
    "g = g * jnp.sqrt(g.shape[0] / g.shape[1]) # set the norms to be close to 1\n",
    "\n",
    "\n",
    "# Set up the data loader\n",
    "# batch_size = jax.local_device_count() * 4096\n",
    "batch_size = 8192\n",
    "train_loader = create_data_loader(g, batch_size)\n",
    "example_batch = next(train_loader)\n",
    "input_dim = example_batch[0].shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef13c185",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab6cdf1b-a096-492d-82ec-574b10bd8cd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = example_batch[0].shape[0]\n",
    "subspace_dim = 5\n",
    "num_experts = 2**11\n",
    "atoms_per_subspace = int(2**5)\n",
    "k = 5\n",
    "\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "model = MixtureOfExperts_v2(\n",
    "    input_dim=input_dim,\n",
    "    subspace_dim=subspace_dim,\n",
    "    atoms_per_subspace=atoms_per_subspace,\n",
    "    num_experts=num_experts,\n",
    "    k=k,\n",
    "    key=key\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0ec799a",
   "metadata": {},
   "outputs": [],
   "source": [
    "l1_penalty = 2.5e-3\n",
    "ortho_penalty = 2.5e-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10cc3409-16ff-4bde-a922-b2baaba2fa98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "num_steps = 5000\n",
    "warmup_steps = 250\n",
    "\n",
    "def cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, alpha=0.0):\n",
    "    warmup_fn = optax.linear_schedule(\n",
    "        init_value=init_value,\n",
    "        end_value=peak_value,\n",
    "        transition_steps=warmup_steps\n",
    "    )\n",
    "    cosine_decay_fn = optax.cosine_decay_schedule(\n",
    "        init_value=peak_value,\n",
    "        decay_steps=decay_steps,\n",
    "        alpha=alpha\n",
    "    )\n",
    "    return optax.join_schedules(\n",
    "        schedules=[warmup_fn, cosine_decay_fn],\n",
    "        boundaries=[warmup_steps]\n",
    "    )\n",
    "\n",
    "\n",
    "learning_rate_fn = cosine_decay_schedule(\n",
    "    init_value=batch_size * 1e-15,\n",
    "    peak_value=batch_size * 1e-6,\n",
    "    warmup_steps=warmup_steps,\n",
    "    decay_steps=num_steps - warmup_steps,\n",
    "    alpha=0.1  # Final learning rate will be 10% of peak value\n",
    ")\n",
    "\n",
    "optimizer = optax.chain(\n",
    "    optax.clip_by_global_norm(1.0),\n",
    "    optax.adam(learning_rate_fn, b1=0.9, b2=0.999)\n",
    ")\n",
    "\n",
    "opt_state = optimizer.init(eqx.filter(model, eqx.is_array))\n",
    "\n",
    "start_time = time.time()\n",
    "for step in range(num_steps):\n",
    "    batch = next(train_loader)\n",
    "    model, opt_state, loss, aux_out = train_step(model, batch, opt_state, l1_penalty, ortho_penalty, optimizer)\n",
    "    if step % 50 == 0:\n",
    "        print(f\"Step {step}, Loss: {loss:.4f}\")\n",
    "end_time = time.time()\n",
    "print(f\"Training took {end_time - start_time:.2f} seconds.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5675685a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# Load unembeddings\n",
    "g = jnp.load(f'{unembeds_dir}/clean_gemma_embeddings.npy')\n",
    "g = g * jnp.sqrt(g.shape[0] / g.shape[1])\n",
    "\n",
    "# Load vocabulary\n",
    "with open(f'{unembeds_dir}/clean_gemma_vocab_dict.json', 'r') as fout:\n",
    "    vocab_dict = json.load(fout)\n",
    "\n",
    "vocab_list = [None] * (max(vocab_dict.values()) + 1)\n",
    "for word, index in vocab_dict.items():\n",
    "    vocab_list[index] = word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3df9d0ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = g[vocab_dict[word]][None, :]\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b376a5d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_autoencoder(x):\n",
    "    if x.ndim == 1:\n",
    "        x = x[None, :]  # Add batch dimension\n",
    "\n",
    "    # Wrap model.encode in vmap for consistency\n",
    "    top_level_latent_codes, expert_specific_codes, top_k_indices, top_k_values = jax.vmap(model.encode)(x)\n",
    "    masked_top_level_latent_codes, masked_expert_specific_codes = mask_codes(\n",
    "        top_level_latent_codes, expert_specific_codes, top_k_indices, top_k_values\n",
    "    )\n",
    "    x_hat, _, _ = jax.vmap(model.decode)(masked_expert_specific_codes, top_k_indices, top_k_values)\n",
    "    return x_hat\n",
    "\n",
    "word=' dog'\n",
    "# the encoder expects batches so we need to add a batch dimension\n",
    "\n",
    "x = g[vocab_dict[word]]\n",
    "dog_hat = run_autoencoder(x)\n",
    "\n",
    "print(jnp.square(dog_hat - g[vocab_dict[word]]).sum())\n",
    "print(jnp.square(g[vocab_dict[word]]).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11159e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "class ActivationCache:\n",
    "    def __init__(self, vocab_size, num_experts, atoms_per_subspace, max_examples_per_latent=10):\n",
    "        self.vocab_size = vocab_size\n",
    "        self.num_experts = num_experts\n",
    "        self.atoms_per_subspace = atoms_per_subspace\n",
    "        self.max_examples_per_latent = max_examples_per_latent\n",
    "\n",
    "        self.top_activations = np.full((num_experts, max_examples_per_latent), -np.inf, dtype=np.float32)\n",
    "        self.top_words = np.full((num_experts, max_examples_per_latent), -1, dtype=np.int32)\n",
    "\n",
    "        self.low_activations = np.full((num_experts, atoms_per_subspace, max_examples_per_latent), -np.inf, dtype=np.float32)\n",
    "        self.low_words = np.full((num_experts, atoms_per_subspace, max_examples_per_latent), -1, dtype=np.int32)\n",
    "\n",
    "    def update(self, batch_words, top_latents, expert_latents):\n",
    "        batch_size = top_latents.shape[0]\n",
    "\n",
    "        for i in range(batch_size):\n",
    "            word_idx = int(batch_words[i])\n",
    "            top_latent = int(np.argmax(top_latents[i]))\n",
    "            top_activation = float(top_latents[i, top_latent])\n",
    "\n",
    "            # Top-level cache update\n",
    "            current_acts = self.top_activations[top_latent]\n",
    "            min_idx = np.argmin(current_acts)\n",
    "            if top_activation > current_acts[min_idx]:\n",
    "                self.top_activations[top_latent, min_idx] = top_activation\n",
    "                self.top_words[top_latent, min_idx] = word_idx\n",
    "\n",
    "            # Low-level cache update\n",
    "            for expert in range(expert_latents.shape[1]):\n",
    "                low_latent = int(np.argmax(expert_latents[i, expert]))\n",
    "                low_activation = float(expert_latents[i, expert, low_latent])\n",
    "\n",
    "                current_low_acts = self.low_activations[top_latent, low_latent]\n",
    "                min_low_idx = np.argmin(current_low_acts)\n",
    "                if low_activation > current_low_acts[min_low_idx]:\n",
    "                    self.low_activations[top_latent, low_latent, min_low_idx] = low_activation\n",
    "                    self.low_words[top_latent, low_latent, min_low_idx] = word_idx\n",
    "\n",
    "def build_activation_cache(dataset, model, batch_size, vocab_embeddings, vocab_indices, num_experts, atoms_per_subspace, top_k, max_examples_per_latent=5):\n",
    "    activation_cache = ActivationCache(\n",
    "        vocab_size=len(vocab_list),\n",
    "        num_experts=num_experts,\n",
    "        atoms_per_subspace=atoms_per_subspace,\n",
    "        max_examples_per_latent=max_examples_per_latent\n",
    "    )\n",
    "\n",
    "    num_batches = int(np.ceil(len(dataset) / batch_size))\n",
    "\n",
    "    for batch_start in tqdm.tqdm(range(0, len(dataset), batch_size), total=num_batches, desc=\"Building Activation Cache\"):\n",
    "        batch_end = min(batch_start + batch_size, len(dataset))\n",
    "        batch = vocab_embeddings[vocab_indices[batch_start:batch_end]]\n",
    "\n",
    "        # Run encoding on device and pull results to CPU for fast numpy updates\n",
    "        top_latents, expert_latents, _, _ = jax.vmap(model.encode)(batch)\n",
    "        top_latents = np.array(top_latents)\n",
    "        expert_latents = np.array(expert_latents)\n",
    "        batch_words = np.array(vocab_indices[batch_start:batch_end])\n",
    "\n",
    "        activation_cache.update(batch_words=batch_words, \n",
    "                                top_latents=top_latents, \n",
    "                                expert_latents=expert_latents)\n",
    "\n",
    "    return activation_cache\n",
    "\n",
    "activation_cache = build_activation_cache(\n",
    "    dataset=g,\n",
    "    model=model,\n",
    "    batch_size=8192,\n",
    "    vocab_embeddings=g,\n",
    "    vocab_indices=jnp.arange(g.shape[0]),\n",
    "    num_experts=num_experts,\n",
    "    atoms_per_subspace=atoms_per_subspace,\n",
    "    top_k=k\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2fec1c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explain_word(word, top_k=5, low_k=5):\n",
    "    word_vec = g[vocab_dict[word]]\n",
    "    top_latents, low_latents, top_k_indices, _ = model.encode(word_vec)\n",
    "\n",
    "    # Remove batch dimension\n",
    "    top_latents = np.array(top_latents).squeeze()\n",
    "    low_latents = np.array(low_latents).squeeze()\n",
    "    top_k_indices = np.array(top_k_indices).squeeze()\n",
    "\n",
    "    print(f\"\\nExplanations for word: '{word.strip()}'\\n{'='*50}\")\n",
    "\n",
    "    for idx in range(min(top_k, len(top_k_indices))):\n",
    "        topic = int(top_k_indices[idx])\n",
    "        activation = top_latents[topic]\n",
    "        print(f\"\\n🔹 Top-Level Feature {topic} (Activation: {activation:.4f})\")\n",
    "        print(\"  Words that maximally activate this feature:\")\n",
    "\n",
    "        # Top-level activating words\n",
    "        top_examples = np.argsort(-activation_cache.top_activations[topic])[:low_k]\n",
    "        top_words = [vocab_list[int(w)] for w in activation_cache.top_words[topic, top_examples] if w >= 0]\n",
    "        print(f\"   {top_words}\")\n",
    "\n",
    "        # Low-level activation for this top-level feature\n",
    "        low_latent = int(np.argmax(low_latents[idx, :]))\n",
    "        low_activation = low_latents[idx, low_latent]\n",
    "        print(f\"    ↳ Low-Level Feature: {low_latent} (Activation: {low_activation:.8f})\")\n",
    "\n",
    "        print(\"      Words that maximally activate this low-level feature:\")\n",
    "        low_examples = np.argsort(-activation_cache.low_activations[topic, low_latent])[:low_k]\n",
    "        low_words = [vocab_list[int(w)] for w in activation_cache.low_words[topic, low_latent, low_examples] if w >= 0]\n",
    "        print(f\"       {low_words}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d675dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' puppy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc4e797",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' Queen')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3d8cffc",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' Chicago')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7326c939",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' London')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4fbb39e",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' Twitter')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f450075",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' python')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84f6668d",
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_word(' Bayesian')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
