{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8489839-96f8-40c2-ba34-cd5e7fe5d932",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7ed5e72-5320-43dc-8fe5-fb4a1fde3b06",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.collector import CollectionConfig, LatentCollector\n",
    "from src.utils import load_trained_sae\n",
    "import transformer_lens.utils as utils\n",
    "from transformers import AutoTokenizer\n",
    "from transformer_lens import HookedTransformer\n",
    "import gc\n",
    "import numpy as np\n",
    "\n",
    "import pickle\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b69b033-4af7-444e-a9fa-c3d06d5318ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = \"topk\"\n",
    "cfg, sae = load_trained_sae(f\"trained/{name}\") # requires pt and config files in this directory\n",
    "model_name = cfg.model_name\n",
    "hook_point = utils.get_act_name(cfg.hook_point, cfg.layer)\n",
    "dict_size = cfg.dict_size\n",
    "layer = cfg.layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae493bcb-f120-496d-8bf0-fc4a2efabc99",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = HookedTransformer.from_pretrained(model_name, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f018ff1f-7048-4111-bb89-c4db7ffc5cf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import gc\n",
    "if device == 'mps':\n",
    "    torch.mps.empty_cache()\n",
    "elif device == 'cuda':\n",
    "    torch.cuda.empty_cache()\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87a8d534-80d6-4499-a125-5bcc1c394fef",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomCollector(LatentCollector):\n",
    "    def _get_activations(self, hidden_state):\n",
    "        B, T, d = hidden_state.shape\n",
    "        _, activations = self.sae.encode(hidden_state.reshape(B * T, -1))\n",
    "        return activations[:, self.feature_indices]\n",
    "\n",
    "class PreCollector(LatentCollector):\n",
    "    def _get_activations(self, hidden_state):\n",
    "        B, T, d = hidden_state.shape\n",
    "        activations = self.sae.encode_pre(hidden_state.reshape(B * T, -1))\n",
    "        activations = self.sae.get_masked_pre_acts(activations)\n",
    "        activations = activations.reshape(B * T, -1)\n",
    "        return activations[:, self.feature_indices]\n",
    "\n",
    "config = CollectionConfig(\n",
    "    model_name=model_name,\n",
    "    hook_point=hook_point,\n",
    "    layer=layer,\n",
    "    dict_size=dict_size,\n",
    "    batch_size=128,\n",
    "    device=\"cuda\",\n",
    "    feature_indices=np.arange(3072, 16384).tolist(),\n",
    "\n",
    "    # ---------- Dataset -----------\n",
    "    dataset_path = \"HuggingFaceFW/fineweb-edu\",\n",
    "    dataset_name = \"CC-MAIN-2024-51\",\n",
    "    dataset_split = \"train\",\n",
    "    streaming = True,\n",
    "\n",
    "    # ----------- Buffer -----------\n",
    "    pos_buffer_size = 384,\n",
    "    neg_buffer_size = 128,\n",
    "    seq_len = 256,\n",
    "    pack_size = 256,\n",
    "\n",
    "    # --------- Histogram -----------\n",
    "    hist_bins = 20,\n",
    "    hist_min = 1e-1,\n",
    "    hist_max = 1e2,\n",
    "    pos_bins = 20,\n",
    "\n",
    "    # ---------- Early Exit -----------\n",
    "    filled_percent = 100.0,\n",
    "    avg_fill_rate = 1.0,\n",
    "    min_fill_rate = 1.0,\n",
    "    exit_strategy = \"tokens\",\n",
    "\n",
    "    # ---------- Various -----------\n",
    "    move_pack_to_cpu = False\n",
    ")\n",
    "\n",
    "collector = LatentCollector(model, sae, config)\n",
    "# collector = PreCollector(model, sae, config)\n",
    "stats = collector.collect(max_tokens=1024000 * 24)\n",
    "\n",
    "# embedding = model.embed.W_E.detach().cpu()\n",
    "# embedding = embedding / embedding.norm(dim=1, keepdim=True)\n",
    "\n",
    "for feature in tqdm(stats.values()):\n",
    "    feature.compute_token_statistics()\n",
    "    # feature.compute_cosine_statistics(sae.W_dec[feature.index].detach().cpu(), embedding.detach().cpu())\n",
    "    feature.examples = feature.split_examples(context_window = 15)\n",
    "\n",
    "with open(f'results/topk.pkl', 'wb') as f:\n",
    "    pickle.dump(stats, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efbef468-8193-476c-9b1b-1a7818b8e5f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(stats[np.random.choice(list(stats.keys()), 1)[0]].show(\n",
    "    tokenizer=tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=7\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11ff10d1-0c41-48f5-a0dc-56d83fba877a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_act_to_triplet(post_act_idx: int, m: int, n: int):\n",
    "    \"\"\"Convert single post-activation index to (head, row, column) triplet.\n",
    "    \n",
    "    Args:\n",
    "        post_act_idx: Integer index from post-activations\n",
    "        m: Number of rows per head\n",
    "        n: Number of columns per head\n",
    "    \n",
    "    Returns:\n",
    "        Tuple (head_idx, row_idx, col_idx)\n",
    "    \"\"\"\n",
    "    elements_per_head = m * n\n",
    "    head_idx = post_act_idx // elements_per_head\n",
    "    position_in_head = post_act_idx % elements_per_head\n",
    "    row_idx = position_in_head // n\n",
    "    col_idx = position_in_head % n\n",
    "    return head_idx, row_idx, col_idx\n",
    "\n",
    "def triplet_to_pre_act_indices(head_idx: int, \n",
    "                              row_idx: int,\n",
    "                              col_idx: int,\n",
    "                              m: int, n: int):\n",
    "    \"\"\"Convert single triplet (head, row, column) to pre-activation indices.\n",
    "    \n",
    "    Args:\n",
    "        head_idx: Head index\n",
    "        row_idx: Row index\n",
    "        col_idx: Column index\n",
    "        m: Number of rows per head\n",
    "        n: Number of columns per head\n",
    "    \n",
    "    Returns:\n",
    "        Tuple (row_pre_act_idx, column_pre_act_idx)\n",
    "    \"\"\"\n",
    "    features_per_head = m + n\n",
    "    base_idx = head_idx * features_per_head\n",
    "    row_pre_act = base_idx + row_idx\n",
    "    column_pre_act = base_idx + m + col_idx\n",
    "    return row_pre_act, column_pre_act"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd75752b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "163840000"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "40000*4096"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc57f035-58de-4712-90d0-40cb0f0c8f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "post_act_to_triplet(75, 4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f7f857-4cea-46cd-9cbb-481c79c33a29",
   "metadata": {},
   "outputs": [],
   "source": [
    "triplet_to_pre_act_indices(4, 2, 3, 4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08b26f3-ba31-46e8-bf97-b2ade4f70407",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(stats[39].show(\n",
    "    tokenizer=tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=10\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f9f85b8-6d19-423a-9e21-6bcd1bc224d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(stats[np.random.choice(list(stats.keys()), 1)[0]].show(\n",
    "    tokenizer=tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=7\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ec21401-ffb9-41f3-96d1-fc87570a4d67",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(stats[1387].show(\n",
    "    tokenizer=tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=8\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b90d4d-5fd8-40a8-9c41-218c412ba7c8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
