{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To use: Move into root dir, run all."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens.toolkit.pretrained_sae_loaders import (\n",
    "    SAEConfigLoadOptions,\n",
    "    get_sae_config,\n",
    ")\n",
    "\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import (\n",
    "    select_saes_multiple_patterns,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "hf_repo_id = \"neuronpedia/sae-evals\"\n",
    "local_dir = \"./temp_sae_evals\"\n",
    "os.makedirs(local_dir, exist_ok=True)\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=hf_repo_id,\n",
    "    local_dir=local_dir,\n",
    "    repo_type=\"dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "hf_repo_id = \"neuronpedia/sae-evals\"\n",
    "local_dir = \"./temp_sae_evals\"\n",
    "os.makedirs(local_dir, exist_ok=True)\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=hf_repo_id,\n",
    "    local_dir=local_dir,\n",
    "    repo_type=\"dataset\",\n",
    "    ignore_patterns=[\"*autointerp_with_generations*\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_regex_patterns = [\n",
    "    r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    r\"(gemma-scope-2b-pt-res)\",\n",
    "    r\"(gemma-scope-9b-pt-res)\",\n",
    "    r\"(gemma-scope-2b-pt-res-canonical)\",\n",
    "    r\"(gemma-scope-9b-pt-res-canonical)\",\n",
    "    r\"sae_bench_pythia70m_sweep_gated_ctx128_0730\",\n",
    "    r\"sae_bench_pythia70m_sweep_panneal_ctx128_0730\",\n",
    "    r\"sae_bench_pythia70m_sweep_standard_ctx128_0712\",\n",
    "    r\"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n",
    "]\n",
    "\n",
    "# Include checkpoints (not relevant to Gemma-Scope)\n",
    "sae_block_pattern = [\".*\"] * len(sae_regex_patterns)\n",
    "\n",
    "# Exclude checkpoints\n",
    "# sae_block_pattern = [\n",
    "#     # rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     # rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     # rf\".*layer_({layer}).*(16k).*\", # For Gemma-Scope\n",
    "# ]\n",
    "\n",
    "\n",
    "assert len(sae_regex_patterns) == len(sae_block_pattern)\n",
    "\n",
    "selected_saes = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folders = [f\"{local_dir}/core\"]\n",
    "import json\n",
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "total = 0\n",
    "total_updated = 0\n",
    "\n",
    "for folder in tqdm(folders):\n",
    "    for sae_release, sae_id in tqdm(selected_saes):\n",
    "        sae_id = sae_id\n",
    "        total += 1\n",
    "\n",
    "        if \"blocks\" not in sae_id:\n",
    "            continue\n",
    "\n",
    "        old_filename = f\"{folder}/{sae_release}/{sae_release}_{sae_id.replace('.', '_')}_128_Skylion007_openwebtext.json\"\n",
    "        new_filename = f\"{folder}/{sae_release}/{sae_release}_{sae_id}_128_Skylion007_openwebtext.json\"\n",
    "\n",
    "        # print(old_filename)\n",
    "        # print(new_filename)\n",
    "\n",
    "        if not os.path.exists(old_filename):\n",
    "            continue\n",
    "\n",
    "        sae_cfg = get_sae_config(\n",
    "            sae_release,\n",
    "            sae_id,\n",
    "            options=SAEConfigLoadOptions(),\n",
    "        )\n",
    "\n",
    "        # print(type(sae_cfg))\n",
    "        # print(sae_cfg)\n",
    "        # break\n",
    "\n",
    "        with open(old_filename) as f:\n",
    "            eval_results = json.load(f)\n",
    "\n",
    "        eval_results[\"sae_cfg_dict\"] = sae_cfg\n",
    "\n",
    "        with open(new_filename, \"w\") as f:\n",
    "            json.dump(eval_results, f, indent=4)\n",
    "\n",
    "        # print(f\"Updated {new_filename}\")\n",
    "\n",
    "        # delete old file\n",
    "        os.remove(old_filename)\n",
    "\n",
    "        # break\n",
    "        total_updated += 1\n",
    "\n",
    "print(total, total_updated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folders = [\n",
    "    f\"{local_dir}/absorption\",\n",
    "    f\"{local_dir}/autointerp\",\n",
    "    f\"{local_dir}/scr\",\n",
    "    f\"{local_dir}/sparse_probing\",\n",
    "    f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/unlearning\",\n",
    "]\n",
    "import json\n",
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "total = 0\n",
    "total_updated = 0\n",
    "\n",
    "for folder in tqdm(folders):\n",
    "    for sae_release, sae_id in tqdm(selected_saes):\n",
    "        sae_id = sae_id\n",
    "        total += 1\n",
    "        old_filename = f\"{folder}/{sae_release}/{sae_release}_{sae_id.replace('/', '_')}_eval_results.json\"\n",
    "\n",
    "        # print(filename)\n",
    "\n",
    "        if not os.path.exists(old_filename):\n",
    "            continue\n",
    "\n",
    "        sae_cfg = get_sae_config(\n",
    "            sae_release,\n",
    "            sae_id,\n",
    "            options=SAEConfigLoadOptions(),\n",
    "        )\n",
    "\n",
    "        # print(type(sae_cfg))\n",
    "        # print(sae_cfg)\n",
    "        # break\n",
    "\n",
    "        with open(old_filename) as f:\n",
    "            eval_results = json.load(f)\n",
    "\n",
    "        eval_results[\"sae_cfg_dict\"] = sae_cfg\n",
    "\n",
    "        with open(old_filename, \"w\") as f:\n",
    "            json.dump(eval_results, f, indent=4)\n",
    "\n",
    "        # print(f\"Updated {filename}\")\n",
    "\n",
    "        # break\n",
    "        total_updated += 1\n",
    "\n",
    "print(total, total_updated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "def inspect_local_directory(directory):\n",
    "    if not os.path.exists(directory):\n",
    "        print(f\"Directory does not exist: {directory}\")\n",
    "        return\n",
    "\n",
    "    if not os.path.isdir(directory):\n",
    "        print(f\"Path is not a directory: {directory}\")\n",
    "        return\n",
    "\n",
    "    print(f\"Inspecting files in directory: {directory}\")\n",
    "    for root, dirs, files in os.walk(directory):\n",
    "        for file in files:\n",
    "            file_path = os.path.join(root, file)\n",
    "            try:\n",
    "                # Attempt to open the file in binary mode to check for corruption\n",
    "                with open(file_path, \"rb\") as f:\n",
    "                    f.read()  # Read the first 1 KB of the file\n",
    "                # print(f\"File is accessible: {file_path}\")\n",
    "            except Exception as e:\n",
    "                print(f\"Error accessing file {file_path}: {e}\")\n",
    "\n",
    "\n",
    "# Replace with the actual local directory path\n",
    "# local_dir = \"your_directory_path_here\"\n",
    "inspect_local_directory(local_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from huggingface_hub import HfApi\n",
    "# api = HfApi()\n",
    "\n",
    "# api.upload_folder(\n",
    "#     folder_path=\"temp_sae_evals_2\",\n",
    "#     path_in_repo=\"\",\n",
    "#     repo_id=\"adamkarvonen/sae_bench_results\",\n",
    "#     repo_type=\"dataset\",\n",
    "#     # allow_patterns=\"*eval_results.json\"\n",
    "#     ignore_patterns=[\".DS_Store\"]\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import HfApi\n",
    "\n",
    "api = HfApi()\n",
    "\n",
    "test_dir = \"eval_results\"\n",
    "\n",
    "api.upload_folder(\n",
    "    folder_path=test_dir,\n",
    "    path_in_repo=\"\",\n",
    "    repo_id=\"adamkarvonen/sae_bench_results\",\n",
    "    repo_type=\"dataset\",\n",
    "    # allow_patterns=\"*eval_results.json\"\n",
    "    ignore_patterns=[\".DS_Store\", \".git\", \".git/**\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# purpose: all files are dumped into the same results folder.\n",
    "# separate by sae release name\n",
    "\n",
    "import os\n",
    "import shutil\n",
    "\n",
    "\n",
    "def organize_files():\n",
    "    # Define the two prefixes we're looking for\n",
    "    prefixes = [\n",
    "        \"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "        \"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    ]\n",
    "\n",
    "    # Create folders if they don't exist\n",
    "    for prefix in prefixes:\n",
    "        if not os.path.exists(prefix):\n",
    "            os.makedirs(prefix)\n",
    "\n",
    "    # Get all json files in current directory\n",
    "    files = [f for f in os.listdir(\".\") if f.endswith(\".json\")]\n",
    "\n",
    "    # Move files to appropriate folders\n",
    "    for file in files:\n",
    "        for prefix in prefixes:\n",
    "            if file.startswith(prefix):\n",
    "                shutil.move(file, os.path.join(prefix, file))\n",
    "                print(f\"Moved {file} to {prefix}/\")\n",
    "                break\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    organize_files()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# purpose: remove the llm generations, which is over 99% of the file size\n",
    "\n",
    "import json\n",
    "import os\n",
    "\n",
    "\n",
    "def process_json_files(directory):\n",
    "    \"\"\"Recursively process all JSON files in directory and its subdirectories\"\"\"\n",
    "    count = 0\n",
    "    for root, _, files in os.walk(directory):\n",
    "        for file in files:\n",
    "            if file.endswith(\".json\"):\n",
    "                filepath = os.path.join(root, file)\n",
    "                try:\n",
    "                    # Read the JSON file\n",
    "                    with open(filepath) as f:\n",
    "                        data = json.load(f)\n",
    "\n",
    "                    # Remove the key if it exists\n",
    "                    if \"eval_result_unstructured\" in data:\n",
    "                        del data[\"eval_result_unstructured\"]\n",
    "                        count += 1\n",
    "\n",
    "                    # Write back the modified data\n",
    "                    with open(filepath, \"w\") as f:\n",
    "                        json.dump(data, f)\n",
    "\n",
    "                    print(f\"Processed: {filepath}\")\n",
    "\n",
    "                except Exception as e:\n",
    "                    print(f\"Error processing {filepath}: {str(e)}\")\n",
    "\n",
    "    return count\n",
    "\n",
    "\n",
    "# Process files starting from the current directory\n",
    "starting_dir = \".\"\n",
    "files_modified = process_json_files(starting_dir)\n",
    "print(f\"\\nCompleted! Modified {files_modified} files.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This cell is to add \"training_tokens\" to the sae_cfg_dict for all sae bench sae results\n",
    "\n",
    "import json\n",
    "import os\n",
    "import re\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns\n",
    "\n",
    "folders = [\n",
    "    f\"{local_dir}/absorption\",\n",
    "    f\"{local_dir}/autointerp\",\n",
    "    f\"{local_dir}/core\",\n",
    "    f\"{local_dir}/scr\",\n",
    "    f\"{local_dir}/sparse_probing\",\n",
    "    f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/unlearning\",\n",
    "]\n",
    "\n",
    "sae_regex_patterns = [\n",
    "    r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    r\"sae_bench_pythia70m_sweep_gated_ctx128_0730\",\n",
    "    r\"sae_bench_pythia70m_sweep_panneal_ctx128_0730\",\n",
    "    r\"sae_bench_pythia70m_sweep_standard_ctx128_0712\",\n",
    "    r\"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n",
    "]\n",
    "\n",
    "# Include checkpoints (not relevant to Gemma-Scope)\n",
    "sae_block_pattern = [\".*\"] * len(sae_regex_patterns)\n",
    "\n",
    "# Exclude checkpoints\n",
    "# sae_block_pattern = [\n",
    "#     # rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     # rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     rf\".*blocks\\.{layer}(?!.*step).*\",\n",
    "#     # rf\".*layer_({layer}).*(16k).*\", # For Gemma-Scope\n",
    "# ]\n",
    "\n",
    "\n",
    "assert len(sae_regex_patterns) == len(sae_block_pattern)\n",
    "\n",
    "selected_saes = select_saes_multiple_patterns(sae_regex_patterns, sae_block_pattern)\n",
    "\n",
    "local_dir = \".\"\n",
    "\n",
    "\n",
    "def get_sae_bench_train_tokens(sae_release: str, sae_id: str) -> int:\n",
    "    \"\"\"This is for SAE Bench internal use. The SAE cfg does not contain the number of training tokens, so we need to hardcode it.\"\"\"\n",
    "\n",
    "    if \"sae_bench\" not in sae_release:\n",
    "        raise ValueError(\"This function is only for SAE Bench releases\")\n",
    "\n",
    "    if \"pythia\" in sae_release:\n",
    "        batch_size = 4096\n",
    "    else:\n",
    "        batch_size = 2048\n",
    "\n",
    "    if \"step\" not in sae_id:\n",
    "        if \"pythia\" in sae_release:\n",
    "            steps = 48828\n",
    "        elif \"2pow14\" in sae_release:\n",
    "            steps = 146484\n",
    "        elif \"2pow12\" or \"2pow16\" in sae_release:\n",
    "            steps = 97656\n",
    "        else:\n",
    "            raise ValueError(f\"sae release {sae_release} not recognized\")\n",
    "\n",
    "        return steps * batch_size\n",
    "    else:\n",
    "        match = re.search(r\"step_(\\d+)\", sae_id)\n",
    "        if match:\n",
    "            step = int(match.group(1))\n",
    "            return step * batch_size\n",
    "        else:\n",
    "            raise ValueError(\"No step match found\")\n",
    "\n",
    "\n",
    "total = 0\n",
    "total_updated = 0\n",
    "\n",
    "for folder in tqdm(folders):\n",
    "    for sae_release, sae_id in tqdm(selected_saes):\n",
    "        sae_id = sae_id\n",
    "        total += 1\n",
    "        old_filename = f\"{folder}/{sae_release}/{sae_release}_{sae_id.replace('/', '_')}_eval_results.json\"\n",
    "\n",
    "        # print(filename)\n",
    "\n",
    "        if not os.path.exists(old_filename):\n",
    "            continue\n",
    "\n",
    "        # print(type(sae_cfg))\n",
    "        # print(sae_cfg)\n",
    "        # break\n",
    "\n",
    "        with open(old_filename) as f:\n",
    "            eval_results = json.load(f)\n",
    "\n",
    "        eval_results[\"sae_cfg_dict\"][\"training_tokens\"] = get_sae_bench_train_tokens(\n",
    "            sae_release, sae_id\n",
    "        )\n",
    "\n",
    "        with open(old_filename, \"w\") as f:\n",
    "            json.dump(eval_results, f, indent=4)\n",
    "\n",
    "        # print(f\"Updated {filename}\")\n",
    "\n",
    "        # break\n",
    "        total_updated += 1\n",
    "\n",
    "print(total, total_updated)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
