{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# (Currently chess only) Dataframe comparing SAE statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Imports\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import torch\n",
    "import einops\n",
    "from datasets import load_dataset\n",
    "from typing import Callable, Optional\n",
    "import math\n",
    "import os\n",
    "import itertools\n",
    "import json\n",
    "import gc\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from dataclasses import dataclass\n",
    "import torch\n",
    "from nnsight import NNsight\n",
    "import json\n",
    "from typing import Any\n",
    "from datasets import load_dataset\n",
    "from einops import rearrange\n",
    "from jaxtyping import Int, Float, jaxtyped\n",
    "from torch import Tensor\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from transformers import GPT2LMHeadModel\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "from circuits.othello_buffer import OthelloActivationBuffer\n",
    "from circuits.dictionary_learning import AutoEncoder\n",
    "from circuits.chess_utils import encode_string\n",
    "from circuits.dictionary_learning import ActivationBuffer\n",
    "from circuits.dictionary_learning.dictionary import AutoEncoder, GatedAutoEncoder\n",
    "from circuits.dictionary_learning.trainers.gated_anneal import GatedAnnealTrainer\n",
    "from circuits.dictionary_learning.trainers.gdm import GatedSAETrainer\n",
    "from circuits.dictionary_learning.trainers.p_anneal import PAnnealTrainer\n",
    "from circuits.dictionary_learning.trainers.standard import StandardTrainer\n",
    "from circuits.dictionary_learning.evaluation import evaluate\n",
    "from circuits.nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model\n",
    "from circuits.eval_sae_as_classifier import (\n",
    "    initialize_results_dict, \n",
    "    get_data_batch, \n",
    "    apply_indexing_function,\n",
    "    construct_eval_dataset,\n",
    "    construct_othello_dataset,\n",
    "    prep_firing_rate_data,\n",
    ")\n",
    "from circuits.utils import (\n",
    "    get_model, \n",
    "    get_submodule,\n",
    "    get_ae_bundle,\n",
    "    collect_activations_batch,\n",
    "    get_nested_folders,\n",
    "    get_firing_features,\n",
    "    to_device,\n",
    "    AutoEncoderBundle,\n",
    ")\n",
    "import circuits.chess_utils as chess_utils\n",
    "import circuits.othello_utils as othello_utils\n",
    "import circuits.othello_engine_utils as othello_engine_utils\n",
    "\n",
    "from circuits.dictionary_learning.evaluation import evaluate\n",
    "\n",
    "from IPython import embed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Globals\n",
    "\n",
    "# Dimension key (from https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd):\n",
    "# F  = features and minibatch size depending on the context (maybe this is stupid)\n",
    "# B = batch_size\n",
    "# L = seq length (context length)\n",
    "# T = thresholds\n",
    "# R = rows (or cols)\n",
    "# C = classes for one hot encoding\n",
    "\n",
    "home_dir = '/share/u/can'\n",
    "repo_dir = f'{home_dir}/chess-gpt-circuits'\n",
    "\n",
    "DEVICE = 'cuda:0'\n",
    "torch.set_grad_enabled(False)\n",
    "batch_size = 32\n",
    "feature_batch_size = batch_size\n",
    "n_inputs = 2048 # Length of the eval dataset\n",
    "GAME = \"chess\" # \"chess\" or \"othello\"\n",
    "\n",
    "models_path = repo_dir + \"/models/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset and init game specific variables\n",
    "\n",
    "if GAME == \"chess\":\n",
    "    othello = False\n",
    "\n",
    "    autoencoder_group_paths = [\"/autoencoders/group1/\"]\n",
    "    custom_functions = [chess_utils.board_to_piece_state] #, chess_utils.board_to_pin_state]\n",
    "    model_name = \"redacted/8LayerChessGPT2\"\n",
    "    # data = construct_eval_dataset(custom_functions, n_inputs, models_path=models_path, device=DEVICE)\n",
    "    indexing_functions = [chess_utils.get_even_list_indices]\n",
    "\n",
    "elif GAME == \"othello\":\n",
    "    othello = True\n",
    "\n",
    "    autoencoder_group_paths = [\"/autoencoders/othello_layer0/\"]\n",
    "    # autoencoder_group_paths = [\"autoencoders/othello_layer0/\", \"autoencoders/othello_layer5_ef4/\"]\n",
    "    custom_functions = [\n",
    "            # othello_utils.games_batch_no_last_move_to_state_stack_BLRRC,\n",
    "            othello_utils.games_batch_to_state_stack_BLRRC,\n",
    "            othello_utils.games_batch_to_state_stack_mine_yours_BLRRC,\n",
    "        ]\n",
    "    model_name = \"Baidicoot/Othello-GPT-Transformer-Lens\"\n",
    "    # data = construct_othello_dataset(custom_functions, n_inputs, models_path=models_path, device=DEVICE)\n",
    "    indexing_functions = [None]  # I'm experimenting with these for Othello\n",
    "else:\n",
    "    raise ValueError(\"Invalid game\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General dataset statistic\n",
    "\n",
    "This is only dataset dependent, but not SAE dependent and can be calculated once after loading the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_true_board_state_counts(pgn_strings):\n",
    "    # Find the true counts of board states over all movers and games in the dataset\n",
    "    # This could be calculated within the board_to_piece_state evaluation!\n",
    "    true_board_states_counts = chess_utils.create_state_stacks(pgn_strings, chess_utils.board_to_piece_state)\n",
    "    true_board_states_counts = chess_utils.state_stack_to_one_hot(\n",
    "        chess_utils.config_lookup[chess_utils.board_to_piece_state.__name__], \n",
    "        DEVICE, \n",
    "        true_board_states_counts)\n",
    "    true_board_states_counts = true_board_states_counts.sum(dim=(0,1))\n",
    "    true_board_states_counts.shape # [RRC]\n",
    "    return true_board_states_counts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SAE specific statistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard evals\n",
    "def do_standard_evals(results, ae_bundle):\n",
    "    eval_results = evaluate(\n",
    "        ae_bundle.ae,\n",
    "        ae_bundle.buffer,\n",
    "        max_len=ae_bundle.context_length,\n",
    "        batch_size=min(512, batch_size), # min(n_eval_samples, activation_buffer_out_batch_size) matters\n",
    "        io=\"out\",\n",
    "        device=DEVICE,\n",
    "        n_batches=1000\n",
    "    )\n",
    "    for k, v in eval_results.items():\n",
    "        results[k] = v\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluation of custom functions\n",
    "def eval_custom_fn(\n",
    "    results,\n",
    "    n_act_threshs,\n",
    "    alive_features_F,\n",
    "    max_activations_F,\n",
    "    ae_bundle,\n",
    "    pgn_strings,\n",
    "    custom_functions,\n",
    "    encoded_inputs,\n",
    "    firing_rate_n_inputs,\n",
    "    indexing_function\n",
    "):\n",
    "    num_features = len(alive_features_F)\n",
    "    print(\n",
    "        f\"Out of {ae_bundle.dictionary_size} features, on {firing_rate_n_inputs} activations, {num_features} are alive.\"\n",
    "    )\n",
    "\n",
    "    assert len(pgn_strings) >= n_inputs\n",
    "    assert n_inputs % batch_size == 0\n",
    "\n",
    "    n_iters = n_inputs // batch_size\n",
    "    # We round up to ensure we don't ignore the remainder of features\n",
    "    num_feature_iters = math.ceil(num_features / feature_batch_size)\n",
    "\n",
    "    thresholds_T = torch.linspace(0, 1, n_act_threshs).to(DEVICE)\n",
    "    thresholds_TF11 = einops.repeat(thresholds_T, \"T -> T F 1 1\", F=num_features)\n",
    "    max_activations_1F11 = einops.repeat(max_activations_F, \"F -> 1 F 1 1\")\n",
    "    thresholds_TF11 = thresholds_TF11 * max_activations_1F11\n",
    "\n",
    "    for i in tqdm(range(n_iters), desc=\"Aggregating statistics\"):\n",
    "        start = i * batch_size\n",
    "        end = (i + 1) * batch_size\n",
    "        pgn_strings_BL = pgn_strings[start:end]\n",
    "        encoded_inputs_BL = encoded_inputs[start:end]\n",
    "        encoded_inputs_BL = torch.tensor(encoded_inputs_BL).to(DEVICE)\n",
    "\n",
    "        batch_data = get_data_batch(data, pgn_strings_BL, start, end, custom_functions, DEVICE)\n",
    "\n",
    "        all_activations_FBL, encoded_token_inputs = collect_activations_batch(\n",
    "            ae_bundle, encoded_inputs_BL, alive_features_F\n",
    "        )\n",
    "\n",
    "        if indexing_function is not None:\n",
    "            all_activations_FBL, batch_data = apply_indexing_function(\n",
    "                pgn_strings[start:end], all_activations_FBL, batch_data, DEVICE, indexing_function\n",
    "            )\n",
    "        # For thousands of features, this would be many GB of memory. So, we minibatch.\n",
    "        for feature in range(num_feature_iters):\n",
    "            f_start = feature * feature_batch_size\n",
    "            f_end = min((feature + 1) * feature_batch_size, num_features)\n",
    "            f_batch_size = f_end - f_start\n",
    "\n",
    "            activations_FBL = all_activations_FBL[\n",
    "                f_start:f_end\n",
    "            ]  \n",
    "            \n",
    "            thresholds_TF11_slice = thresholds_TF11[:, f_start:f_end, :, :]\n",
    "            # NOTE: Now F == feature_batch_size\n",
    "            # Maybe that's stupid and inconsistent and I should use a new letter for annotations\n",
    "            # I'll roll with it for now\n",
    "\n",
    "\n",
    "            ### Aggregate batch statistics\n",
    "            active_indices_TFBL = activations_FBL > thresholds_TF11_slice\n",
    "            active_counts_TF = einops.reduce(active_indices_TFBL, \"T F B L -> T F\", \"sum\")\n",
    "            off_counts_TF = einops.reduce(~active_indices_TFBL, \"T F B L -> T F\", \"sum\")\n",
    "\n",
    "            results[\"on_count\"][:, f_start:f_end] += active_counts_TF\n",
    "            results[\"off_count\"][:, f_start:f_end] += off_counts_TF\n",
    "\n",
    "            for custom_function in custom_functions:\n",
    "                on_tracker_TFRRC = results[custom_function.__name__][\"on\"]\n",
    "                off_tracker_FTRRC = results[custom_function.__name__][\"off\"]\n",
    "\n",
    "                boards_BLRRC = batch_data[custom_function.__name__]\n",
    "                boards_TFBLRRC = einops.repeat(\n",
    "                    boards_BLRRC,\n",
    "                    \"B L R1 R2 C -> T F B L R1 R2 C\",\n",
    "                    F=f_batch_size,\n",
    "                    T=thresholds_TF11_slice.shape[0],\n",
    "                )\n",
    "\n",
    "                # TODO The next 2 operations consume almost all of the compute. I don't think it will work,\n",
    "                # but maybe we can only do 1 of these operations?\n",
    "                active_boards_sum_TFRRC = einops.reduce(\n",
    "                    boards_TFBLRRC * active_indices_TFBL[:, :, :, :, None, None, None],\n",
    "                    \"T F B L R1 R2 C -> T F R1 R2 C\",\n",
    "                    \"sum\",\n",
    "                )\n",
    "                off_boards_sum_TFRRC = einops.reduce(\n",
    "                    boards_TFBLRRC * ~active_indices_TFBL[:, :, :, :, None, None, None],\n",
    "                    \"T F B L R1 R2 C -> T F R1 R2 C\",\n",
    "                    \"sum\",\n",
    "                )\n",
    "\n",
    "                on_tracker_TFRRC[:, f_start:f_end, :, :, :] += active_boards_sum_TFRRC\n",
    "                off_tracker_FTRRC[:, f_start:f_end, :, :, :] += off_boards_sum_TFRRC\n",
    "\n",
    "                results[custom_function.__name__][\"on\"] = on_tracker_TFRRC\n",
    "                results[custom_function.__name__][\"off\"] = off_tracker_FTRRC\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Precision, recall, and F1\n",
    "\n",
    "def get_classification_metrics(results, true_board_states_counts):\n",
    "    precision_thresh = 0.9\n",
    "    recall_thresh = 0.01\n",
    "    f1_thresh = 0.01\n",
    "    threshs = [precision_thresh, recall_thresh, f1_thresh]\n",
    "    eps = 1e-8\n",
    "    R = 8\n",
    "    C = 13\n",
    "\n",
    "    true_pos_TFRRC = results['board_to_piece_state']['on'] \n",
    "    pos_all_TF = results['on_count']\n",
    "    true_all_RRC = true_board_states_counts\n",
    "\n",
    "    precision = true_pos_TFRRC / (pos_all_TF[:, :, None, None, None] +eps) # Note that a feature which always fires (piece present/absent) will have a precision of 1\n",
    "    recall = true_pos_TFRRC / (true_all_RRC[None, None, :, :, :] +eps)\n",
    "    f1 = 2 * (precision * recall) / (precision + recall + eps)\n",
    "    metrics_TFRRC = [precision, recall, f1]\n",
    "\n",
    "    # Apply threshold\n",
    "    counts_TFRRC = [metric > thresh for metric, thresh in zip(metrics_TFRRC, threshs)]\n",
    "\n",
    "    # Drop empty square state counts\n",
    "    for i in range(len(counts_TFRRC)):\n",
    "        counts_TFRRC[i][..., 6] = False\n",
    "    num_board_states = R * R * (C-1)\n",
    "\n",
    "\n",
    "    ### Fraction of features with high metric on at least one board state\n",
    "    # High metric for at least one board state\n",
    "    counts_any_board_TF = [metric.any(dim=(-1,-2,-3)) for metric in counts_TFRRC]\n",
    "\n",
    "    # Report fraction of all features for count_as_firing_threshold = 0\n",
    "    frac_any_board_nonzero_1 = [metric[0].float().mean() for metric in counts_any_board_TF]\n",
    "\n",
    "    # Report fraction of all features for any threshold (choose threshold per feature that maximizes ratio)\n",
    "    frac_any_board_best_1 = [metric.any(dim=0).float().mean() for metric in counts_any_board_TF]\n",
    "\n",
    "\n",
    "    ### Fraction of board states that have at least one feature with high metric\n",
    "    # Check for each board state whether at least one feature has a high metric (using count_as_firing_threshold = 0)\n",
    "    counts_any_feature_nonzero_RCC = [metric[0].any(dim=0) for metric in counts_TFRRC]\n",
    "\n",
    "    # Check for each board state whether at least one feature has a high metric (for any count_as_firing threshold)\n",
    "    counts_any_feature_best_RCC = [metric.any(dim=(0,1)) for metric in counts_TFRRC]\n",
    "\n",
    "    # Fraction of individual board states at least one feature has a high metric\n",
    "    frac_any_feature_nonzero_RCC = [metric.sum() / num_board_states for metric in counts_any_feature_nonzero_RCC]\n",
    "    frac_any_feature_best_RCC = [metric.sum() / num_board_states for metric in counts_any_feature_best_RCC]\n",
    "\n",
    "    print(frac_any_board_nonzero_1)\n",
    "    print(frac_any_board_best_1)\n",
    "    print(frac_any_feature_nonzero_RCC)\n",
    "    print(frac_any_feature_best_RCC)\n",
    "\n",
    "    names = ['precision', 'recall', 'f1']\n",
    "    for i, (name, t) in enumerate(zip(names, threshs)):\n",
    "        results[f'frac_any_board_per_feature_act-nonzero_{name}-{t}'] = frac_any_board_nonzero_1[i].item()\n",
    "        results[f'frac_any_board_per_feature_act-best_{name}-{t}'] = frac_any_board_best_1[i].item()\n",
    "        results[f'frac_any_feature_per_board_act-nonzero_{name}-{t}'] = frac_any_feature_nonzero_RCC[i].item()\n",
    "        results[f'frac_any_feature_per_board_act-best_{name}-{t}'] = frac_any_feature_best_RCC[i].item()\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop over SAEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-03_layer=5/\n",
      "idx_fn: <function get_even_list_indices at 0x7f2fdcb6d8a0>\n",
      "\n",
      "ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-03_layer=5/\n",
      "idx_fn: <function get_even_list_indices at 0x7f2fdcb6d8a0>\n",
      "\n",
      "ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-01_layer=5/\n",
      "idx_fn: <function get_even_list_indices at 0x7f2fdcb6d8a0>\n",
      "\n",
      "ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-01_layer=5/\n",
      "idx_fn: <function get_even_list_indices at 0x7f2fdcb6d8a0>\n",
      "\n",
      "ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=3e-01_layer=5/\n",
      "idx_fn: <function get_even_list_indices at 0x7f2fdcb6d8a0>\n",
      "\n",
      "ae_dir: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=3e-01_layer=5/\n",
      "idx_fn: <function get_even_list_indices at 0x7f2fdcb6d8a0>\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Autoencoder loop:   0%|          | 0/6 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-03_layer=5/\n",
      "Indexing function: get_even_list_indices\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: using manual setting of layer to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Collecting features: 100%|██████████| 8000/8000 [00:05<00:00, 1595.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "do_standard_evals\n",
      "do custom eval metrics\n",
      "Out of 8192 features, on 256000 activations, 4870 are alive.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Aggregating statistics: 100%|██████████| 64/64 [02:01<00:00,  1.90s/it]\n",
      "Autoencoder loop:  17%|█▋        | 1/6 [02:23<11:55, 143.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.8639, device='cuda:0'), tensor(0.5805, device='cuda:0'), tensor(0.6335, device='cuda:0')]\n",
      "[tensor(0.8639, device='cuda:0'), tensor(0.5805, device='cuda:0'), tensor(0.6349, device='cuda:0')]\n",
      "[tensor(0.1406, device='cuda:0'), tensor(0.1237, device='cuda:0'), tensor(0.1055, device='cuda:0')]\n",
      "[tensor(0.2161, device='cuda:0'), tensor(0.1237, device='cuda:0'), tensor(0.1120, device='cuda:0')]\n",
      "Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-03_layer=5/\n",
      "Indexing function: get_even_list_indices\n",
      "WARNING: using manual setting of layer to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Collecting features: 100%|██████████| 8000/8000 [00:05<00:00, 1566.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "do_standard_evals\n",
      "do custom eval metrics\n",
      "Out of 2048 features, on 256000 activations, 1696 are alive.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Aggregating statistics: 100%|██████████| 64/64 [00:55<00:00,  1.15it/s]\n",
      "Autoencoder loop:  33%|███▎      | 2/6 [03:39<06:56, 104.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.9693, device='cuda:0'), tensor(0.8644, device='cuda:0'), tensor(0.8939, device='cuda:0')]\n",
      "[tensor(0.9693, device='cuda:0'), tensor(0.8644, device='cuda:0'), tensor(0.8939, device='cuda:0')]\n",
      "[tensor(0.0794, device='cuda:0'), tensor(0.1237, device='cuda:0'), tensor(0.1003, device='cuda:0')]\n",
      "[tensor(0.2135, device='cuda:0'), tensor(0.1237, device='cuda:0'), tensor(0.1198, device='cuda:0')]\n",
      "Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=1e-01_layer=5/\n",
      "Indexing function: get_even_list_indices\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: using manual setting of layer to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Collecting features: 100%|██████████| 8000/8000 [00:05<00:00, 1580.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "do_standard_evals\n",
      "do custom eval metrics\n",
      "Out of 2048 features, on 256000 activations, 2039 are alive.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Aggregating statistics: 100%|██████████| 64/64 [01:02<00:00,  1.02it/s]\n",
      "Autoencoder loop:  50%|█████     | 3/6 [05:02<04:43, 94.50s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.9652, device='cuda:0'), tensor(0.1535, device='cuda:0'), tensor(0.2315, device='cuda:0')]\n",
      "[tensor(0.9652, device='cuda:0'), tensor(0.1535, device='cuda:0'), tensor(0.2379, device='cuda:0')]\n",
      "[tensor(0.1432, device='cuda:0'), tensor(0.0859, device='cuda:0'), tensor(0.1094, device='cuda:0')]\n",
      "[tensor(0.2799, device='cuda:0'), tensor(0.0859, device='cuda:0'), tensor(0.1107, device='cuda:0')]\n",
      "Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=1e-01_layer=5/\n",
      "Indexing function: get_even_list_indices\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: using manual setting of layer to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Collecting features: 100%|██████████| 8000/8000 [00:05<00:00, 1578.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "do_standard_evals\n",
      "do custom eval metrics\n",
      "Out of 8192 features, on 256000 activations, 1444 are alive.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Aggregating statistics: 100%|██████████| 64/64 [00:50<00:00,  1.27it/s]\n",
      "Autoencoder loop:  67%|██████▋   | 4/6 [06:14<02:50, 85.27s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.6579, device='cuda:0'), tensor(0.1690, device='cuda:0'), tensor(0.2154, device='cuda:0')]\n",
      "[tensor(0.6579, device='cuda:0'), tensor(0.1690, device='cuda:0'), tensor(0.2202, device='cuda:0')]\n",
      "[tensor(0.1419, device='cuda:0'), tensor(0.0990, device='cuda:0'), tensor(0.1185, device='cuda:0')]\n",
      "[tensor(0.2253, device='cuda:0'), tensor(0.0990, device='cuda:0'), tensor(0.1198, device='cuda:0')]\n",
      "Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=16_lr=1e-03_l1=3e-01_layer=5/\n",
      "Indexing function: get_even_list_indices\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: using manual setting of layer to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Collecting features: 100%|██████████| 8000/8000 [00:05<00:00, 1583.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "do_standard_evals\n",
      "do custom eval metrics\n",
      "Out of 8192 features, on 256000 activations, 657 are alive.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Aggregating statistics: 100%|██████████| 64/64 [00:34<00:00,  1.86it/s]\n",
      "Autoencoder loop:  83%|████████▎ | 5/6 [07:08<01:14, 74.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.3668, device='cuda:0'), tensor(0.0365, device='cuda:0'), tensor(0.0426, device='cuda:0')]\n",
      "[tensor(0.3668, device='cuda:0'), tensor(0.0365, device='cuda:0'), tensor(0.0426, device='cuda:0')]\n",
      "[tensor(0.1172, device='cuda:0'), tensor(0.1211, device='cuda:0'), tensor(0.0924, device='cuda:0')]\n",
      "[tensor(0.1406, device='cuda:0'), tensor(0.1211, device='cuda:0'), tensor(0.1003, device='cuda:0')]\n",
      "Autoencoder: /share/u/can/chess-gpt-circuits/autoencoders/group1/ef=4_lr=1e-03_l1=3e-01_layer=5/\n",
      "Indexing function: get_even_list_indices\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/u/can/miniconda3/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: using manual setting of layer to 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Collecting features: 100%|██████████| 8000/8000 [00:04<00:00, 1601.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "do_standard_evals\n",
      "do custom eval metrics\n",
      "Out of 2048 features, on 256000 activations, 1078 are alive.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Aggregating statistics: 100%|██████████| 64/64 [00:42<00:00,  1.50it/s]\n",
      "Autoencoder loop: 100%|██████████| 6/6 [08:11<00:00, 81.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.8386, device='cuda:0'), tensor(0.0455, device='cuda:0'), tensor(0.1002, device='cuda:0')]\n",
      "[tensor(0.8386, device='cuda:0'), tensor(0.0455, device='cuda:0'), tensor(0.1030, device='cuda:0')]\n",
      "[tensor(0.1667, device='cuda:0'), tensor(0.0547, device='cuda:0'), tensor(0.0872, device='cuda:0')]\n",
      "[tensor(0.2344, device='cuda:0'), tensor(0.0547, device='cuda:0'), tensor(0.0885, device='cuda:0')]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Choose aes and indexing functions\n",
    "\n",
    "# This could be computed once before the loop if adapting loading pgn_strings\n",
    "# true_board_state_counts = get_true_board_state_counts(pgn_strings)\n",
    "\n",
    "sweep_results = {}\n",
    "sweep_result_keys = ['l0', 'frac_variance_explained', 'cossim', 'l2_ratio', 'frac_any_board_per_feature_act-nonzero_precision-0.9', 'frac_any_board_per_feature_act-best_precision-0.9', 'frac_any_feature_per_board_act-nonzero_precision-0.9', 'frac_any_feature_per_board_act-best_precision-0.9', 'frac_any_board_per_feature_act-nonzero_recall-0.01', 'frac_any_board_per_feature_act-best_recall-0.01', 'frac_any_feature_per_board_act-nonzero_recall-0.01', 'frac_any_feature_per_board_act-best_recall-0.01', 'frac_any_board_per_feature_act-nonzero_f1-0.01', 'frac_any_board_per_feature_act-best_f1-0.01', 'frac_any_feature_per_board_act-nonzero_f1-0.01', 'frac_any_feature_per_board_act-best_f1-0.01']\n",
    "\n",
    "all_autoencoder_paths = []\n",
    "for group_path in autoencoder_group_paths:\n",
    "    all_autoencoder_paths += get_nested_folders(repo_dir + group_path) \n",
    "\n",
    "param_combinations = list(itertools.product(all_autoencoder_paths, indexing_functions))\n",
    "\n",
    "for ae_dir, idx_fn in param_combinations:\n",
    "    print(f'ae_dir: {ae_dir}')\n",
    "    print(f'idx_fn: {idx_fn}\\n')\n",
    "\n",
    "# autoencoder_path, indexing_function = param_combinations[1]\n",
    "\n",
    "for autoencoder_path, indexing_function in tqdm(param_combinations, desc=\"Autoencoder loop\", total=len(param_combinations)):\n",
    "    torch.cuda.empty_cache()\n",
    "    gc.collect()\n",
    "    \n",
    "    indexing_function_name = \"None\"\n",
    "    if indexing_function is not None:\n",
    "        indexing_function_name = indexing_function.__name__\n",
    "\n",
    "    print(f\"Autoencoder: {autoencoder_path}\")\n",
    "    print(f\"Indexing function: {indexing_function_name}\")\n",
    "\n",
    "    # TODO Function below manipulates the loaded data. If we change that, we can load data once and for all at the top of the file\n",
    "    data = construct_eval_dataset(custom_functions, n_inputs, models_path=models_path, device=DEVICE)\n",
    "    data, ae_bundle, pgn_strings, encoded_inputs = prep_firing_rate_data(\n",
    "        autoencoder_path, batch_size, models_path, model_name, data, DEVICE, n_inputs, othello\n",
    "    )\n",
    "\n",
    "    firing_rate_n_inputs = min(int(n_inputs * 0.5), 1000) * ae_bundle.context_length\n",
    "    # TODO: Custom thresholds per feature based on max activations\n",
    "    alive_features_F, max_activations_F = get_firing_features(\n",
    "        ae_bundle, firing_rate_n_inputs, batch_size, DEVICE\n",
    "    )\n",
    "    true_board_states_counts = get_true_board_state_counts(pgn_strings)\n",
    "    assert true_board_states_counts is not None\n",
    "\n",
    "    # initialize result dictionary\n",
    "    n_act_threshs = 10\n",
    "    results = initialize_results_dict(custom_functions, n_act_threshs, alive_features_F, DEVICE)\n",
    "\n",
    "    # Standard evaluation metrics\n",
    "    print('do_standard_evals')\n",
    "    results = do_standard_evals(results, ae_bundle)\n",
    "    del ae_bundle.buffer\n",
    "    \n",
    "    # Do custom eval metrics\n",
    "    print('do custom eval metrics')\n",
    "    results = eval_custom_fn(\n",
    "        results,\n",
    "        n_act_threshs,\n",
    "        alive_features_F,\n",
    "        max_activations_F,\n",
    "        ae_bundle,\n",
    "        pgn_strings,\n",
    "        custom_functions,\n",
    "        encoded_inputs,\n",
    "        firing_rate_n_inputs,\n",
    "        indexing_function,\n",
    "    )\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "    gc.collect()\n",
    "\n",
    "    results = get_classification_metrics(results, true_board_states_counts)\n",
    "    ae_name = autoencoder_path.split('/')[-2]\n",
    "    sweep_results[ae_name] = {}\n",
    "    for sweep_key in sweep_result_keys:\n",
    "        sweep_results[ae_name][sweep_key] = results[sweep_key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>frac_variance_explained</th>\n",
       "      <th>l0</th>\n",
       "      <th>frac_any_board_per_feature_act-nonzero_precision-0.9</th>\n",
       "      <th>frac_any_board_per_feature_act-best_precision-0.9</th>\n",
       "      <th>frac_any_feature_per_board_act-nonzero_precision-0.9</th>\n",
       "      <th>frac_any_feature_per_board_act-best_precision-0.9</th>\n",
       "      <th>frac_any_board_per_feature_act-nonzero_recall-0.01</th>\n",
       "      <th>frac_any_board_per_feature_act-best_recall-0.01</th>\n",
       "      <th>frac_any_feature_per_board_act-nonzero_recall-0.01</th>\n",
       "      <th>frac_any_feature_per_board_act-best_recall-0.01</th>\n",
       "      <th>frac_any_board_per_feature_act-nonzero_f1-0.01</th>\n",
       "      <th>frac_any_board_per_feature_act-best_f1-0.01</th>\n",
       "      <th>frac_any_feature_per_board_act-nonzero_f1-0.01</th>\n",
       "      <th>frac_any_feature_per_board_act-best_f1-0.01</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ef=16_lr=1e-03_l1=3e-01_layer=5</th>\n",
       "      <td>-1391422.000</td>\n",
       "      <td>2.131</td>\n",
       "      <td>0.367</td>\n",
       "      <td>0.367</td>\n",
       "      <td>0.117</td>\n",
       "      <td>0.141</td>\n",
       "      <td>0.037</td>\n",
       "      <td>0.037</td>\n",
       "      <td>0.121</td>\n",
       "      <td>0.121</td>\n",
       "      <td>0.043</td>\n",
       "      <td>0.043</td>\n",
       "      <td>0.092</td>\n",
       "      <td>0.100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ef=4_lr=1e-03_l1=3e-01_layer=5</th>\n",
       "      <td>0.879</td>\n",
       "      <td>3.811</td>\n",
       "      <td>0.839</td>\n",
       "      <td>0.839</td>\n",
       "      <td>0.167</td>\n",
       "      <td>0.234</td>\n",
       "      <td>0.045</td>\n",
       "      <td>0.045</td>\n",
       "      <td>0.055</td>\n",
       "      <td>0.055</td>\n",
       "      <td>0.100</td>\n",
       "      <td>0.103</td>\n",
       "      <td>0.087</td>\n",
       "      <td>0.089</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ef=16_lr=1e-03_l1=1e-01_layer=5</th>\n",
       "      <td>0.930</td>\n",
       "      <td>20.474</td>\n",
       "      <td>0.658</td>\n",
       "      <td>0.658</td>\n",
       "      <td>0.142</td>\n",
       "      <td>0.225</td>\n",
       "      <td>0.169</td>\n",
       "      <td>0.169</td>\n",
       "      <td>0.099</td>\n",
       "      <td>0.099</td>\n",
       "      <td>0.215</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.118</td>\n",
       "      <td>0.120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ef=4_lr=1e-03_l1=1e-01_layer=5</th>\n",
       "      <td>0.956</td>\n",
       "      <td>24.885</td>\n",
       "      <td>0.965</td>\n",
       "      <td>0.965</td>\n",
       "      <td>0.143</td>\n",
       "      <td>0.280</td>\n",
       "      <td>0.154</td>\n",
       "      <td>0.154</td>\n",
       "      <td>0.086</td>\n",
       "      <td>0.086</td>\n",
       "      <td>0.231</td>\n",
       "      <td>0.238</td>\n",
       "      <td>0.109</td>\n",
       "      <td>0.111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ef=16_lr=1e-03_l1=1e-03_layer=5</th>\n",
       "      <td>-4.485</td>\n",
       "      <td>508.757</td>\n",
       "      <td>0.864</td>\n",
       "      <td>0.864</td>\n",
       "      <td>0.141</td>\n",
       "      <td>0.216</td>\n",
       "      <td>0.580</td>\n",
       "      <td>0.580</td>\n",
       "      <td>0.124</td>\n",
       "      <td>0.124</td>\n",
       "      <td>0.633</td>\n",
       "      <td>0.635</td>\n",
       "      <td>0.105</td>\n",
       "      <td>0.112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ef=4_lr=1e-03_l1=1e-03_layer=5</th>\n",
       "      <td>0.999</td>\n",
       "      <td>1064.868</td>\n",
       "      <td>0.969</td>\n",
       "      <td>0.969</td>\n",
       "      <td>0.079</td>\n",
       "      <td>0.214</td>\n",
       "      <td>0.864</td>\n",
       "      <td>0.864</td>\n",
       "      <td>0.124</td>\n",
       "      <td>0.124</td>\n",
       "      <td>0.894</td>\n",
       "      <td>0.894</td>\n",
       "      <td>0.100</td>\n",
       "      <td>0.120</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 frac_variance_explained        l0  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5             -1391422.000     2.131   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                     0.879     3.811   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                    0.930    20.474   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                     0.956    24.885   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                   -4.485   508.757   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                     0.999  1064.868   \n",
       "\n",
       "                                 frac_any_board_per_feature_act-nonzero_precision-0.9  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                              0.367      \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                               0.839      \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                              0.658      \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                               0.965      \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                              0.864      \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                               0.969      \n",
       "\n",
       "                                 frac_any_board_per_feature_act-best_precision-0.9  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                              0.367   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                               0.839   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                              0.658   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                               0.965   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                              0.864   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                               0.969   \n",
       "\n",
       "                                 frac_any_feature_per_board_act-nonzero_precision-0.9  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                              0.117      \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                               0.167      \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                              0.142      \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                               0.143      \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                              0.141      \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                               0.079      \n",
       "\n",
       "                                 frac_any_feature_per_board_act-best_precision-0.9  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                              0.141   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                               0.234   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                              0.225   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                               0.280   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                              0.216   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                               0.214   \n",
       "\n",
       "                                 frac_any_board_per_feature_act-nonzero_recall-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                              0.037    \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                               0.045    \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                              0.169    \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                               0.154    \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                              0.580    \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                               0.864    \n",
       "\n",
       "                                 frac_any_board_per_feature_act-best_recall-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                            0.037   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                             0.045   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                            0.169   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                             0.154   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                            0.580   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                             0.864   \n",
       "\n",
       "                                 frac_any_feature_per_board_act-nonzero_recall-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                              0.121    \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                               0.055    \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                              0.099    \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                               0.086    \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                              0.124    \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                               0.124    \n",
       "\n",
       "                                 frac_any_feature_per_board_act-best_recall-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                            0.121   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                             0.055   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                            0.099   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                             0.086   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                            0.124   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                             0.124   \n",
       "\n",
       "                                 frac_any_board_per_feature_act-nonzero_f1-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                           0.043   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                            0.100   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                           0.215   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                            0.231   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                           0.633   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                            0.894   \n",
       "\n",
       "                                 frac_any_board_per_feature_act-best_f1-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                        0.043   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                         0.103   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                        0.220   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                         0.238   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                        0.635   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                         0.894   \n",
       "\n",
       "                                 frac_any_feature_per_board_act-nonzero_f1-0.01  \\\n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                           0.092   \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                            0.087   \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                           0.118   \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                            0.109   \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                           0.105   \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                            0.100   \n",
       "\n",
       "                                 frac_any_feature_per_board_act-best_f1-0.01  \n",
       "ef=16_lr=1e-03_l1=3e-01_layer=5                                        0.100  \n",
       "ef=4_lr=1e-03_l1=3e-01_layer=5                                         0.089  \n",
       "ef=16_lr=1e-03_l1=1e-01_layer=5                                        0.120  \n",
       "ef=4_lr=1e-03_l1=1e-01_layer=5                                         0.111  \n",
       "ef=16_lr=1e-03_l1=1e-03_layer=5                                        0.112  \n",
       "ef=4_lr=1e-03_l1=1e-03_layer=5                                         0.120  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(sweep_results, orient='index').sort_values('l0')\n",
    "df = df.round(3)\n",
    "df[['frac_variance_explained', 'l0', 'frac_any_board_per_feature_act-nonzero_precision-0.9', 'frac_any_board_per_feature_act-best_precision-0.9', 'frac_any_feature_per_board_act-nonzero_precision-0.9', 'frac_any_feature_per_board_act-best_precision-0.9', 'frac_any_board_per_feature_act-nonzero_recall-0.01', 'frac_any_board_per_feature_act-best_recall-0.01', 'frac_any_feature_per_board_act-nonzero_recall-0.01', 'frac_any_feature_per_board_act-best_recall-0.01', 'frac_any_board_per_feature_act-nonzero_f1-0.01', 'frac_any_board_per_feature_act-best_f1-0.01', 'frac_any_feature_per_board_act-nonzero_f1-0.01', 'frac_any_feature_per_board_act-best_f1-0.01']]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
