{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`eval_sae_as_classifier.py` does the following:\n",
    "\n",
    "A custom function could be `board_to_pin_state`, which, for every token in the PGN string, returns a \"state stack\", which is 0 or 1. 0 means \"There is not a pin on the board at this character\" and 1 means \"there is a pin on the board at this character\". Or it could be like `board_to_piece_state`, which returns a state stack one hot tensor of shape (8,8,13) or (rows, cols, num_classes), which returns the state of every square on the chess board.\n",
    "\n",
    "Over 1000's of input pgn strings, for every activation for every dictionary feature for a range of threshold values, we check if the activation is above every threshold. For every active activation, for every custom function, we add the state stack to the on_tracker. For every off activation, for every custom function, we add the state stack to the off tracker. This runs reasonably quickly - around 2 minutes on an RTX 3090 for every 1000 input pgn strings.\n",
    "\n",
    "on_tracker is shape (thresholds, features, rows, cols, classes).\n",
    "\n",
    "So, if for 100% of the times that a feature is active, the board has a corresponding state (such as there is a pinned piece on the board, or a white knight on f3), then it's likely that the feature corresponds to that state.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import torch\n",
    "from typing import Callable\n",
    "import circuits.chess_utils as chess_utils\n",
    "from circuits.utils import to_cpu\n",
    "\n",
    "# This should have been downloaded and unzipped by setup.sh\n",
    "filename = \"group1_results/autoencoders_group1_ef=4_lr=1e-03_l1=1e-01_layer=5_results.pkl\"\n",
    "# filename = \"layer0_results/autoencoders_layer0_ef=4_lr=1e-03_l1=1e-01_layer=0_results.pkl\"\n",
    "# filename = \"group1_results/autoencoders_group1_ef=16_lr=1e-03_l1=1e-01_layer=5_results.pkl\"\n",
    "filename = \"layer5_large_sweep_results/autoencoders_layer5_large_sweep_ef=4_lr=1e-03_l1=1e-01_layer=5_results.pkl\"\n",
    "filename = \"layer5_large_sweep_results/autoencoders_layer5_large_sweep_ef=16_lr=1e-03_l1=1e-04_layer=5_results.pkl\"\n",
    "\n",
    "with open(filename, 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "\n",
    "# This usually isn't needed as eval_sae_as_classifier now does this, but I have some results that are on the GPU\n",
    "results = to_cpu(results)\n",
    "\n",
    "print(results.keys())\n",
    "print(\"\\nAs we can see, every custom function shares the same keys.\\n\")\n",
    "print(results['board_to_pin_state'].keys())\n",
    "print(results['board_to_piece_state'].keys())\n",
    "print(\"However, the shapes of the values are different.\\n\")\n",
    "print(results['board_to_pin_state']['on'].shape, results['board_to_pin_state']['off'].shape)\n",
    "print(results['board_to_piece_state']['on'].shape, results['board_to_piece_state']['off'].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have raw counts of how many times every state was active while a feature was on / off. We can convert these to percentages. For example, this state was active 100% of the time this feature was active."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(results['board_to_piece_state']['on'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "\n",
    "def normalize_tracker(\n",
    "    results: dict, tracker_type: str, custom_functions: list[Callable], device: torch.device\n",
    "):\n",
    "    \"\"\"Normalize the specified tracker (on or off) values by its count using element-wise multiplication.\"\"\"\n",
    "    for custom_function in custom_functions:\n",
    "        counts_TF = results[f\"{tracker_type}_count\"]\n",
    "\n",
    "        # Calculate inverse of counts safely\n",
    "        inverse_counts_TF = torch.zeros_like(counts_TF).to(device)\n",
    "        non_zero_mask = counts_TF > 0\n",
    "        inverse_counts_TF[non_zero_mask] = 1 / counts_TF[non_zero_mask]\n",
    "\n",
    "        tracker_TFRRC = results[custom_function.__name__][tracker_type]\n",
    "\n",
    "        # Normalize using element-wise multiplication\n",
    "        normalized_tracker_TFRRC = tracker_TFRRC * inverse_counts_TF[:, :, None, None, None]\n",
    "\n",
    "        # Store the normalized results\n",
    "        results[custom_function.__name__][f\"{tracker_type}_normalized\"] = normalized_tracker_TFRRC\n",
    "\n",
    "    return results\n",
    "\n",
    "results = normalize_tracker(results, \"on\", [chess_utils.board_to_pin_state, chess_utils.board_to_piece_state], torch.device(\"cpu\"))\n",
    "results = normalize_tracker(results, \"off\", [chess_utils.board_to_pin_state, chess_utils.board_to_piece_state], torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(results['on_count'][:, :5].to(torch.int))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These results came from num_inputs pgn strings of len 256. So, if we sum across possible square states, every element == 256 * num_inputs, which is also the total number of tokens / characters the SAE was evaluated on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "if 'hyperparameters' in results:\n",
    "    n_inputs = results['hyperparameters']['n_inputs']\n",
    "    print(f\"Every square should sum to {n_inputs * 256}.\")\n",
    "\n",
    "print(results['board_to_piece_state']['on'][0].shape)\n",
    "print(results['board_to_piece_state']['off'][0].shape)\n",
    "\n",
    "on_tracker = results['board_to_piece_state']['on'][0].sum(dim=-1)\n",
    "off_tracker = results['board_to_piece_state']['off'][0].sum(dim=-1)\n",
    "\n",
    "compare = on_tracker + off_tracker\n",
    "print(compare.shape)\n",
    "print(compare[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In contrast, every pin state should sum to the number of characters where there was a pin on the board. It often seems to be about 10% of number of characters (the above number)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(results['board_to_pin_state']['on'].squeeze()[0])\n",
    "print(results['board_to_pin_state']['off'].squeeze()[0])\n",
    "\n",
    "compare = results['board_to_pin_state']['on'].squeeze()[1] + results['board_to_pin_state']['off'].squeeze()[1]\n",
    "print(compare[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This next cell looks for elements that were active > some percentage of the time (high_threshold) whenever a feature was active. For example, maybe there was a pin on the board 98% of the time feature 253 was active above threshold idx 5/10 (maybe the threshold was 2.0 for this index).\n",
    "\n",
    "If this is the case, we also check that this feature was active at least `significance threshold` times. Otherwise, any feature that was active only 1 time would have many percentage matches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "high_threshold = 0.95\n",
    "significance_threshold = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def get_above_below_counts(\n",
    "    tracker_TF: torch.Tensor,\n",
    "    counts_TF: torch.Tensor,\n",
    "    low_threshold: float,\n",
    "    high_threshold: float,\n",
    "    significance_threshold: int = 10,\n",
    "    verbose: bool = False,\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"Must be a 2D tensor matching shape annotation.\"\"\"\n",
    "\n",
    "    # Find all elements that were active more than x% of the time (high_threshold)\n",
    "    above_freq_TF_mask = tracker_TF >= high_threshold\n",
    "\n",
    "    # For the counts tensor, zero out all elements that were not active enough\n",
    "    above_counts_TF = counts_TF * above_freq_TF_mask\n",
    "\n",
    "    # Find all features that were active more than significance_threshold times\n",
    "    above_counts_TF_mask = above_counts_TF >= significance_threshold\n",
    "\n",
    "    # Zero out all elements that were not active enough\n",
    "    above_counts_TF = above_counts_TF * above_counts_TF_mask\n",
    "\n",
    "    # Count the number of elements that were active more than high_threshold % and significance_threshold times\n",
    "    above_counts_T = above_counts_TF_mask.sum(dim=(1))\n",
    "\n",
    "    # All nonzero elements are set to 1\n",
    "    above_counts_TF = (above_counts_TF != 0).int()\n",
    "\n",
    "    if verbose:\n",
    "        print(\n",
    "            f\"\\nThis is the number of elements that were active more than {high_threshold} and {significance_threshold} times.\"\n",
    "        )\n",
    "        print(\n",
    "            f\"Note that this shape is num_thresholds, and every element corresponds to a threshold.\"\n",
    "        )\n",
    "        print(above_counts_T)\n",
    "\n",
    "        above_T = above_freq_TF_mask.sum(dim=(1))\n",
    "\n",
    "        print(\n",
    "            f\"\\nThis is the number of elements that were active more than {high_threshold} percent.\"\n",
    "        )\n",
    "        print(above_T)\n",
    "\n",
    "    # Count the number of elements that were active less than low_threshold %\n",
    "    # below_T = below_freq_TF_mask.sum(dim=(1))\n",
    "    # # Count the number of elements that were active more than high_threshold %\n",
    "    # above_T = above_freq_TF_mask.sum(dim=(1))\n",
    "\n",
    "    # values_above_threshold = [tracker_TF[i, above_freq_TF_mask[i]] for i in range(tracker_TF.size(0))]\n",
    "    # counts_above_threshold = [counts_TF[i, above_freq_TF_mask[i]] for i in range(tracker_TF.size(0))]\n",
    "\n",
    "    # for i, values in enumerate(values_above_threshold):\n",
    "    #     print(f\"Row {i} values above {high_threshold}: {values.tolist()}\")\n",
    "\n",
    "    # for i, counts in enumerate(counts_above_threshold):\n",
    "    #     print(f\"Row {i} counts above {high_threshold}: {counts.tolist()}\")\n",
    "\n",
    "    return above_counts_T, above_counts_TF\n",
    "\n",
    "\n",
    "above_counts_T, above_counts_TF = get_above_below_counts(\n",
    "    results[\"board_to_pin_state\"][\"on_normalized\"].squeeze().clone(),\n",
    "    results[\"board_to_pin_state\"][\"on\"].squeeze().clone(),\n",
    "    0.00,\n",
    "    high_threshold,\n",
    "    significance_threshold=significance_threshold,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we flatten the `board_to_piece_state` tracker to shape (thresholds, (rows, cols, classes)). We do some masking of certain states, and rerun the same analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import einops\n",
    "import chess\n",
    "\n",
    "def mask_initial_board_state(on_tracker: torch.Tensor, device: torch.device) -> torch.Tensor:\n",
    "    initial_board = chess.Board()\n",
    "    initial_state = chess_utils.board_to_piece_state(initial_board)\n",
    "    initial_state = initial_state.view(1, 1, 8, 8)\n",
    "    initial_one_hot = chess_utils.state_stack_to_one_hot(\n",
    "        chess_utils.piece_config, device, initial_state\n",
    "    ).squeeze()\n",
    "    mask = initial_one_hot == 1\n",
    "    on_tracker[:, :, mask] = 0\n",
    "\n",
    "    return on_tracker\n",
    "\n",
    "def analyze_board_tracker(\n",
    "    results: dict,\n",
    "    function: str,\n",
    "    key: str,\n",
    "    device: torch.device,\n",
    "    high_threshold: float,\n",
    "    significance_threshold: int,\n",
    ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"Prepare the board tracker for analysis.\"\"\"\n",
    "    normalized_key = key + \"_normalized\"\n",
    "\n",
    "    num_thresholds = results[function][normalized_key].shape[0]\n",
    "\n",
    "    piece_state_on_normalized = (\n",
    "        results[function][normalized_key].clone().view(num_thresholds, -1)\n",
    "    )\n",
    "    piece_state_on = results[function][key].clone()\n",
    "    original_shape = piece_state_on.shape\n",
    "\n",
    "    piece_state_on = mask_initial_board_state(piece_state_on, device)\n",
    "\n",
    "    # Optionally, we also mask off the blank class\n",
    "    piece_state_on[:, :, :, :, 6] = 0\n",
    "\n",
    "    # Flatten the tensor to a 2D shape for compatibility with get_above_below_counts()\n",
    "    piece_state_on = piece_state_on.view(num_thresholds, -1)\n",
    "\n",
    "    above_counts_T, above_counts_TF = get_above_below_counts(\n",
    "        piece_state_on_normalized,\n",
    "        piece_state_on,\n",
    "        0.00,\n",
    "        high_threshold,\n",
    "        significance_threshold=significance_threshold,\n",
    "    )\n",
    "\n",
    "    best_idx = torch.argmax(above_counts_T)\n",
    "\n",
    "    above_counts_TFRRC = above_counts_TF.view(original_shape)\n",
    "\n",
    "    best_counts_FRRC = above_counts_TFRRC[best_idx, ...]\n",
    "\n",
    "    summary_board_RR = einops.reduce(best_counts_FRRC, \"F R1 R2 C -> R1 R2\", \"sum\").to(torch.int)\n",
    "\n",
    "    class_dict_C = einops.reduce(best_counts_FRRC, \"F R1 R2 C -> C\", \"sum\").to(torch.int)\n",
    "\n",
    "    return above_counts_T, summary_board_RR, class_dict_C\n",
    "\n",
    "piece_state_above_counts_T, summary_board, class_dict = (\n",
    "    analyze_board_tracker(\n",
    "        results,\n",
    "        \"board_to_piece_state\",\n",
    "        \"on\",\n",
    "        torch.device(\"cpu\"),\n",
    "        high_threshold,\n",
    "        significance_threshold,\n",
    "    )\n",
    ")\n",
    "\n",
    "print(piece_state_above_counts_T)\n",
    "\n",
    "print(f\"\\nThis is the number of times each square was active more than 98% of the time above some significance_threshold.\")\n",
    "print(summary_board)\n",
    "\n",
    "print(f\"\\nThis is the number of times each piece was active more than 98% of the time. 0 == black king, 1 == black queen, 6 == blank, 7 == white pawn, etc.\")\n",
    "print(class_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "print(torch.argmax(piece_state_above_counts_T))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an experiment looking at mine / yours vs white / black. It's half baked right now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def transform_board_from_piece_color_to_piece(board: torch.Tensor) -> torch.Tensor:\n",
    "    new_board = torch.zeros(board.shape[:-1] + (7,), dtype=board.dtype, device=board.device)\n",
    "\n",
    "    for i in range(7):\n",
    "        if i == 6:\n",
    "            new_board[..., i] = board[..., 6]\n",
    "        else:\n",
    "            new_board[..., i] = board[..., i] + board[..., 12 - i]\n",
    "    return new_board\n",
    "\n",
    "results['board_to_piece_state']['on_piece'] = transform_board_from_piece_color_to_piece(results['board_to_piece_state']['on'])\n",
    "results['on_piece_count'] = results['on_count']\n",
    "results = normalize_tracker(results, \"on_piece\", [chess_utils.board_to_piece_state], torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import einops\n",
    "import chess\n",
    "\n",
    "key = \"on_piece\"\n",
    "normalized_key = \"on_piece_normalized\"\n",
    "\n",
    "num_thresholds = results[\"board_to_piece_state\"][normalized_key].shape[0]\n",
    "\n",
    "piece_state_on_normalized = results[\"board_to_piece_state\"][normalized_key].clone().view(num_thresholds, -1)\n",
    "piece_state_on = results[\"board_to_piece_state\"][key].clone()\n",
    "original_shape = piece_state_on.shape\n",
    "\n",
    "\n",
    "initial_board = chess.Board()\n",
    "initial_state = chess_utils.board_to_piece_state(initial_board)\n",
    "initial_state = initial_state.view(1, 1, 8, 8)\n",
    "initial_one_hot = chess_utils.state_stack_to_one_hot(chess_utils.piece_config, \"cpu\", initial_state).squeeze()\n",
    "\n",
    "initial_one_hot = transform_board_from_piece_color_to_piece(initial_one_hot)\n",
    "\n",
    "mask = (initial_one_hot == 1)\n",
    "piece_state_on[:, :, mask] = 0\n",
    "\n",
    "piece_state_on[:, :, :, :, 6] = 0\n",
    "piece_state_on = piece_state_on.view(num_thresholds, -1)\n",
    "print(piece_state_on_normalized.shape)\n",
    "\n",
    "above_counts_TF = get_above_below_counts(piece_state_on_normalized, piece_state_on, 0.00, 0.98, significance_threshold=50)\n",
    "above_counts_TF = above_counts_TF.view(original_shape)\n",
    "\n",
    "summary_board_RR = einops.reduce(above_counts_TF, \"T F R1 R2 C -> R1 R2\", \"sum\").to(torch.int)\n",
    "print(summary_board_RR)\n",
    "\n",
    "class_dict_C = einops.reduce(above_counts_TF, \"T F R1 R2 C -> C\", \"sum\").to(torch.int)\n",
    "print(class_dict_C)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
