{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "import sae_bench.custom_saes.batch_topk_sae as batch_topk_sae\n",
    "import sae_bench.custom_saes.gated_sae as gated_sae\n",
    "import sae_bench.custom_saes.jumprelu_sae as jumprelu_sae\n",
    "import sae_bench.custom_saes.relu_sae as relu_sae\n",
    "import sae_bench.custom_saes.topk_sae as topk_sae\n",
    "\n",
    "TRAINER_LOADERS = {\n",
    "    \"StandardTrainer\": relu_sae.load_dictionary_learning_relu_sae,\n",
    "    \"StandardTrainerAprilUpdate\": relu_sae.load_dictionary_learning_relu_sae,\n",
    "    \"PAnnealTrainer\": relu_sae.load_dictionary_learning_relu_sae,\n",
    "    \"TopKTrainer\": topk_sae.load_dictionary_learning_topk_sae,\n",
    "    \"JumpReluTrainer\": jumprelu_sae.load_dictionary_learning_jump_relu_sae,\n",
    "    \"BatchTopKTrainer\": batch_topk_sae.load_dictionary_learning_batch_topk_sae,\n",
    "    \"GatedSAETrainer\": gated_sae.load_dictionary_learning_gated_sae,\n",
    "}\n",
    "\n",
    "\n",
    "def get_all_hf_repo_autoencoders(\n",
    "    repo_id: str, download_location: str = \"downloaded_saes\"\n",
    ") -> list[str]:\n",
    "    download_location = os.path.join(download_location, repo_id.replace(\"/\", \"_\"))\n",
    "    config_dir = snapshot_download(\n",
    "        repo_id,\n",
    "        allow_patterns=[\"*config.json\"],\n",
    "        local_dir=download_location,\n",
    "        force_download=False,\n",
    "    )\n",
    "\n",
    "    configs = []\n",
    "\n",
    "    for root, _, files in os.walk(config_dir):\n",
    "        for file in files:\n",
    "            if file == \"config.json\":\n",
    "                configs.append(os.path.join(root, file))\n",
    "\n",
    "    repo_locations = []\n",
    "\n",
    "    for config in configs:\n",
    "        repo_location = config.split(f\"{download_location}/\")[1].split(\"config.json\")[0]\n",
    "        repo_locations.append(repo_location)\n",
    "\n",
    "    return repo_locations\n",
    "\n",
    "\n",
    "repo_id = \"adamkarvonen/sae_test\"\n",
    "locations = get_all_hf_repo_autoencoders(repo_id)\n",
    "\n",
    "print(locations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sae_bench.custom_saes.base_sae as base_sae\n",
    "\n",
    "layer = 3\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "dtype = torch.float32\n",
    "\n",
    "model_name = \"EleutherAI/pythia-70m-deduped\"\n",
    "\n",
    "\n",
    "def load_dictionary_learning_sae(\n",
    "    repo_id: str, location: str, layer: int, model_name, device: str, dtype: torch.dtype\n",
    ") -> base_sae.BaseSAE:\n",
    "    for key, loader in TRAINER_LOADERS.items():\n",
    "        if key in location:\n",
    "            sae = loader(\n",
    "                repo_id=repo_id,\n",
    "                filename=location,\n",
    "                layer=layer,\n",
    "                model_name=model_name,\n",
    "                device=device,\n",
    "                dtype=dtype,\n",
    "            )\n",
    "            return sae\n",
    "\n",
    "    raise ValueError(f\"Could not find a loader for {location}\")\n",
    "\n",
    "\n",
    "sae = load_dictionary_learning_sae(\n",
    "    repo_id, f\"{locations[0]}ae.pt\", layer, model_name, device, dtype\n",
    ")\n",
    "sae.test_sae(model_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
