{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c1c45b7a",
   "metadata": {},
   "source": [
    "## Gemma 2 2b L12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f258af5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_2_custom_sae_eval_results.json → avg_score = 0.7910\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/standard_april_update_309_0.7910.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_2_custom_sae_eval_results.json → avg_score = 0.7676\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/gated_340_0.7676.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_0_custom_sae_eval_results.json → avg_score = 0.7593\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/gated_948_0.7593.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_batch_top_k_resid_post_layer_12_trainer_0_custom_sae_eval_results.json → avg_score = 0.8407\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/batch_topk_50_0.8407.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_7_custom_sae_eval_results.json → avg_score = 0.8073\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/batch_topk_160_0.8073.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_4_custom_sae_eval_results.json → avg_score = 0.8147\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/standard_april_update_99_0.8147.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_5_custom_sae_eval_results.json → avg_score = 0.8291\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/standard_april_update_54_0.8291.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_3_custom_sae_eval_results.json → avg_score = 0.8024\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/standard_april_update_156_0.8024.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_3_custom_sae_eval_results.json → avg_score = 0.8151\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/topk_820_0.8151.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_0_custom_sae_eval_results.json → avg_score = 0.8417\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/topk_50_0.8417.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_4_custom_sae_eval_results.json → avg_score = 0.8257\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/jumprelu_52_0.8257.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_1_custom_sae_eval_results.json → avg_score = 0.7586\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/gated_547_0.7586.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_batch_top_k_resid_post_layer_12_trainer_1_custom_sae_eval_results.json → avg_score = 0.8146\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/batch_topk_320_0.8146.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_7_custom_sae_eval_results.json → avg_score = 0.8046\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/topk_160_0.8046.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_5_custom_sae_eval_results.json → avg_score = 0.7906\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/jumprelu_330_0.7906.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_6_custom_sae_eval_results.json → avg_score = 0.8253\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/topk_80_0.8253.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_2_custom_sae_eval_results.json → avg_score = 0.8254\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/topk_520_0.8254.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_batch_top_k_resid_post_layer_12_trainer_2_custom_sae_eval_results.json → avg_score = 0.7991\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/batch_topk_520_0.7991.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_3_custom_sae_eval_results.json → avg_score = 0.7844\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/gated_148_0.7844.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_4_custom_sae_eval_results.json → avg_score = 0.8043\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/gated_78_0.8043.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_1_custom_sae_eval_results.json → avg_score = 0.7789\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/standard_april_update_507_0.7789.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_gated_top_k_resid_post_layer_12_trainer_5_custom_sae_eval_results.json → avg_score = 0.8117\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/gated_49_0.8117.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_0_custom_sae_eval_results.json → avg_score = 0.7629\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/standard_april_update_733_0.7629.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_1_custom_sae_eval_results.json → avg_score = 0.7997\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/topk_320_0.7997.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_6_custom_sae_eval_results.json → avg_score = 0.8060\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/jumprelu_538_0.8060.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_batch_top_k_resid_post_layer_12_trainer_3_custom_sae_eval_results.json → avg_score = 0.7877\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/batch_topk_820_0.7877.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_6_custom_sae_eval_results.json → avg_score = 0.8223\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/batch_topk_80_0.8223.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_9_custom_sae_eval_results.json → avg_score = 0.7974\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/jumprelu_165_0.7974.csv\n",
      "\n",
      "gemma-sae_trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new_resid_post_layer_12_trainer_8_custom_sae_eval_results.json → avg_score = 0.8134\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/jumprelu_83_0.8134.csv\n",
      "\n",
      "gemma-sae_trained_saes_2__google_gemma-2-2b_top_k_jump_relu_resid_post_layer_12_trainer_7_custom_sae_eval_results.json → avg_score = 0.7939\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept500_gemma2_2b_data/jumprelu_779_0.7939.csv\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import re\n",
    "import random\n",
    "import csv\n",
    "\n",
    "# 1. Mapping from architecture+trainer to ID\n",
    "mapping = {\n",
    "    \"batch_topk_trainer0\": 50,\n",
    "    \"batch_topk_trainer1\": 320,\n",
    "    \"batch_topk_trainer2\": 520,\n",
    "    \"batch_topk_trainer3\": 820,\n",
    "    \"batch_topk_trainer6\": 80,\n",
    "    \"batch_topk_trainer7\": 160,\n",
    "    \"gated_trainer0\": 948,\n",
    "    \"gated_trainer1\": 547,\n",
    "    \"gated_trainer2\": 340,\n",
    "    \"gated_trainer3\": 148,\n",
    "    \"gated_trainer4\": 78,\n",
    "    \"gated_trainer5\": 49,\n",
    "    \"jumprelu_trainer4\": 52,\n",
    "    \"jumprelu_trainer5\": 330,\n",
    "    \"jumprelu_trainer6\": 538,\n",
    "    \"jumprelu_trainer7\": 779,\n",
    "    \"jumprelu_trainer8\": 83,\n",
    "    \"jumprelu_trainer9\": 165,\n",
    "    \"standard_april_update_trainer0\": 733,\n",
    "    \"standard_april_update_trainer1\": 507,\n",
    "    \"standard_april_update_trainer2\": 309,\n",
    "    \"standard_april_update_trainer3\": 156,\n",
    "    \"standard_april_update_trainer4\": 99,\n",
    "    \"standard_april_update_trainer5\": 54,\n",
    "    \"topk_trainer0\": 50,\n",
    "    \"topk_trainer1\": 320,\n",
    "    \"topk_trainer2\": 520,\n",
    "    \"topk_trainer3\": 820,\n",
    "    \"topk_trainer6\": 80,\n",
    "    \"topk_trainer7\": 160,\n",
    "}\n",
    "\n",
    "# 2. Paths configuration\n",
    "input_dir = \"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/eval_results/gemma2_2b/autointerp\"\n",
    "downloaded_root = \"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes\"\n",
    "output_dir = \"/home/dslabra5/sae4steer/axbench/axbench/concept100_gemma2_2b_data\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# 3. Set fixed random seed\n",
    "random.seed(42)\n",
    "\n",
    "# 4. Iterate through JSON files in the input directory\n",
    "for fname in os.listdir(input_dir):\n",
    "    if not fname.endswith(\".json\"):\n",
    "        continue\n",
    "\n",
    "    json_path = os.path.join(input_dir, fname)\n",
    "\n",
    "    # (1) Extract trainer number from the filename\n",
    "    m = re.search(r\"trainer_(\\d+)\", fname)\n",
    "    if not m:\n",
    "        print(f\"Skipping {fname}: trainer number not found\")\n",
    "        continue\n",
    "    trainer_num = m.group(1)\n",
    "\n",
    "    # (2) Read the JSON file\n",
    "    with open(json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    arch = data.get(\"sae_cfg_dict\", {}).get(\"architecture\")\n",
    "    if arch is None:\n",
    "        print(f\"Skipping {fname}: missing architecture\")\n",
    "        continue\n",
    "\n",
    "    key = f\"{arch}_trainer{trainer_num}\"\n",
    "    id_val = mapping.get(key)\n",
    "    if id_val is None:\n",
    "        print(f\"Skipping {fname}: key '{key}' not in mapping\")\n",
    "        continue\n",
    "\n",
    "    # (3) Locate ae.pt based on sae_lens_release_id\n",
    "    release_id = data.get(\"sae_lens_release_id\", \"\")\n",
    "    if \"_\" not in release_id:\n",
    "        print(f\"Skipping {fname}: release_id format incorrect\")\n",
    "        continue\n",
    "    loc_rest = release_id.split(\"_\", 1)[1]\n",
    "    if \"_resid_post_layer_\" not in loc_rest:\n",
    "        print(f\"Skipping {fname}: '_resid_post_layer_' not found in loc_rest\")\n",
    "        continue\n",
    "    loc_root, rest = loc_rest.split(\"_resid_post_layer_\", 1)\n",
    "    hook_layer_str = rest.split(\"_trainer_\")[0]\n",
    "    hook_folder = f\"resid_post_layer_{hook_layer_str}\"\n",
    "    trainer_folder = f\"trainer_{trainer_num}\"\n",
    "    ae_path = os.path.join(\n",
    "        downloaded_root,\n",
    "        loc_root,\n",
    "        hook_folder,\n",
    "        trainer_folder,\n",
    "        \"ae.pt\",\n",
    "    )\n",
    "\n",
    "    # (4) Randomly select up to n latent entries\n",
    "    unstruct = data.get(\"eval_result_unstructured\", {})\n",
    "    items = list(unstruct.items())  # [(latent, detail), ...]\n",
    "    chosen = items if len(items) < 100 else random.sample(items, 100)\n",
    "\n",
    "    # (5) Compute and print the average score\n",
    "    scores = [detail.get(\"score\", 0.0) for _, detail in chosen]\n",
    "    avg_score = sum(scores) / len(scores) if scores else 0.0\n",
    "    print(f\"{fname} → avg_score = {avg_score:.4f}\")\n",
    "\n",
    "    # (6) Write to CSV (no header), appending latent ID to ae.pt path\n",
    "    csv_name = f\"{arch}_{id_val}_{avg_score:.4f}.csv\"\n",
    "    csv_path = os.path.join(output_dir, csv_name)\n",
    "\n",
    "    with open(csv_path, \"w\", newline=\"\", encoding=\"utf-8\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        for latent, detail in chosen:\n",
    "            explanation = detail.get(\"explanation\", \"\")\n",
    "            # append latent ID to the ae.pt path\n",
    "            writer.writerow([explanation, f\"{ae_path}/{latent}\"])\n",
    "\n",
    "    print(f\"Written to {csv_path}\\n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e056156",
   "metadata": {},
   "source": [
    "## Gemma 2 2b L20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e70d4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import re\n",
    "import random\n",
    "import csv\n",
    "\n",
    "# 1. Mapping from architecture+trainer to ID\n",
    "mapping = {\n",
    "    # gated\n",
    "    \"gated_trainer0\": 653,\n",
    "    \"gated_trainer1\": 338,\n",
    "    \"gated_trainer2\": 225,\n",
    "    \"gated_trainer3\": 113,\n",
    "    \"gated_trainer4\": 66,\n",
    "    \"gated_trainer5\": 45,\n",
    "\n",
    "    # jumprelu\n",
    "    \"jumprelu_trainer12\": 50,\n",
    "    \"jumprelu_trainer13\": 81,\n",
    "    \"jumprelu_trainer14\": 163,\n",
    "    \"jumprelu_trainer15\": 325,\n",
    "    \"jumprelu_trainer16\": 505,\n",
    "    \"jumprelu_trainer17\": 778,\n",
    "\n",
    "    # standard_april_update\n",
    "    \"standard_april_update_trainer0\": 460,\n",
    "    \"standard_april_update_trainer1\": 320,\n",
    "    \"standard_april_update_trainer2\": 208,\n",
    "    \"standard_april_update_trainer3\": 117,\n",
    "    \"standard_april_update_trainer4\": 79,\n",
    "    \"standard_april_update_trainer5\": 47,\n",
    "\n",
    "    # batch_topk\n",
    "    \"batch_topk_trainer6\": 50,\n",
    "    \"batch_topk_trainer7\": 80,\n",
    "    \"batch_topk_trainer8\": 160,\n",
    "    \"batch_topk_trainer9\": 320,\n",
    "    \"batch_topk_trainer10\": 520,\n",
    "    \"batch_topk_trainer11\": 820,\n",
    "\n",
    "    # topk\n",
    "    \"topk_trainer6\": 50,\n",
    "    \"topk_trainer7\": 80,\n",
    "    \"topk_trainer8\": 160,\n",
    "    \"topk_trainer9\": 320,\n",
    "    \"topk_trainer10\": 520,\n",
    "    \"topk_trainer11\": 820,\n",
    "}\n",
    "\n",
    "\n",
    "# 2. Paths configuration\n",
    "input_dir = \"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/eval_results/gemma2_2b/l20/autointerp\"\n",
    "downloaded_root = \"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes\"\n",
    "output_dir = \"/home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# 3. Set fixed random seed\n",
    "random.seed(42)\n",
    "\n",
    "# 4. Iterate through JSON files in the input directory\n",
    "for fname in os.listdir(input_dir):\n",
    "    if not fname.endswith(\".json\"):\n",
    "        continue\n",
    "\n",
    "    json_path = os.path.join(input_dir, fname)\n",
    "\n",
    "    # (1) Extract trainer number from the filename\n",
    "    m = re.search(r\"trainer_(\\d+)\", fname)\n",
    "    if not m:\n",
    "        print(f\"Skipping {fname}: trainer number not found\")\n",
    "        continue\n",
    "    trainer_num = m.group(1)\n",
    "\n",
    "    # (2) Read the JSON file\n",
    "    with open(json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    arch = data.get(\"sae_cfg_dict\", {}).get(\"architecture\")\n",
    "    if arch is None:\n",
    "        print(f\"Skipping {fname}: missing architecture\")\n",
    "        continue\n",
    "\n",
    "    key = f\"{arch}_trainer{trainer_num}\"\n",
    "    id_val = mapping.get(key)\n",
    "    if id_val is None:\n",
    "        print(f\"Skipping {fname}: key '{key}' not in mapping\")\n",
    "        continue\n",
    "\n",
    "    # (3) Locate ae.pt based on sae_lens_release_id\n",
    "    release_id = data.get(\"sae_lens_release_id\", \"\")\n",
    "    if \"_\" not in release_id:\n",
    "        print(f\"Skipping {fname}: release_id format incorrect\")\n",
    "        continue\n",
    "    loc_rest = release_id.split(\"_\", 1)[1]\n",
    "    if \"_resid_post_layer_\" not in loc_rest:\n",
    "        print(f\"Skipping {fname}: '_resid_post_layer_' not found in loc_rest\")\n",
    "        continue\n",
    "    loc_root, rest = loc_rest.split(\"_resid_post_layer_\", 1)\n",
    "    hook_layer_str = rest.split(\"_trainer_\")[0]\n",
    "    hook_folder = f\"resid_post_layer_{hook_layer_str}\"\n",
    "    trainer_folder = f\"trainer_{trainer_num}\"\n",
    "    ae_path = os.path.join(\n",
    "        downloaded_root,\n",
    "        loc_root,\n",
    "        hook_folder,\n",
    "        trainer_folder,\n",
    "        \"ae.pt\",\n",
    "    )\n",
    "\n",
    "    # (4) Randomly select up to n latent entries\n",
    "    unstruct = data.get(\"eval_result_unstructured\", {})\n",
    "    items = list(unstruct.items())  # [(latent, detail), ...]\n",
    "    chosen = items if len(items) < 100 else random.sample(items, 100)\n",
    "\n",
    "    # (5) Compute and print the average score\n",
    "    scores = [detail.get(\"score\", 0.0) for _, detail in chosen]\n",
    "    avg_score = sum(scores) / len(scores) if scores else 0.0\n",
    "    print(f\"{fname} → avg_score = {avg_score:.4f}\")\n",
    "\n",
    "    # (6) Write to CSV (no header), appending latent ID to ae.pt path\n",
    "    csv_name = f\"{arch}_{id_val}_{avg_score:.4f}.csv\"\n",
    "    csv_path = os.path.join(output_dir, csv_name)\n",
    "\n",
    "    with open(csv_path, \"w\", newline=\"\", encoding=\"utf-8\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        for latent, detail in chosen:\n",
    "            explanation = detail.get(\"explanation\", \"\")\n",
    "            # append latent ID to the ae.pt path\n",
    "            writer.writerow([explanation, f\"{ae_path}/{latent}\"])\n",
    "\n",
    "    print(f\"Written to {csv_path}\\n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68fa7e38",
   "metadata": {},
   "source": [
    "## Qwen 2.5 3b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "61a37085",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_0_custom_sae_eval_results.json → avg_score = 0.7371\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/gated_999_0.7371.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_12_custom_sae_eval_results.json → avg_score = 0.8457\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/batch_topk_50_0.8457.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_3_custom_sae_eval_results.json → avg_score = 0.7900\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/gated_141_0.7900.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_11_custom_sae_eval_results.json → avg_score = 0.7157\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/topk_820_0.7157.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_4_custom_sae_eval_results.json → avg_score = 0.8286\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/standard_april_update_108_0.8286.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_14_custom_sae_eval_results.json → avg_score = 0.8300\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/batch_topk_160_0.8300.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_9_custom_sae_eval_results.json → avg_score = 0.7857\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/jumprelu_323_0.7857.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_1_custom_sae_eval_results.json → avg_score = 0.7450\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/gated_565_0.7450.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_11_custom_sae_eval_results.json → avg_score = 0.7593\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/jumprelu_754_0.7593.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_0_custom_sae_eval_results.json → avg_score = 0.7793\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/standard_april_update_762_0.7793.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_16_custom_sae_eval_results.json → avg_score = 0.7407\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/batch_topk_520_0.7407.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_7_custom_sae_eval_results.json → avg_score = 0.8343\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/topk_80_0.8343.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_9_custom_sae_eval_results.json → avg_score = 0.8243\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/topk_320_0.8243.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_2_custom_sae_eval_results.json → avg_score = 0.8143\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/standard_april_update_321_0.8143.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_2_custom_sae_eval_results.json → avg_score = 0.7779\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/gated_338_0.7779.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_5_custom_sae_eval_results.json → avg_score = 0.8257\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/standard_april_update_61_0.8257.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_1_custom_sae_eval_results.json → avg_score = 0.7907\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/standard_april_update_523_0.7907.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_13_custom_sae_eval_results.json → avg_score = 0.8229\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/batch_topk_80_0.8229.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_15_custom_sae_eval_results.json → avg_score = 0.7850\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/batch_topk_320_0.7850.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_3_custom_sae_eval_results.json → avg_score = 0.8121\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/standard_april_update_167_0.8121.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_7_custom_sae_eval_results.json → avg_score = 0.8379\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/jumprelu_82_0.8379.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_8_custom_sae_eval_results.json → avg_score = 0.8107\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/jumprelu_166_0.8107.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_10_custom_sae_eval_results.json → avg_score = 0.7743\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/jumprelu_494_0.7743.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_17_custom_sae_eval_results.json → avg_score = 0.7357\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/batch_topk_820_0.7357.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_8_custom_sae_eval_results.json → avg_score = 0.8129\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/topk_160_0.8129.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_6_custom_sae_eval_results.json → avg_score = 0.8279\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/jumprelu_51_0.8279.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_10_custom_sae_eval_results.json → avg_score = 0.7814\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/topk_520_0.7814.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_4_custom_sae_eval_results.json → avg_score = 0.8143\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/gated_72_0.8143.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_gated_jump_relu_resid_post_layer_17_trainer_5_custom_sae_eval_results.json → avg_score = 0.8364\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/gated_46_0.8364.csv\n",
      "\n",
      "gemma-sae_saes_Qwen_Qwen2.5-3B_batch_top_k_top_k_standard_new_resid_post_layer_17_trainer_6_custom_sae_eval_results.json → avg_score = 0.8457\n",
      "Written to /home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data/topk_50_0.8457.csv\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import re\n",
    "import random\n",
    "import csv\n",
    "\n",
    "# 1. Mapping from architecture+trainer to ID\n",
    "mapping = {\n",
    "    # gated\n",
    "    \"gated_trainer0\": 999,\n",
    "    \"gated_trainer1\": 565,\n",
    "    \"gated_trainer2\": 338,\n",
    "    \"gated_trainer3\": 141,\n",
    "    \"gated_trainer4\": 72,\n",
    "    \"gated_trainer5\": 46,\n",
    "\n",
    "    # jumprelu\n",
    "    \"jumprelu_trainer6\": 51,\n",
    "    \"jumprelu_trainer7\": 82,\n",
    "    \"jumprelu_trainer8\": 166,\n",
    "    \"jumprelu_trainer9\": 323,\n",
    "    \"jumprelu_trainer10\": 494,\n",
    "    \"jumprelu_trainer11\": 754,\n",
    "\n",
    "    # standard_april_update\n",
    "    \"standard_april_update_trainer0\": 762,\n",
    "    \"standard_april_update_trainer1\": 523,\n",
    "    \"standard_april_update_trainer2\": 321,\n",
    "    \"standard_april_update_trainer3\": 167,\n",
    "    \"standard_april_update_trainer4\": 108,\n",
    "    \"standard_april_update_trainer5\": 61,\n",
    "\n",
    "    # batch_topk\n",
    "    \"batch_topk_trainer12\": 50,\n",
    "    \"batch_topk_trainer13\": 80,\n",
    "    \"batch_topk_trainer14\": 160,\n",
    "    \"batch_topk_trainer15\": 320,\n",
    "    \"batch_topk_trainer16\": 520,\n",
    "    \"batch_topk_trainer17\": 820,\n",
    "\n",
    "    # topk\n",
    "    \"topk_trainer6\": 50,\n",
    "    \"topk_trainer7\": 80,\n",
    "    \"topk_trainer8\": 160,\n",
    "    \"topk_trainer9\": 320,\n",
    "    \"topk_trainer10\": 520,\n",
    "    \"topk_trainer11\": 820,\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "# 2. Paths configuration\n",
    "input_dir = \"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/eval_results/qwen2.5_3b/l17/autointerp\"\n",
    "downloaded_root = \"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes\"\n",
    "output_dir = \"/home/dslabra5/sae4steer/axbench/axbench/concept100_qwen2.5_3b_data\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "# 3. Set fixed random seed\n",
    "random.seed(42)\n",
    "\n",
    "# 4. Iterate through JSON files in the input directory\n",
    "for fname in os.listdir(input_dir):\n",
    "    if not fname.endswith(\".json\"):\n",
    "        continue\n",
    "\n",
    "    json_path = os.path.join(input_dir, fname)\n",
    "\n",
    "    # (1) Extract trainer number from the filename\n",
    "    m = re.search(r\"trainer_(\\d+)\", fname)\n",
    "    if not m:\n",
    "        print(f\"Skipping {fname}: trainer number not found\")\n",
    "        continue\n",
    "    trainer_num = m.group(1)\n",
    "\n",
    "    # (2) Read the JSON file\n",
    "    with open(json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    arch = data.get(\"sae_cfg_dict\", {}).get(\"architecture\")\n",
    "    if arch is None:\n",
    "        print(f\"Skipping {fname}: missing architecture\")\n",
    "        continue\n",
    "\n",
    "    key = f\"{arch}_trainer{trainer_num}\"\n",
    "    id_val = mapping.get(key)\n",
    "    if id_val is None:\n",
    "        print(f\"Skipping {fname}: key '{key}' not in mapping\")\n",
    "        continue\n",
    "\n",
    "    # (3) Locate ae.pt based on sae_lens_release_id\n",
    "    release_id = data.get(\"sae_lens_release_id\", \"\")\n",
    "    if \"_\" not in release_id:\n",
    "        print(f\"Skipping {fname}: release_id format incorrect\")\n",
    "        continue\n",
    "    loc_rest = release_id.split(\"_\", 1)[1]\n",
    "    if \"_resid_post_layer_\" not in loc_rest:\n",
    "        print(f\"Skipping {fname}: '_resid_post_layer_' not found in loc_rest\")\n",
    "        continue\n",
    "    loc_root, rest = loc_rest.split(\"_resid_post_layer_\", 1)\n",
    "    hook_layer_str = rest.split(\"_trainer_\")[0]\n",
    "    hook_folder = f\"resid_post_layer_{hook_layer_str}\"\n",
    "    trainer_folder = f\"trainer_{trainer_num}\"\n",
    "    ae_path = os.path.join(\n",
    "        downloaded_root,\n",
    "        loc_root,\n",
    "        hook_folder,\n",
    "        trainer_folder,\n",
    "        \"ae.pt\",\n",
    "    )\n",
    "\n",
    "    # (4) Randomly select up to n latent entries\n",
    "    unstruct = data.get(\"eval_result_unstructured\", {})\n",
    "    items = list(unstruct.items())  # [(latent, detail), ...]\n",
    "    chosen = items if len(items) < 100 else random.sample(items, 100)\n",
    "\n",
    "    # (5) Compute and print the average score\n",
    "    scores = [detail.get(\"score\", 0.0) for _, detail in chosen]\n",
    "    avg_score = sum(scores) / len(scores) if scores else 0.0\n",
    "    print(f\"{fname} → avg_score = {avg_score:.4f}\")\n",
    "\n",
    "    # (6) Write to CSV (no header), appending latent ID to ae.pt path\n",
    "    csv_name = f\"{arch}_{id_val}_{avg_score:.4f}.csv\"\n",
    "    csv_path = os.path.join(output_dir, csv_name)\n",
    "\n",
    "    with open(csv_path, \"w\", newline=\"\", encoding=\"utf-8\") as csvfile:\n",
    "        writer = csv.writer(csvfile)\n",
    "        for latent, detail in chosen:\n",
    "            explanation = detail.get(\"explanation\", \"\")\n",
    "            # append latent ID to the ae.pt path\n",
    "            writer.writerow([explanation, f\"{ae_path}/{latent}\"])\n",
    "\n",
    "    print(f\"Written to {csv_path}\\n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c509518",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae4steer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
