{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of vectors: 1\n",
      "Vector dimension: 768\n",
      "Vector names: ['sentiment_vector']\n"
     ]
    }
   ],
   "source": [
    "# Example usage\n",
    "import json\n",
    "import torch\n",
    "from pathlib import Path\n",
    "from sae_dashboard.neuronpedia.vector_set import VectorSet\n",
    "\n",
    "\n",
    "# Load vector from file. Note that the vectors should be stored in this format, as a list of lists of floats:\n",
    "# {\n",
    "#     \"vectors\": [\n",
    "#         [vector_1],\n",
    "#         [vector_2],\n",
    "#         ...\n",
    "#     ]\n",
    "# }\n",
    "json_path = Path(\"test_vectors/logistic_direction.json\")\n",
    "\n",
    "# Load the vector into a VectorSet\n",
    "vector_set = VectorSet.from_json(\n",
    "    json_path=json_path,\n",
    "    d_model=768,  # Example dimension for GPT-2 Small\n",
    "    hook_point=\"blocks.7.hook_resid_pre\",\n",
    "    hook_layer=7,\n",
    "    model_name=\"gpt2\",\n",
    "    names=[\"sentiment_vector\"],  # Optional custom name\n",
    ")\n",
    "\n",
    "# Now you can use the vector set\n",
    "print(f\"Number of vectors: {vector_set.vectors.shape[0]}\")\n",
    "print(f\"Vector dimension: {vector_set.vectors.shape[1]}\")\n",
    "print(f\"Vector names: {vector_set.names}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can also save and load the vector set as a VectorSet object as opposed to a simple list of lists of floats\n",
    "vector_set.save(Path(\"test_vectors/logistic_direction_vector_set.json\"))\n",
    "vector_set = VectorSet.load(Path(\"test_vectors/logistic_direction_vector_set.json\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_dashboard.neuronpedia.neuronpedia_vector_runner import (\n",
    "    NeuronpediaVectorRunner,\n",
    "    NeuronpediaVectorRunnerConfig,\n",
    ")\n",
    "\n",
    "cfg = NeuronpediaVectorRunnerConfig(\n",
    "    outputs_dir=\"test_outputs/\",\n",
    "    huggingface_dataset_path=\"monology/pile-uncopyrighted\",\n",
    "    vector_dtype=\"float32\",\n",
    "    model_dtype=\"float32\",\n",
    "    # Small test settings\n",
    "    n_prompts_total=16384,\n",
    "    n_tokens_in_prompt=128,  # Shorter sequences\n",
    "    n_prompts_in_forward_pass=256,\n",
    "    n_vectors_at_a_time=1,\n",
    "    use_wandb=False,  # Disable wandb for testing\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device Count: 1\n",
      "Using specified vector dtype: float32\n",
      "SAE Device: mps\n",
      "Model Device: mps\n",
      "Model Num Devices: 1\n",
      "Activation Store Device: mps\n",
      "Dataset Path: monology/pile-uncopyrighted\n",
      "Forward Pass size: 128\n",
      "Total number of tokens: 2097152\n",
      "Total number of contexts (prompts): 16384\n",
      "Vector DType: float32\n",
      "Model DType: float32\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/curttigges/miniconda3/envs/sae-d/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2 into HookedTransformer\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f1a49eee02cd482e9de6deaa88e4afde",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.\n",
      "Tokens don't exist, making them.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2048 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3180 > 1024). Running this sequence through the model will result in indexing errors\n",
      "100%|██████████| 2048/2048 [00:18<00:00, 108.67it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========== Running Batch #0 ==========\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "418cf3ccf15d4ae597e06d24e4c89b11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Forward passes to cache data for vis:   0%|          | 0/60 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2a0ddbda94d0407598edf564b4487407",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/curttigges/Projects/SAEDashboard/sae_dashboard/vector_data_generator.py:205: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  return torch.load(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "feature_indices: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━┳━━━━━━┳━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Task </span>┃<span style=\"font-weight: bold\"> Time </span>┃<span style=\"font-weight: bold\"> Pct % </span>┃\n",
       "┡━━━━━━╇━━━━━━╇━━━━━━━┩\n",
       "└──────┴──────┴───────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━┳━━━━━━┳━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mTask\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mTime\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPct %\u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━╇━━━━━━╇━━━━━━━┩\n",
       "└──────┴──────┴───────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [00:02,  2.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output written to test_outputs/gpt2_blocks.7.hook_resid_pre/batch-0.json\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "runner = NeuronpediaVectorRunner(vector_set, cfg)\n",
    "runner.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae-d",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
