{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_bench.sae_bench_utils.sae_selection_utils import (\n",
    "    get_saes_from_regex,\n",
    "    print_all_sae_releases,\n",
    "    print_release_details,\n",
    ")\n",
    "\n",
    "# Callum came up with this format which I like visually.\n",
    "print_all_sae_releases()\n",
    "print_release_details(\"gpt2-small-res-jb\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cases:\n",
    "- Select all canonical Gemma Scope 2b res SAEs for all sizes, layer 12\n",
    "- Select all canonical Gemma Scope 2b, For layers 5,12,19, get all res, mlp and attn saes of size 16k or 65k\n",
    "- Select all Gemma Scope 2b, 16k res SAEs of all sparsities. \n",
    "- Select all sae bench gemma 2 2b SAEs vanilla, and topk, size 4k and 8k (both expansion factors, all sparsities)\n",
    "- Select all layer 3 and 4 pythia 70m SAES, Vanilla, TopK, Gated, P-Anneal, all sparsities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all canonical Gemma Scope 2b res SAEs for all sizes, layer 12\n",
    "sae_regex_pattern = r\"gemma-scope-2b-pt-res-canonical\"\n",
    "sae_block_pattern = r\".*layer_12.*\"\n",
    "\n",
    "# all mlp Gemma Scope 2b SAEs for 16k size, layer 12\n",
    "sae_regex_pattern = r\"gemma-scope-2b-pt-mlp\"\n",
    "sae_block_pattern = r\".*layer_5.*(16k).*\"\n",
    "\n",
    "# canonical Gemma Scope 2b, For layers 5,12,19, get all res, mlp and attn saes of size 16k or 65k\n",
    "sae_regex_pattern = r\"(gemma-scope-2b-pt-(res|att|mlp)-canonical)\"\n",
    "sae_block_pattern = r\".*layer_(5|12|19).*(16k|65k).*\"\n",
    "\n",
    "# Select all sae bench layer 19 gemma-2-2b topk 16k width SAEs (excluding checkpoints)\n",
    "sae_regex_pattern = r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\"\n",
    "sae_block_pattern = r\".*blocks\\.19(?!.*step).*\"\n",
    "\n",
    "# Select all sae bench layer 5 gemma-2-2b topk 16k width SAEs (including checkpoints)\n",
    "sae_regex_pattern = r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\"\n",
    "sae_block_pattern = r\".*blocks\\.5.*\"\n",
    "\n",
    "# Select all sae bench layer 3 and 4 pythia 70M SAEs: Vanilla, TopK, Gated, P-Anneal, all sparsities\n",
    "sae_regex_pattern = r\"sae_bench_pythia70m_sweep.*_ctx128_.*\"\n",
    "sae_block_pattern = r\".*blocks\\.([3|4])\\.hook_resid_post__trainer_.*\"\n",
    "\n",
    "# Select trainer ids 2 and 10 for layer 4 pythia 70m topk SAEs\n",
    "sae_regex_pattern = r\"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*\"\n",
    "sae_block_pattern = r\".*blocks\\.([4])\\.hook_resid_post__trainer_(2|10)$\"\n",
    "\n",
    "selected_saes = get_saes_from_regex(sae_regex_pattern, sae_block_pattern)\n",
    "\n",
    "print(f\"Selected {len(selected_saes)} SAEs:\")\n",
    "\n",
    "for sae_release, sae_id in selected_saes:\n",
    "    print(f\"{sae_release} - {sae_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae_bench_template",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
