{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nnsight import LanguageModel\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import chess\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from typing import Callable\n",
    "\n",
    "from dictionary_learning import ActivationBuffer\n",
    "from nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model\n",
    "from dictionary_learning.utils import hf_dataset_to_generator\n",
    "from dictionary_learning import AutoEncoder\n",
    "\n",
    "from circuits.utils import get_ae_bundle, AutoEncoderBundle, get_feature\n",
    "\n",
    "import chess_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I run this notebook on my laptop, which has 4GB of VRAM in a RTX 3050, and 64 GB of RAM.\n",
    "\n",
    "I believe this notebook is actually more CPU bound than GPU bound, as my laptop runs this notebook in less time than a vast.ai RTX 3090.\n",
    "\n",
    "Step 1: Load the model, dictionary, data, and activation buffers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "MODEL_PATH = \"../models/lichess_8layers_ckpt_no_optimizer.pt\"\n",
    "batch_size = 25\n",
    "\n",
    "autoencoder1_path = \"../autoencoders/group0/ef=4_lr=1e-03_l1=1e-01_layer=5/\"\n",
    "autoencoder2_path = \"../autoencoders/group0/ef=8_lr=1e-04_l1=1e-03_layer=5/\"\n",
    "\n",
    "# chess_sae_test is 100MB of data, so no big deal to download it\n",
    "data = hf_dataset_to_generator(\"redacted/chess_sae_test\", streaming=False)\n",
    "\n",
    "ae_bundle1 = get_ae_bundle(autoencoder1_path, DEVICE, data, batch_size, model_path=\"../models/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Collect feature activations on total_inputs inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_inputs = 8000\n",
    "assert total_inputs % batch_size == 0\n",
    "num_iters = total_inputs // batch_size\n",
    "\n",
    "features = torch.zeros((total_inputs, ae_bundle1.dictionary_size), device=DEVICE)\n",
    "for i in tqdm(range(num_iters), total=num_iters, desc=\"Extracting features\"):\n",
    "    feature = get_feature(ae_bundle1.buffer, ae_bundle1.ae, DEVICE)  # (batch_size, dictionary_size)\n",
    "    features[i * batch_size : (i + 1) * batch_size, :] = feature"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A few plots about various statistics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "firing_rate_per_feature = (features != 0).float().sum(dim=0).cpu() / total_inputs\n",
    "\n",
    "# Creating the histogram\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.hist(firing_rate_per_feature, bins=50, alpha=0.75, color=\"blue\")\n",
    "plt.title(\"Histogram of firing rates for features\")\n",
    "plt.xlabel(\"Probability\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "firing_rate_per_input = (features != 0).float().sum(dim=-1).cpu() / total_inputs\n",
    "\n",
    "# Creating the histogram\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.hist(firing_rate_per_input, bins=50, alpha=0.75, color=\"blue\")\n",
    "plt.title(\"Percentage of features firing per input\")\n",
    "plt.xlabel(\"Probability\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I got this from: https://colab.research.google.com/drive/19Qo9wj5rGLjb6KsB9NkKNJkMiHcQhLqo?usp=sharing#scrollTo=WZMhAzLTvw-u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_prob = features.mean(0)\n",
    "print(feat_prob.shape)\n",
    "log_freq = (feat_prob + 1e-10).log10()\n",
    "print(log_freq.shape)\n",
    "\n",
    "log_freq_np = log_freq.cpu().numpy()\n",
    "\n",
    "# Creating the histogram\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.hist(log_freq_np, bins=50, alpha=0.75, color=\"blue\")\n",
    "plt.title(\"Histogram of log10 of Feature Probabilities\")\n",
    "plt.xlabel(\"log10(Probability)\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the L0 statistic. Then, get a list of indices for features that fire between 0 and 50% of the time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(features.shape)\n",
    "l0 = (features != 0).float().sum(dim=-1).mean()\n",
    "print(f\"l0: {l0}\")\n",
    "\n",
    "firing_rate_per_feature = (features != 0).float().sum(dim=0) / total_inputs\n",
    "\n",
    "assert firing_rate_per_feature.shape[0] == ae_bundle1.dictionary_size\n",
    "\n",
    "mask = (firing_rate_per_feature > 0) & (firing_rate_per_feature < 0.5)\n",
    "idx = torch.nonzero(mask, as_tuple=False).squeeze()\n",
    "print(idx.shape)\n",
    "print(f\"\\n\\nWe have {idx.shape[0]} features that fire between 0 and 50% of the time.\")\n",
    "print(idx[:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we collect per dim stats, which include the top tokens it fires on, and the top k inputs and activations per input token.\n",
    "\n",
    "Rough ballpark times on my RTX 3050: \n",
    "\n",
    "2000 dims, 1500 inputs, batch size 50 = 23 seconds\n",
    "\n",
    "Note that I perform the activation processing on my CPU. This is comparable speed, but much lower VRAM usage.\n",
    "\n",
    "The current state of the code uses around 30-40GB of RAM of n_dims == 6000, n_inputs == 5000. If we want to scale to more dims and / or inputs, the code should be refactored to \"stream\" rather than save intermediate results. I have TODOs for this in the code.\n",
    "\n",
    "If you run out of RAM or VRAM, reduce n_inputs and max_dims."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "import chess_interp\n",
    "\n",
    "importlib.reload(chess_interp)\n",
    "\n",
    "max_dims = 2000\n",
    "top_k = 30\n",
    "n_inputs = 1500\n",
    "batch_size = 25\n",
    "\n",
    "\n",
    "per_dim_stats1 = chess_interp.examine_dimension_chess(\n",
    "    ae_bundle1,\n",
    "    dims=idx[:max_dims],\n",
    "    n_inputs=n_inputs,\n",
    "    k=top_k + 1,\n",
    "    batch_size=batch_size,\n",
    "    processing_device=\"cpu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_acts = []\n",
    "mean_acts = []\n",
    "for key in per_dim_stats1.keys():\n",
    "    max_acts.append(per_dim_stats1[key]['activations'][0][-1])\n",
    "    for i in range(top_k):\n",
    "        mean_acts.append(per_dim_stats1[key]['activations'][i][-1])\n",
    "\n",
    "# print(per_dim_stats1[0]['activations'][0][-1])\n",
    "print(max(max_acts), min(max_acts))\n",
    "print(max(mean_acts), min(mean_acts))\n",
    "print(sum(max_acts) / len(max_acts))\n",
    "print(sum(mean_acts) / len(mean_acts))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All of the above steps can be performed with `get_ae_stats()`, which we will do for the second autoencoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import circuits.sae_stats_collection\n",
    "importlib.reload(circuits.sae_stats_collection)\n",
    "from circuits.sae_stats_collection import get_ae_stats\n",
    "\n",
    " # TODO getting the eval_results is broken. I think it's because I switched to SAEs on the residual stream, not the mlp output\n",
    " # The residual stream returns a tuple of parameters, not a single parameter. This is just a guess though.\n",
    "per_dim_stats2, eval_result2 = get_ae_stats(autoencoder2_path, max_dims, n_inputs, top_k, batch_size, DEVICE, \"../models/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see below, autoencoder 2 has a terrible L0 of 1911. It has around 1500 features that fire between 0 and 50% of the time. It is expansion factor 8 on d_model == 512."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of features firing between 0 and 50% of the time: {len(per_dim_stats2)}\\n\")\n",
    "\n",
    "for key, value in eval_result2.items():\n",
    "    print(f\"{key}: {value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This cell looks at syntax related features. Specifically, it looks for features that always fire on a PGN \"counting number\". In this PGN, I've wrapped the \"counting numbers\" in brackets.\n",
    "\n",
    ";<1.>e4 e5 <2.>Nf3 ...\n",
    "\n",
    "We can easily analyze different syntax related attributes by just passing in a different syntax function, such as one that just finds space indices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(chess_utils)\n",
    "from pydantic import BaseModel\n",
    "from chess_interp import initialize_feature_dictionary\n",
    "from typing import Optional\n",
    "\n",
    "\n",
    "class SyntaxResultsConfig(BaseModel):\n",
    "    dim_count: int = 0\n",
    "    nonzero_count: int = 0\n",
    "    syntax_match_idx_count: int = 0\n",
    "    average_input_length: float = 0.0\n",
    "\n",
    "# Copy pasted directly from chess_interp.py\n",
    "def syntax_analysis(\n",
    "    per_dim_stats: dict,\n",
    "    minimum_number_of_activations: int,\n",
    "    top_k: int,\n",
    "    max_dims: int,\n",
    "    syntax_function: Callable,\n",
    "    feature_dict: Optional[dict[int, list[dict]]] = None,\n",
    "    notebook_usage: bool = False,\n",
    "    verbose: bool = False,\n",
    ") -> tuple[SyntaxResultsConfig, dict[int, list[dict]]]:\n",
    "\n",
    "    if feature_dict is None:\n",
    "        feature_dict = initialize_feature_dictionary(per_dim_stats)\n",
    "\n",
    "    results = SyntaxResultsConfig()\n",
    "\n",
    "    for dim in per_dim_stats:\n",
    "        results.dim_count += 1\n",
    "        if results.dim_count >= max_dims:\n",
    "            break\n",
    "\n",
    "        decoded_tokens = per_dim_stats[dim][\"decoded_tokens\"]\n",
    "        activations = per_dim_stats[dim][\"activations\"]\n",
    "        # If the dim doesn't have at least min_num firing activations, skip it\n",
    "        if activations[minimum_number_of_activations][-1].item() == 0:\n",
    "            continue\n",
    "        results.nonzero_count += 1\n",
    "\n",
    "        inputs = [\"\".join(string) for string in decoded_tokens]\n",
    "        inputs = inputs[:top_k]\n",
    "\n",
    "        num_indices = []\n",
    "        count = 0\n",
    "\n",
    "        for i, pgn in enumerate(inputs[:top_k]):\n",
    "            nums = syntax_function(pgn)\n",
    "            num_indices.append(nums)\n",
    "\n",
    "            # If the last token (which contains the max activation for that context) is a number\n",
    "            # Then we count this firing as a \"number index firing\"\n",
    "            if (len(pgn) - 1) in nums:\n",
    "                count += 1\n",
    "\n",
    "        if count == top_k:\n",
    "            if notebook_usage:\n",
    "                for pgn in inputs[:top_k]:\n",
    "                    print(pgn)\n",
    "                print(f\"All top {top_k} activations in dim: {dim} are on num indices\")\n",
    "            results.syntax_match_idx_count += 1\n",
    "            average_input_length = sum(len(pgn) for pgn in inputs[:top_k]) / len(inputs[:top_k])\n",
    "            results.average_input_length += average_input_length\n",
    "            feature_dict[dim].append({\"name\": syntax_function.__name__})\n",
    "\n",
    "    if results.syntax_match_idx_count > 0:\n",
    "        results.average_input_length /= results.syntax_match_idx_count\n",
    "\n",
    "    if verbose:\n",
    "        print(\n",
    "            f\"Out of {results.dim_count} features, {results.nonzero_count} had at least {minimum_number_of_activations} activations.\"\n",
    "        )\n",
    "        print(\n",
    "            f\"{results.syntax_match_idx_count} features matched on all top {top_k} inputs for our syntax function {syntax_function.__name__}\"\n",
    "        )\n",
    "        print(\n",
    "            f\"The average length of inputs of pattern matching features was {results.average_input_length:.2f}\"\n",
    "        )\n",
    "\n",
    "    return results, feature_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first 5 features in the second autoencoder are all syntax features. We can see them below.\n",
    "\n",
    "Example:\n",
    "\n",
    ";1.d4 d5 2.h3 c5 3.a3 Nc6 4.e3 e5 5.dxc5 Bxc5 6.b4 Bb6 7\n",
    "\n",
    ";1.e4 e5 2.c3 d5 3.exd5 Qxd5 4.d4 exd4 5.cxd4 Bb4+ 6\n",
    "\n",
    ";1.Nh3 e5 2.Na3 d5 3.e3 Bxa3 4.bxa3 Bxh3 5.gxh3 c5 6\n",
    "\n",
    ";1.e4 e6 2.c4 d5 3.d3 dxc4 4.dxc4 Qxd1+ 5.Kxd1 Nc6 6\n",
    "\n",
    ";1.d4 d6 2.c4 Nd7 3.Nc3 e5 4.d5 Ne7 5.e4 Ng6 6.Bd3 Be7 7\n",
    "\n",
    "All top 10 activations in dim: 3 are on num indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_top_k = top_k\n",
    "top_k = 10\n",
    "syntax_analysis(per_dim_stats2, top_k, top_k, max_dims=5, syntax_function=chess_utils.find_num_indices, notebook_usage=True, verbose=True)\n",
    "top_k = prev_top_k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In contrast, in the first 25 features on the first autoencoder, which has an L0 of 25, only three of the first 25 features match our `find_num_indices` pattern match. But, they are probably common opening features, as the top k inputs for each feature are identical, so it's debateable if they are actually syntax matches. This is one of the challenges of automatic evaluations. Maybe enough heuristics would make this acceptable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_top_k = top_k\n",
    "top_k = 10\n",
    "syntax_analysis(per_dim_stats1, top_k, top_k, max_dims=25, syntax_function=chess_utils.find_num_indices, notebook_usage=True, verbose=True)\n",
    "top_k = prev_top_k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can also do programmatic analysis of board states at max activating input tokens. The procedure is the following:\n",
    "\n",
    "At each max activating token in an input pgn string, convert the pgn string to a chess board object. Then we can find common board states, such as all inputs have this piece on this square, or all inputs have a pinned piece on the board.\n",
    "\n",
    "In the case of common board states, we convert every chess board object to a one hot tensor of shape (8, 8, 13) or (rows, cols, num_options). Num options is (blank, white / black pawn, knight, bishop, rook, queen, king). Now, we have the chess boards tensor of shape (top_k, rows, cols, num_options).\n",
    "\n",
    "So, we just look at every square in (rows, cols, num_options) and see if 100% of squares are 1. Any square with a 100% match is added to common_indices. For example, if every e4 contains a white pawn and every a2 contains a white rook, then we would have (3, 4, 6) and (1, 0, 2) or (4, e, white pawn) and (2, a, white rook).\n",
    "\n",
    "Note that we also make a one hot tensor for the initial board state, and mask off all initial board state squares from every chess board tensor. If we didn't, any short pgn string would have many activations.\n",
    "\n",
    "Note that this `board_analysis()` function takes a `Config`, which contains a function to convert a chess board object to a tensor. All of the above steps can be repeated for any Config object. For example, we have a `threat_config`, which makes a one hot tensor of shape (8, 8, 2) or (rows, cols, is_threatened), where is_threatened is if the square is threatened by an opponent. Or we can have a `pin_config`, which makes a one hot tensor of shape (1, 1, 2), which is active if a piece on the board is pinned to its king.\n",
    "\n",
    "The masking also applies to shape configs like `pin_config`. If we didn't didn't mask off the initial board state (with no pins), the vast majority of features would activate for `no pins on board`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(chess_utils)\n",
    "from chess_utils import Config, get_num_classes\n",
    "\n",
    "from pydantic import BaseModel\n",
    "\n",
    "\n",
    "class BoardResultsConfig(BaseModel):\n",
    "    dim_count: int = 0\n",
    "    nonzero_count: int = 0\n",
    "    pattern_match_count: int = 0\n",
    "    total_average_length: float = 0.0\n",
    "    average_matches_per_dim: float = 0.0\n",
    "    per_class_dict: dict[int, int]\n",
    "    board_tracker: list[list[int]]  # shape: (num_rows, num_cols)\n",
    "\n",
    "\n",
    "# copy pasted directly from chess_interp.py\n",
    "def board_analysis(\n",
    "    per_dim_stats: dict,\n",
    "    minimum_number_of_activations: int,\n",
    "    top_k: int,\n",
    "    max_dims: int,\n",
    "    threshold: float,\n",
    "    configs: list[Config],\n",
    "    feature_dict: Optional[dict[int, list[dict]]] = None,\n",
    "    device: str = \"cpu\",\n",
    "    notebook_usage: bool = False,\n",
    "    verbose: bool = False,\n",
    ") -> tuple[dict[str, BoardResultsConfig], dict[int, list[dict]]]:\n",
    "\n",
    "    if feature_dict is None:\n",
    "        feature_dict = initialize_feature_dictionary(per_dim_stats)\n",
    "\n",
    "    nonzero_count = 0\n",
    "    dim_count = 0\n",
    "\n",
    "    results: dict[str, BoardResultsConfig] = {}\n",
    "\n",
    "    for config in configs:\n",
    "        board_tracker = torch.zeros(config.num_rows, config.num_cols).tolist()\n",
    "        per_class_dict = {key: 0 for key in range(0, get_num_classes(config))}\n",
    "\n",
    "        results[config.custom_board_state_function.__name__] = BoardResultsConfig(\n",
    "            per_class_dict=per_class_dict,\n",
    "            board_tracker=board_tracker,\n",
    "        )\n",
    "\n",
    "    for dim in tqdm(per_dim_stats, total=len(per_dim_stats), desc=\"Processing chess pgn strings\"):\n",
    "        dim_count += 1\n",
    "        if dim_count >= max_dims:\n",
    "            break\n",
    "\n",
    "        decoded_tokens = per_dim_stats[dim][\"decoded_tokens\"]\n",
    "        activations = per_dim_stats[dim][\"activations\"]\n",
    "        # If the dim doesn't have at least minimum_number_of_activations firing activations, skip it\n",
    "        if activations[minimum_number_of_activations][-1].item() == 0:\n",
    "            continue\n",
    "        nonzero_count += 1\n",
    "\n",
    "        inputs = [\"\".join(string) for string in decoded_tokens]\n",
    "        inputs = inputs[:top_k]\n",
    "\n",
    "        count = 0\n",
    "\n",
    "        chess_boards = [\n",
    "            chess_utils.pgn_string_to_board(pgn, allow_exception=True) for pgn in inputs\n",
    "        ]\n",
    "\n",
    "        for config in configs:\n",
    "\n",
    "            config_name = config.custom_board_state_function.__name__\n",
    "\n",
    "            # See function definitions for jaxtyped shapes\n",
    "            one_hot_list = chess_utils.chess_boards_to_state_stack(chess_boards, device, config)\n",
    "            one_hot_list = chess_utils.mask_initial_board_states(one_hot_list, device, config)\n",
    "            averaged_one_hot = chess_utils.get_averaged_states(one_hot_list)\n",
    "            common_indices = chess_utils.find_common_states(averaged_one_hot, threshold)\n",
    "\n",
    "            if any(len(idx) > 0 for idx in common_indices):  # if at least one square matches\n",
    "                results[config_name].pattern_match_count += 1\n",
    "                average_input_length = sum(len(pgn) for pgn in inputs) / len(inputs)\n",
    "                results[config_name].total_average_length += average_input_length\n",
    "\n",
    "                if notebook_usage:\n",
    "                    for i, pgn in enumerate(inputs):\n",
    "                        if i >= 10:\n",
    "                            break\n",
    "                        print(pgn)\n",
    "\n",
    "                common_board_state = torch.zeros(\n",
    "                    config.num_rows,\n",
    "                    config.num_cols,\n",
    "                    get_num_classes(config),\n",
    "                    device=device,\n",
    "                    dtype=torch.int8,\n",
    "                )\n",
    "\n",
    "                for idx in zip(*common_indices):\n",
    "                    results[config_name].board_tracker[idx[0]][idx[1]] += 1\n",
    "                    results[config_name].per_class_dict[idx[2].item()] += 1\n",
    "                    results[config_name].average_matches_per_dim += 1\n",
    "\n",
    "                    common_board_state[idx[0], idx[1], idx[2]] = 1\n",
    "                    if notebook_usage:\n",
    "                        print(f\"Dim: {dim}, Index: {idx}\")\n",
    "\n",
    "                feature_info = {\n",
    "                    \"name\": config.custom_board_state_function.__name__,\n",
    "                    \"max_activation\": activations[0][-1].item(),\n",
    "                    \"board_state\": common_board_state,\n",
    "                }\n",
    "\n",
    "                feature_dict[dim].append(feature_info)\n",
    "\n",
    "    for config in configs:\n",
    "        config_name = config.custom_board_state_function.__name__\n",
    "        match_count = results[config_name].pattern_match_count\n",
    "        results[config_name].dim_count = dim_count\n",
    "        results[config_name].nonzero_count = nonzero_count\n",
    "        results[config_name].board_tracker = results[config_name].board_tracker\n",
    "        if match_count > 0:\n",
    "            results[config_name].total_average_length /= match_count\n",
    "            results[config_name].average_matches_per_dim /= match_count\n",
    "\n",
    "    if verbose:\n",
    "        for config in configs:\n",
    "            config_name = config.custom_board_state_function.__name__\n",
    "            pattern_match_count = results[config_name].pattern_match_count\n",
    "            total_average_length = results[config_name].total_average_length\n",
    "            print(f\"\\n{config_name} Results:\")\n",
    "            print(\n",
    "                f\"Out of {dim_count} features, {nonzero_count} had at least {minimum_number_of_activations} activations.\"\n",
    "            )\n",
    "            print(\n",
    "                f\"{pattern_match_count} features matched on all top {top_k} inputs for our board to state function {config_name}\"\n",
    "            )\n",
    "            print(\n",
    "                f\"The average length of inputs of pattern matching features was {total_average_length}\"\n",
    "            )\n",
    "\n",
    "            if config.num_rows == 8:\n",
    "                board_tracker = results[config_name].board_tracker\n",
    "                print(f\"\\nThe following square states had the following number of occurances:\")\n",
    "                for key, count in results[config_name].per_class_dict.items():\n",
    "                    print(f\"Index: {key}, Count: {count}\")\n",
    "\n",
    "                print(f\"\\nHere are the most common squares:\")\n",
    "                board_tracker = torch.tensor(board_tracker).flip(0)\n",
    "                print(board_tracker)  # torch.tensor has a cleaner printout\n",
    "\n",
    "    return results, feature_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In our first SAE, which has L0 of 25, we have many features that contain these shared board states. For example, in the PGN strings for Dim 18, all pgn strings contain the following 8 identical board squares:\n",
    "\n",
    ";1.e4 e5 2.Bc4 Nf6 3.d3 c6 4.Nf3 d5 5.exd5 cxd5 6.Bb3 e4 7.dxe4 Nxe4 8.O-O Bc5 9.Qxd5 Qxd5 10.Bxd5 Nf6 11.Bb3 O-O 12.\n",
    "\n",
    ";1.e4 c5 2.Nf3 Nc6 3.d4 cxd4 4.Nxd4 e6 5.Nc3 a6 6.Nb3 Bb4 7.Bd2 d5 8.exd5 exd5 9.Bd3 Nf6 10.O-O O-O 11.\n",
    "\n",
    ";1.e4 c5 2.Nf3 d6 3.d4 cxd4 4.Nxd4 Nf6 5.Nc3 a6 6.Bd3 g6 7.O-O Bg7 8.Nf3 O-O 9.Bg5 Nc6 10.Qd2 \n",
    "\n",
    ";1.e4 e5 2.Nf3 Nc6 3.Bb5 Nge7 4.Bxc6 Nxc6 5.d3 d6 6.O-O Be7 7.h3 O-O 8.b3 a6 9.Bb2 f5 10.exf5 Bxf5 11.Nc3 \n",
    "\n",
    ";1.e4 e5 2.Nf3 Nc6 3.Bb5 Nge7 4.Bxc6 Nxc6 5.d3 d6 6.O-O Be7 7.h3 O-O 8.b3 a6 9.Bb2 f5 10.exf5 Bxf5 11.Nc3 Bg5 12.Qe2 \n",
    "\n",
    ";1.e4 c5 2.c3 Nc6 3.Nf3 g6 4.Bb5 Bg7 5.O-O d6 6.d4 cxd4 7.cxd4 Nf6 8.d5 a6 9.Ba4 b5 10.dxc6 bxa4 11.Qxa4 O-O 12.Nc3 B\n",
    "\n",
    ";1.e4 e5 2.Nf3 Nc6 3.d4 exd4 4.Nxd4 Bc5 5.Be3 Nf6 6.Nxc6 bxc6 7.Bxc5 Nxe4 8.Be3 Rb8 9.Bd3 Qe7 10.O-O O-O 11.Nc3 Nf6 12.Bg\n",
    "\n",
    ";1.e4 c5 2.Nf3 d6 3.d4 cxd4 4.Nxd4 Nf6 5.Nc3 a6 6.Bd3 g6 7.O-O Bg7 8.Nf3 O-O 9.Bg5 Nc6 10.Qd2 B\n",
    "\n",
    ";1.e4 e5 2.Nf3 Nf6 3.Nxe5 d6 4.Nf3 Nxe4 5.d4 d5 6.Bd3 Be7 7.O-O O-O 8.c4 c6 9.Nc3 Nxc3 10.bxc3 Bg\n",
    "\n",
    ";1.e4 e5 2.Nf3 Nc6 3.Bb5 Nge7 4.Bxc6 Nxc6 5.d3 d6 6.O-O Be7 7.h3 O-O 8.b3 a6 9.Bb2 f5 10.exf5 Bxf5 11.Nc3 Bg5 12.Qe2 Bh6 13.Ne4 \n",
    "\n",
    "Dim: 18, Index: (tensor(0), tensor(4), tensor(6))\n",
    "\n",
    "Dim: 18, Index: (tensor(0), tensor(5), tensor(10))\n",
    "\n",
    "Dim: 18, Index: (tensor(0), tensor(6), tensor(12))\n",
    "\n",
    "Dim: 18, Index: (tensor(0), tensor(7), tensor(6))\n",
    "\n",
    "Dim: 18, Index: (tensor(7), tensor(4), tensor(6))\n",
    "\n",
    "Dim: 18, Index: (tensor(7), tensor(5), tensor(2))\n",
    "\n",
    "Dim: 18, Index: (tensor(7), tensor(6), tensor(0))\n",
    "\n",
    "Dim: 18, Index: (tensor(7), tensor(7), tensor(6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_top_k = top_k\n",
    "top_k = 30\n",
    "results, feature_dict1 = board_analysis(\n",
    "    per_dim_stats1, top_k, top_k, 50, 0.99, [chess_utils.piece_config], device=\"cpu\", notebook_usage=True, verbose=True\n",
    ")\n",
    "top_k = prev_top_k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(feature_dict1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_top_k = top_k\n",
    "top_k = 10\n",
    "board_analysis(\n",
    "    per_dim_stats1, top_k, top_k, 5, 0.99, [chess_utils.piece_config], device=\"cpu\", notebook_usage=True, verbose=True\n",
    ")\n",
    "top_k = prev_top_k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_analysis(\n",
    "    per_dim_stats1, top_k, top_k, 5, 0.99, [chess_utils.piece_config], device=\"cpu\", notebook_usage=True, verbose=True\n",
    ")\n",
    "top_k = prev_top_k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our first SAE, with L0 of 25, we have around 10% of features matching one of our syntax match filters (~200 out of 2048)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syntax_analysis(per_dim_stats1, top_k, top_k, max_dims, chess_utils.find_num_indices, verbose=True)\n",
    "print()\n",
    "syntax_analysis(per_dim_stats1, top_k, top_k, max_dims, chess_utils.find_spaces_indices, verbose=True)\n",
    "print()\n",
    "syntax_analysis(per_dim_stats1, top_k, top_k, max_dims, chess_utils.find_dots_indices, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In contrast, for the second SAE with an L0 of ~2000, 75% of 1561 features match the `find_num_indices` filter!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syntax_analysis(per_dim_stats2, top_k, top_k, max_dims, chess_utils.find_num_indices, verbose=True)\n",
    "print()\n",
    "syntax_analysis(per_dim_stats2, top_k, top_k, max_dims, chess_utils.find_spaces_indices, verbose=True)\n",
    "print()\n",
    "syntax_analysis(per_dim_stats2, top_k, top_k, max_dims, chess_utils.find_dots_indices, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For board analysis on our first autoencder, we find that 50% of 2000 features match our `piece_config` filter. We also print the number of matches per class:\n",
    "\n",
    "The following square states had the following number of occurances:\n",
    "\n",
    "Index: 0, Count: 27\n",
    "\n",
    "Index: 1, Count: 11\n",
    "\n",
    "Index: 2, Count: 10\n",
    "\n",
    "Index: 3, Count: 12\n",
    "\n",
    "Index: 4, Count: 126\n",
    "\n",
    "Index: 5, Count: 344\n",
    "\n",
    "Index: 6, Count: 1781\n",
    "\n",
    "Index: 7, Count: 439\n",
    "\n",
    "Index: 8, Count: 178\n",
    "\n",
    "Index: 9, Count: 44\n",
    "\n",
    "Index: 10, Count: 24\n",
    "\n",
    "Index: 11, Count: 6\n",
    "\n",
    "Index: 12, Count: 37\n",
    "\n",
    "Unsurprisingly, most matches are for idx 6, the blank class. Many of these are potentially false positives. Understanding this further is probably important. We could potentially maybe compare to common board state stastics in real chess games (which we could gather using this repo).\n",
    "\n",
    "We also print the match count per square:\n",
    "\n",
    "Here are the most common squares:\n",
    "\n",
    "        [  0, 126,  19,  31,  12,  17,  69,  28],\n",
    "\n",
    "        [  8,   2,  90, 194, 160,   1,   9,   3],\n",
    "\n",
    "        [  6,   2,  63,  35,  54,  73,  11,   2],\n",
    "\n",
    "        [  4,  13,  26, 125,  89,   0,   0,   4],\n",
    "\n",
    "        [  2,   0,  41, 166, 203,   9,   1,   0],\n",
    "\n",
    "        [  1,   6,  31,   9,  15, 136,  11,   1],\n",
    "\n",
    "        [  4,   8,  45, 272, 349,  10,  10,   2],\n",
    "\n",
    "        [  1, 120,  26,  22,  27,  56, 137,  42]])\n",
    "\n",
    "One potential measure of quality could be a more even spread of pattern matches per square or over classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_analysis(\n",
    "    per_dim_stats1, top_k, top_k, max_dims, 0.99, [chess_utils.piece_config], device=\"cpu\", verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In contrast, only 436 out of 1561 features in our second autoencoder match a `piece_config` filter. This seems concerning! 75% of features are syntax features, rather than true underlying semantic features.\n",
    "\n",
    "Obviously, this second autoencoder has a pretty terrible L0. It would be interesting to do more analysis to see what this distribution looks like for more reasonable L0s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_analysis(\n",
    "    per_dim_stats2, top_k, top_k, max_dims, 0.99, [chess_utils.piece_config], device=\"cpu\", verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our first autoencoder finds 11 features matching `pin_config`, while the second only has 1 feature match for `pin_config`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_analysis(per_dim_stats1, top_k, top_k, 2500, 0.99, [chess_utils.pin_config], verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_analysis(per_dim_stats2, top_k, top_k, 2500, 0.99, [chess_utils.pin_config], verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also pass in a list of Configs for around a 2x speedup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_analysis(per_dim_stats1, top_k, top_k, 2500, 0.99, [chess_utils.threat_config, chess_utils.check_config], verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use this to get a dictionary for every feature in a dictionary of which filters it matches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_dict1 = initialize_feature_dictionary(per_dim_stats1)\n",
    "feature_dict2 = initialize_feature_dictionary(per_dim_stats2)\n",
    "\n",
    "board_results, feature_dict1 = board_analysis(\n",
    "    per_dim_stats1,\n",
    "    top_k,\n",
    "    top_k,\n",
    "    2500,\n",
    "    0.99,\n",
    "    [\n",
    "        chess_utils.pin_config,\n",
    "        chess_utils.piece_config,\n",
    "        chess_utils.threat_config,\n",
    "        chess_utils.check_config,\n",
    "    ],\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, key in enumerate(feature_dict1):\n",
    "    print(key)\n",
    "    print(feature_dict1[key])\n",
    "\n",
    "    if i > 5:\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
