{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating Outputs for Neuronpedia Upload\n",
    "\n",
    "We use Callum McDougall's `sae_vis` library for generating JSON data to upload to Neuronpedia.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens.toolkit.pretrained_saes import download_sae_from_hf\n",
    "import os\n",
    "\n",
    "MODEL_ID = \"gpt2-small\"\n",
    "SAE_ID = \"res-jb\"\n",
    "\n",
    "(_, SAE_WEIGHTS_PATH, _) = download_sae_from_hf(\n",
    "    \"jbloom/GPT2-Small-SAEs-Reformatted\", \"blocks.0.hook_resid_pre\"\n",
    ")\n",
    "\n",
    "SAE_PATH = os.path.dirname(SAE_WEIGHTS_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save JSON to neuronpedia_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n",
    "\n",
    "print(SAE_PATH)\n",
    "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs/my_outputs\"\n",
    "\n",
    "runner = NeuronpediaRunner(\n",
    "    sae_id=SAE_ID,\n",
    "    sae_path=SAE_PATH,\n",
    "    outputs_dir=NP_OUTPUT_FOLDER,\n",
    "    sparsity_threshold=-5,\n",
    "    n_batches_to_sample_from=2**12,\n",
    "    n_prompts_to_select=4096*6,\n",
    "    n_features_at_a_time=24,\n",
    "    start_batch_inclusive=1,\n",
    "    end_batch_inclusive=1,\n",
    ")\n",
    "\n",
    "runner.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload to Neuronpedia\n",
    "#### This currently only works if you have admin access to the Neuronpedia database via localhost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helpers that fix weird NaN stuff\n",
    "from decimal import Decimal\n",
    "from typing import Any\n",
    "import math\n",
    "import json\n",
    "import os\n",
    "import requests\n",
    "\n",
    "FEATURE_OUTPUTS_FOLDER = runner.outputs_dir\n",
    "\n",
    "def nanToNeg999(obj: Any) -> Any:\n",
    "    if isinstance(obj, dict):\n",
    "        return {k: nanToNeg999(v) for k, v in obj.items()}\n",
    "    elif isinstance(obj, list):\n",
    "        return [nanToNeg999(v) for v in obj]\n",
    "    elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj):\n",
    "        return -999\n",
    "    return obj\n",
    "\n",
    "\n",
    "class NanConverter(json.JSONEncoder):\n",
    "    def encode(self, o: Any, *args: Any, **kwargs: Any):\n",
    "        return super().encode(nanToNeg999(o), *args, **kwargs)\n",
    "\n",
    "\n",
    "# Server info\n",
    "host = \"http://localhost:3000\"\n",
    "\n",
    "# Upload alive features\n",
    "for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER):\n",
    "    if file_name.startswith(\"batch-\") and file_name.endswith(\".json\"):\n",
    "        print(\"Uploading file: \" + file_name)\n",
    "        file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name)\n",
    "        f = open(file_path, \"r\")\n",
    "        data = json.load(f)\n",
    "\n",
    "        # Replace NaNs\n",
    "        data_fixed = json.dumps(data, cls=NanConverter)\n",
    "        data = json.loads(data_fixed)\n",
    "\n",
    "        url = host + \"/api/local/upload-features\"\n",
    "        resp = requests.post(\n",
    "            url,\n",
    "            json=data,\n",
    "        )\n",
    "\n",
    "# Upload dead feature stubs\n",
    "skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, \"skipped_indexes.json\")\n",
    "f = open(skipped_path, \"r\")\n",
    "data = json.load(f)\n",
    "url = host + \"/api/local/upload-dead-features\"\n",
    "resp = requests.post(\n",
    "    url,\n",
    "    json=data,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: Automatically validate the uploaded data"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mats_sae_training",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
