{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SAE Bench Eval Template\n",
    "\n",
    "## Overview\n",
    "\n",
    "Every eval type has the following:\n",
    "1. A corresponding sub-package. \n",
    "2. A main.py, which includes:\n",
    "   1.  An argparse interface (`arg_parse`) for running the eval from the command line \n",
    "   2.  A run_eval function which operates on a set of SAEs, producing a json with results per SAE. \n",
    "3. An eval_config.py with a pydantic dataclass inheriting from BaseEvalConfig, specific to that eval and defaults set to recommended values. \n",
    "4. An eval_output.py with a pydantic dataclass subclassing from BaseEvalOutput, with output specific to that eval.\n",
    "\n",
    "## CLI and Eval Config\n",
    "\n",
    "The CLI interface takes a combination of common arguments (same for all evals) and eval-type specific arguments. Eval-type specific arguments should match those in the eval_config of that sub-package. The common eval arguments should include:\n",
    "- `sae_regex_pattern` and `sae_block_pattern` used with regex to select SAEs from the SAE Lens library. \n",
    "- `output_folder` to place the output in. \n",
    "- `model_name` for loading a model from TransformerLens.\n",
    "\n",
    "Eval configs should be a pydantic.dataclass and inherit from BaseEvalConfig. This allows you to add \"Title\" and \"Description\" annotations to describe each field. You should do this so that when these fields are displayed in the UI, they have user-friendly display names, as well as hover-able descriptions that explain what each field means and its significance. For an example, see `/evals/absorption/eval_config.py`.\n",
    "\n",
    "To see which SAEs you can select via the regex arguments, use the SAE selection utils like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_bench.sae_bench_utils.sae_selection_utils import (\n",
    "    print_all_sae_releases,\n",
    "    print_release_details,\n",
    ")\n",
    "\n",
    "# Callum came up with this format which I like visually.\n",
    "print_all_sae_releases()  # each release has a corresponding model / repo_id. We recommend you don't select releases with different models when running evals.\n",
    "print_release_details(\n",
    "    \"gpt2-small-res-jb\"\n",
    ")  # each release has a number of possible SAEs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's an example call to the absorption eval. Note that we are selecting just one release / SAE (though we could select more) and that we're using default arguments for the eval-specific args (by not setting them via the CLI.)\n",
    "\n",
    "```bash\n",
    "python evals/absorption/main.py \\\n",
    "--sae_regex_pattern \"sae_bench_pythia70m_sweep_standard_ctx128_0712\" \\\n",
    "--sae_block_pattern \"blocks.4.hook_resid_post__trainer_.*\" \\\n",
    "--model_name pythia-70m-deduped \\\n",
    "--output_folder results\n",
    "```\n",
    "\n",
    "To create such an interface, an arg_parse function should be created in the main.py file as below and an EvalConfig should be instantiated in an eval_config.py file inside the eval subpackage. Eval configs should be dataclass objects that have serializable values (so it's easy to save them / load them.)\n",
    "\n",
    "You can test whether you've set up the `EvalConfig` and `arg_parser` correctly by using the `validate_eval_cli_interface` testing util. Feel free to change the CLI args / Eval Config to test the validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install pydantic\n",
    "\n",
    "import argparse\n",
    "\n",
    "from pydantic import Field\n",
    "from pydantic.dataclasses import dataclass\n",
    "\n",
    "from sae_bench.sae_bench_utils.testing_utils import validate_eval_cli_interface\n",
    "\n",
    "\n",
    "def arg_parser():\n",
    "    parser = argparse.ArgumentParser(description=\"Run absorption evaluation\")\n",
    "    parser.add_argument(\"--arg1\", type=int, default=42, help=\"Description for arg1\")\n",
    "    parser.add_argument(\"--arg2\", type=float, default=0.03, help=\"Description for arg2\")\n",
    "    parser.add_argument(\"--arg3\", type=int, default=10, help=\"Description for arg3\")\n",
    "    parser.add_argument(\n",
    "        \"--arg4\",\n",
    "        type=str,\n",
    "        default=\"{word} has the first letter:\",\n",
    "        help=\"Description for arg4\",\n",
    "    )\n",
    "    parser.add_argument(\"--arg5\", type=int, default=-6, help=\"Description for arg5\")\n",
    "    parser.add_argument(\n",
    "        \"--model_name\",\n",
    "        type=str,\n",
    "        default=\"pythia-70m-deduped\",\n",
    "        help=\"Description for arg6\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--sae_regex_pattern\",\n",
    "        type=str,\n",
    "        required=True,\n",
    "        help=\"Regex pattern for SAE selection\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--sae_block_pattern\",\n",
    "        type=str,\n",
    "        required=True,\n",
    "        help=\"Regex pattern for SAE block selection\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--output_folder\",\n",
    "        type=str,\n",
    "        default=\"eval_results/absorption\",\n",
    "        help=\"Output folder\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--force_rerun\", action=\"store_true\", help=\"Force rerun of experiments\"\n",
    "    )\n",
    "\n",
    "    return parser\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class EvalConfig:\n",
    "    arg1: int = Field(default=42, title=\"Arg1\", description=\"Description for arg1\")\n",
    "    arg2: float = Field(default=0.03, title=\"Arg2\", description=\"Description for arg2\")\n",
    "    arg3: int = Field(default=10, title=\"Arg3\", description=\"Description for arg3\")\n",
    "    arg4: str = Field(\n",
    "        default=\"{word} has the first letter:\",\n",
    "        title=\"Arg4\",\n",
    "        description=\"Description for arg4\",\n",
    "    )\n",
    "    arg5: int = Field(default=-6, title=\"Arg5\", description=\"Description for arg5\")\n",
    "    model_name: str = Field(\n",
    "        default=\"pythia-70m-deduped\",\n",
    "        title=\"Model Name\",\n",
    "        description=\"Description for model name\",\n",
    "    )\n",
    "\n",
    "\n",
    "validate_eval_cli_interface(arg_parser(), eval_config_cls=EvalConfig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Output Format\n",
    "\n",
    "Each output json should correspond to one SAE and their base structure is defined by `BaseEvalOutput` in `evals/base_eval_output.py`.\n",
    "\n",
    "An Eval Output is a pydantic.dataclass and inherits from `BaseEvalOutput`. This makes the actual JSON output files easily verifiable (since pydantic automatically generates JSON Schema from dataclasses) and portable to other languages/apps. And same as with the Eval Config, this also allows you to add \"Title\" and \"Description\" annotations to describe each field, which will be saved in that JSON schema. For an example, see `/evals/absorption/eval_output.py` and `/evals/absorption/eval_output_schema.json`.\n",
    "\n",
    "To build an output, inherit from BaseEvalOutput and create your own `eval_output.py`. An example from `absorption/eval_output.py` is partially pasted below:\n",
    "\n",
    "```\n",
    "# Define the eval output, which includes the eval config, metrics, and result details.\n",
    "# The title will end up being the title of the eval in the UI.\n",
    "@dataclass(config=ConfigDict(title=\"Feature Absorption Evaluation - First Letter\"))\n",
    "class AbsorptionEvalOutput(\n",
    "    BaseEvalOutput[\n",
    "        AbsorptionEvalConfig, AbsorptionMetricCategories, AbsorptionResultDetail\n",
    "    ]\n",
    "):\n",
    "    # This will end up being the description of the eval in the UI.\n",
    "    \"\"\"\n",
    "    The output of a feature absorption evaluation looking at the first letter.\n",
    "    \"\"\"\n",
    "\n",
    "    eval_config: AbsorptionEvalConfig\n",
    "    eval_id: str\n",
    "    datetime_epoch_millis: int\n",
    "    eval_result_metrics: AbsorptionMetricCategories\n",
    "    eval_result_details: list[AbsorptionResultDetail] = Field(\n",
    "        default_factory=list,\n",
    "        title=\"Per-Letter Absorption Results\",\n",
    "        description=\"Each object is a stat on the first letter of the absorption.\",\n",
    "    )\n",
    "    eval_type_id: str = Field(\n",
    "        default=\"absorption_first_letter\",\n",
    "        title=\"Eval Type ID\",\n",
    "        description=\"The type of the evaluation\",\n",
    "    )\n",
    "```\n",
    "\n",
    "Then, when you've run the eval, put the results into your eval_output type (in this case, AbsorptionEvalOutput), like so:\n",
    "\n",
    "```\n",
    "eval_output = AbsorptionEvalOutput(\n",
    "    eval_type_id=\"absorption_first_letter\",\n",
    "    eval_config=config,\n",
    "    eval_id=get_eval_uuid(),\n",
    "    datetime_epoch_millis=int(datetime.now().timestamp() * 1000),\n",
    "    eval_result_metrics=AbsorptionMetricCategories(\n",
    "        mean=AbsorptionMeanMetrics(\n",
    "            mean_absorption_score=statistics.mean(absorption_rates),\n",
    "            mean_num_split_features=statistics.mean(num_split_features),\n",
    "        )\n",
    "    ),\n",
    "    eval_result_details=eval_result_details,\n",
    "    sae_bench_commit_hash=get_sae_bench_version(),\n",
    "    sae_lens_id=sae_id,\n",
    "    sae_lens_release_id=sae_release,\n",
    "    sae_lens_version=get_sae_lens_version(),\n",
    ")\n",
    "```\n",
    "\n",
    "Finally, simply do a JSON dump to output to file:\n",
    "```\n",
    "eval_output.to_json_file(sae_result_path, indent=2)\n",
    "```\n",
    "\n",
    "Here's what that output would look like:\n",
    "```json\n",
    "{\n",
    "    \"eval_type_id\": \"absorption_first_letter\",\n",
    "    \"eval_config\": {\n",
    "        \"random_seed\": 42,\n",
    "        \"f1_jump_threshold\": 0.03,\n",
    "        \"max_k_value\": 10,\n",
    "        \"prompt_template\": \"{word} has the first letter:\",\n",
    "        \"prompt_token_pos\": -6,\n",
    "        \"model_name\": \"pythia-70m-deduped\"\n",
    "    },\n",
    "    \"eval_id\": \"0c057d5e-973e-410e-8e32-32569323b5e6\",\n",
    "    \"datetime_epoch_millis\": \"1729834113150\",\n",
    "    \"eval_result_metrics\": {\n",
    "        \"mean\": {\n",
    "            \"mean_absorption_score\": 2,\n",
    "            \"mean_num_split_features\": 3.5,\n",
    "        }\n",
    "    },\n",
    "    \"eval_result_details\": [\n",
    "        {\n",
    "            \"first_letter\": \"a\",\n",
    "            \"num_absorption\": 177,\n",
    "            \"absorption_rate\": 0.28780487804878047,\n",
    "            \"num_probe_true_positives\": 615.0,\n",
    "            \"num_split_features\": 1\n",
    "        },\n",
    "        {\n",
    "            \"first_letter\": \"b\",\n",
    "            \"num_absorption\": 51,\n",
    "            \"absorption_rate\": 0.1650485436893204,\n",
    "            \"num_probe_true_positives\": 309.0,\n",
    "            \"num_split_features\": 1\n",
    "        }\n",
    "    ],\n",
    "    \"sae_bench_commit_hash\": \"57e9be0ac9199dba6b9f87fe92f80532e9aefced\",\n",
    "    \"sae_lens_id\": \"blocks.3.hook_resid_post__trainer_10\",\n",
    "    \"sae_lens_release_id\": \"sae_bench_pythia70m_sweep_standard_ctx128_0712\",\n",
    "    \"sae_lens_version\": \"4.0.0\"\n",
    "}\n",
    "```\n",
    "\n",
    "You can see tests for this under `tests/evals/absorption/test_eval_output.py`.\n",
    "\n",
    "Since you're using a pydantic dataclass to define the output, you shouldn't need to do any additional re-verification of the output. However, if you want to check a JSON to see if it meets the defined output spec, you can call  `validate_eval_output_format_file` or `validate_eval_output_format_str` to check it.  Feel free to break the json and see the test fail. (eg: remove a field like `sae_lens_release_id`).\n",
    "\n",
    "The JSON schemas files themselves are generated with `evals/generate_json_schemas.py`, which can be updated by running:\n",
    "```\n",
    "python evals/generate_json_schemas.py\n",
    "```\n",
    "\n",
    "### What if I have unstructured outputs I want to save into the JSON?\n",
    "Put unstructured outputs into the `eval_result_unstructured`. This allows putting data of any type. However, be aware that since this has no structure, it's less likely to support sorting, filtering, or visualizations using these values. We highly encourage you to use the `eval_result_metrics` or `eval_result_details whenever possible`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "from sae_bench.evals.absorption.eval_output import AbsorptionEvalOutput\n",
    "from sae_bench.sae_bench_utils.testing_utils import validate_eval_output_format_file\n",
    "\n",
    "eval_results_temp = {\n",
    "    \"eval_type_id\": \"absorption_first_letter\",\n",
    "    \"eval_config\": {\n",
    "        \"random_seed\": 42,\n",
    "        \"f1_jump_threshold\": 0.03,\n",
    "        \"max_k_value\": 10,\n",
    "        \"prompt_template\": \"{word} has the first letter:\",\n",
    "        \"prompt_token_pos\": -6,\n",
    "        \"model_name\": \"pythia-70m-deduped\",\n",
    "    },\n",
    "    \"eval_id\": \"0c057d5e-973e-410e-8e32-32569323b5e6\",\n",
    "    \"datetime_epoch_millis\": \"1729834113150\",\n",
    "    \"eval_result_metrics\": {\n",
    "        \"mean\": {\n",
    "            \"mean_absorption_score\": 2,\n",
    "            \"mean_num_split_features\": 3.5,\n",
    "        }\n",
    "    },\n",
    "    \"eval_result_details\": [\n",
    "        {\n",
    "            \"first_letter\": \"a\",\n",
    "            \"num_absorption\": 177,\n",
    "            \"absorption_rate\": 0.28780487804878047,\n",
    "            \"num_probe_true_positives\": 615.0,\n",
    "            \"num_split_features\": 1,\n",
    "        },\n",
    "        {\n",
    "            \"first_letter\": \"b\",\n",
    "            \"num_absorption\": 51,\n",
    "            \"absorption_rate\": 0.1650485436893204,\n",
    "            \"num_probe_true_positives\": 309.0,\n",
    "            \"num_split_features\": 1,\n",
    "        },\n",
    "    ],\n",
    "    \"eval_result_unstructured\": {\n",
    "        \"pew pew\": \"pew pew\",\n",
    "        \"bar\": [\"woof\", 1, 3],\n",
    "        3: 3,\n",
    "    },\n",
    "    \"sae_bench_commit_hash\": \"57e9be0ac9199dba6b9f87fe92f80532e9aefced\",\n",
    "    \"sae_lens_id\": \"blocks.3.hook_resid_post__trainer_10\",\n",
    "    \"sae_lens_release_id\": \"sae_bench_pythia70m_sweep_standard_ctx128_0712\",\n",
    "    \"sae_lens_version\": \"4.0.0\",\n",
    "}\n",
    "\n",
    "\n",
    "# save to file\n",
    "with open(\"eval_results_temp.json\", \"w\") as f:\n",
    "    json.dump(eval_results_temp, f)\n",
    "\n",
    "validate_eval_output_format_file(\n",
    "    \"eval_results_temp.json\", eval_output_type=AbsorptionEvalOutput\n",
    ")\n",
    "\n",
    "# delete file\n",
    "os.remove(\"eval_results_temp.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then load the eval results jsons across many different SAEs and have a high level of visibility into which evals were run with which parameters and code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## A note on cached results\n",
    "\n",
    "A variety of evals can share the results intermediate computation, such as model activations or trained probes. Most of these will be model / hook point specific so should be saved along a path of the format `f'{artifact_dir}/{eval_type}/{model}/{hook_point}/{artifact_id}'`.\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae_bench_template",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
