{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Mechanistic Interpretability Study of a Chess Playing Language Model\n",
    "\n",
    "There have been many reports of language models possessing linear representations of the board state in games. In particular, Karvonen<sup>1</sup> trained a GPT-2 style transformer (ChessGPT) on chess PGNs and was able to show that the model contained a linearly decodable representation of the board.\n",
    "\n",
    "However, how the transformer constructs this representation remains unclear. Recent interpretability efforts, such as those by Davis et. al.<sup>2</sup> classify some attention patterns and focus on finding where the model commits to its next move, but do not analyze the model's board representation. This work investigates the internal workings of ChessGPT and proposes methods that explain how the model computes the board state.\n",
    "\n",
    "---\n",
    "<sup>1</sup> Karvonen, A. Emergent world models and latent variable estimation in chess-playing language models. In Proceedings of the Conference on Language Modeling (COLM), 2024. URL https://openreview.net/forum?id=PPTrmvEnpW.\n",
    "Accepted at COLM 2024.\n",
    "\n",
    "<sup>2</sup> Davis, A. L. and Sukthankar, G. Decoding chess mastery: A mechanistic analysis of a chess language transformer model. In Artificial General Intelligence: 17th International Conference, AGI 2024, Seattle, WA, USA, August 13–16, 2024, Proceedings, pp. 63–72, Berlin, Heidelberg, 2024. Springer-Verlag. ISBN 978-3-031-65571-5. doi: 10.1007/978-3-031-65572-2 7. URL https://doi.org/10.1007/978-3-031-65572-2_7."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "  import google.colab\n",
    "  IN_COLAB = True\n",
    "  print(\"Running as a Colab notebook\")\n",
    "  %pip install chess\n",
    "  %pip install circuitsvis\n",
    "  %pip install transformer_lens\n",
    "except:\n",
    "  IN_COLAB = False\n",
    "  print(\"Running as a Jupyter notebook - intended for development only!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### import relevant packages\n",
    "\n",
    "import bisect\n",
    "import collections\n",
    "import os\n",
    "import pickle\n",
    "from dataclasses import dataclass, field\n",
    "from pathlib import Path\n",
    "\n",
    "import chess\n",
    "import circuitsvis as cv \n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from huggingface_hub import snapshot_download\n",
    "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
    "from tqdm import tqdm\n",
    "from tabulate import tabulate\n",
    "\n",
    "from transformer_lens import HookedTransformer, HookedTransformerConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chess utils\n",
    "PIECE_TO_INT = {\n",
    "    chess.PAWN: 1,\n",
    "    chess.KNIGHT: 2,\n",
    "    chess.BISHOP: 3,\n",
    "    chess.ROOK: 4,\n",
    "    chess.QUEEN: 5,\n",
    "    chess.KING: 6,\n",
    "}\n",
    "\n",
    "INT_TO_PIECE = {value: key for key, value in PIECE_TO_INT.items()}\n",
    "\n",
    "# model params\n",
    "D_MODEL = 512\n",
    "N_HEADS = 8\n",
    "\n",
    "MODEL_DIR = \"models/\"\n",
    "DATA_DIR = \"data/\"\n",
    "\n",
    "DEVICE = (\n",
    "    \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    ")\n",
    "\n",
    "for d in (Path(MODEL_DIR), Path(DATA_DIR)):\n",
    "    d.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download data and models from huggingface\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=\"spherical-chisel/ChessGPT-Interp\",\n",
    "    repo_type=\"dataset\",\n",
    "    local_dir=DATA_DIR,\n",
    ")\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=\"spherical-chisel/ChessGPT-Interp\",\n",
    "    repo_type=\"model\",\n",
    "    local_dir=MODEL_DIR,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# utils for loading the model\n",
    "with open(f\"{MODEL_DIR}meta.pkl\", \"rb\") as f:\n",
    "    meta = pickle.load(f)\n",
    "\n",
    "stoi, itos = meta[\"stoi\"], meta[\"itos\"]\n",
    "encode = lambda s: [stoi[c] for c in s]\n",
    "decode = lambda l: \"\".join([itos[i] for i in l])\n",
    "\n",
    "meta_round_trip_input = \"1.e4 e6 2.Nf3\"\n",
    "print(encode(meta_round_trip_input))\n",
    "print(\"Performing round trip test on meta\")\n",
    "assert decode(encode(meta_round_trip_input)) == meta_round_trip_input\n",
    "\n",
    "def get_transformer_lens_model(\n",
    "    model_name: str, n_layers: int, device: torch.device\n",
    ") -> HookedTransformer:\n",
    "\n",
    "    cfg = HookedTransformerConfig(\n",
    "        n_layers=n_layers,\n",
    "        d_model=D_MODEL,\n",
    "        d_head=int(D_MODEL / N_HEADS),\n",
    "        n_heads=N_HEADS,\n",
    "        d_mlp=D_MODEL * 4,\n",
    "        d_vocab=32,\n",
    "        n_ctx=1023,\n",
    "        act_fn=\"gelu\",\n",
    "        normalization_type=\"LNPre\",\n",
    "    )\n",
    "    model = HookedTransformer(cfg)\n",
    "    model.load_state_dict(torch.load(f\"{MODEL_DIR}{model_name}.pth\"))\n",
    "    model.to(device)\n",
    "    return model\n",
    "\n",
    "# convert each game transcript into a sequence of integer token IDs\n",
    "def get_board_seqs_int(df: pd.DataFrame):\n",
    "    encoded_df = df[\"transcript\"].apply(encode)\n",
    "    board_seqs_int_Bl = torch.tensor(encoded_df.apply(list).tolist())\n",
    "    return board_seqs_int_Bl \n",
    "\n",
    "# extract the game string from the dataframe\n",
    "def get_board_seqs_string(df: pd.DataFrame):\n",
    "\n",
    "    key = \"transcript\"\n",
    "    row_length = len(df[key].iloc[0])\n",
    "\n",
    "    assert all(\n",
    "        df[key].apply(lambda x: len(x) == row_length)\n",
    "    ), \"Not all transcripts are of length {}\".format(row_length)\n",
    "\n",
    "    board_seqs_string_Bl = df[key]\n",
    "\n",
    "    return board_seqs_string_Bl\n",
    "\n",
    "# load the model\n",
    "dataset_prefix = \"lichess_\"\n",
    "n_layers = 8\n",
    "model_name = f\"tf_lens_{dataset_prefix}{n_layers}layers_ckpt_no_optimizer\"\n",
    "\n",
    "model = get_transformer_lens_model(model_name, n_layers, DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data\n",
    "input_file = \"data/lichess_train.csv\"\n",
    "\n",
    "df = pd.read_csv(input_file)\n",
    "df = df[:10000] # we use the first 10,000 games for analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Overview\n",
    "\n",
    "We examine a 8-layered GPT model pretrained by Karvonen. The model uses a 512-dimension hidden space and 8 heads per layer, with a total parameter count of 25 million. The input the model is a chess PGN (1.e4 e5 2.Nf3 ...) with a maximum length of 1023 characters. Each character represents an input token, and the model’s vocabulary is restricted to the 32 characters required to construct chess PGN strings. Additionally, each game string starts with a ';' token.\n",
    "\n",
    "The model was trained to autoregressively generate the next character of the PGN. As reported in Karvonen’s work, the model has a legal move rate of 99.6% and an ELO rating of roughly 1300 with a win rate of 46% against Stockfish 16 level 0.\n",
    "\n",
    "Using the post-MLP residual stream as input, we follow Karvonen and train a linear probe that classifies every square on the chessboard into one of 13 states: blank, or one of the six piece types (pawn, knight, bishop, rook, queen, king), each in white or black.\n",
    "\n",
    "We make the following key observations:\n",
    "- Accuracy jumps from 88.7 % to 99.0 % between layers 4 and 5.  \n",
    "- Fitting a probe at layer 5 on the pre-MLP residual stream also reaches 99.0 %, implying the board representation originates in attention rather than the MLP.  \n",
    "- There is no drop in accuracy when the probe is tested on randomly generated games\n",
    "\n",
    "## Data Overview\n",
    "\n",
    "The dataset we use is also obtained from Karvonen<sup>1</sup>. The data consists of games from the Lichess open database. During probing and analysis, we create a set of 10,000 games not found in the model’s training data. All games in this set were truncated to a length of 365 tokens, as this was the median length of a game.\n",
    "\n",
    "## Attention Overview\n",
    "\n",
    "The PGN strings that are input to the model represent game moves in standard algebraic notation. In general, each move is indicated by a letter denoting the piece type followed by the coordinates of the destination square. The coordinates are then further split into a letter a-h representing the file and a number 1-8 denoting the rank. For pawn moves, the letter indicating the piece is omitted (e.g., c5). Additionally, there is special notation for captures, checks, castling, and piece disambiguation.\n",
    "\n",
    "Furthermore, the PGN move list is serialized as a space-separated sequence with the following pattern:\n",
    "\n",
    "```<Move-number>. <White-move> <Black-move> <Move-number+1>. <White-move> <Black-move> ...```\n",
    "\n",
    "When probing for the board state, we use the residual stream vectors at each dot token. The equivalent token from the Black perspective is the whitespace after the white move.\n",
    "\n",
    "Previous literature shows that the model represents the pieces with a (Mine, Yours) scheme. Hence, all of our analysis is done solely from the White perspective, as the opposite side is likely symmetric. Some interesting observations (mostly by visual inspection) are highlighted:\n",
    "\n",
    "- In layer 2, information about the parity, piece type, and file are moved to the rank token. This is sensible because the rank number is the last token of a move and is the prime candidate for collecting the information in a move\n",
    "- In layer 3, we see a proto-\"previous move head\" (Head 3.2) for knights and bishops\n",
    "- In layer 4, there is a more robust previous move head (Head 4.0), a head that tracks the King position (Head 4.1), companion piece heads (Head 4.3), and a gather all head (Head 4.6).\n",
    "- In layer 5, Head 5.0 tracks opponent pawns, Head 5.1 tracks your pawns, Head 5.2 tracks your queen/rooks/knights, Head 5.3 tracks your kingside bishop, Head 5.4 sometimes attends to the piece being moved, Head 5.5 tracks opponent bishops/knights, Head 5.6 tracks your queenside bishop/knight, and Head 5.7 tracks opponent queen/rooks\n",
    "\n",
    "Moreover, once piece information has been aggregated in the rank token, heads in later layers primarily attend to the move’s rank token and the start of game delimiter with some exceptions (e.g., captures). Thus, when we refer to the token for a certain move, we specifically refer to the rank token of the move. The aforementioned descriptions of some of the heads will become more clear later on. Below is code for visualizing the attention that you can play around with.\n",
    "\n",
    "---\n",
    "<sup>1</sup> More details and code can be found: https://github.com/adamkarvonen/chess_llm_interpretability\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "game = 4 # interested game\n",
    "\n",
    "game_int_seqs = get_board_seqs_int(df[:100])\n",
    "\n",
    "raw_tokens_for_game = game_int_seqs[game] \n",
    "if isinstance(raw_tokens_for_game, torch.Tensor):\n",
    "    raw_tokens_for_game = raw_tokens_for_game.tolist()\n",
    "\n",
    "inp_tokens = torch.tensor([raw_tokens_for_game]).to(DEVICE) # [1, sequence_length]\n",
    "\n",
    "resid_post_dict_BLD_viz = {}\n",
    "with torch.inference_mode():\n",
    "    _, cache = model.run_with_cache(inp_tokens, return_type=None) # Get all activations \n",
    "    for layer in range(n_layers):\n",
    "        resid_post_dict_BLD_viz[layer] = cache[f\"blocks.{layer}.attn.hook_pattern\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualization parameters\n",
    "nm = 200       # number of tokens to visualize\n",
    "layer_to_viz = 5 # layer index to visualize attention from\n",
    "\n",
    "str_tokens = [itos[token_id] for token_id in raw_tokens_for_game]\n",
    "\n",
    "# ensure nm is not greater than the actual sequence length\n",
    "max_seq_len_viz = min(nm, len(str_tokens))\n",
    "token_sub = str_tokens[:max_seq_len_viz]\n",
    "\n",
    "attention_for_layer = resid_post_dict_BLD_viz[layer_to_viz]\n",
    "attention_to_visualize = attention_for_layer[0, :, :max_seq_len_viz, :max_seq_len_viz]\n",
    "\n",
    "display(cv.attention.attention_patterns(\n",
    "    tokens=token_sub,\n",
    "    attention=attention_to_visualize\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Previous Move Heads\n",
    "\n",
    "In algebraic notation a move only includes the destination square and omits information about the square the piece moved from (except in rare cases where disambiguation is necessary).\n",
    "\n",
    "Howeover, heads 3.2 and 4.0 capture this \"from square\" information by attending to the \"to square\" of the previous move by that piece. Moreover, it seems that while Head 4.0 is a generic previous move head, Head 3.2 is a weaker version that specializes in finding previous moves for knights, bishops and kings.\n",
    "\n",
    "We observe that for these heads, the rank token position of a move attends to the rank token position of its previous move. For example, if a knight moves from its initial square to c3 (denoted Nc3) and then later moves from c3 to d5 (denoted Nd5), the \"5\" will direct most of its attention to the \"3\".\n",
    "\n",
    "To quantify this, for moves where the piece has already moved at least once, we identify the token with the maximal attention coefficient. Then, we check whether this matches the true previous move token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_prev_move_indices(moves_string: str, interested_pieces: list[int]) -> tuple[list[int]]:\n",
    "    \"\"\"\n",
    "    Scan a PGN move string and return (current move token index, previous move token index)\n",
    "    for any move involving a piece in 'interested_pieces' that has a corresponding\n",
    "    previous move\n",
    "\n",
    "    - Token indices point at the rank character of the move.\n",
    "    - Example: In \"Nc3 ... Nd5 ...\", the \"5\" in Nd5 pairs with the \"3\" in Nc3.\n",
    "    \"\"\"\n",
    "\n",
    "    # find the file/rank character of the pieces\n",
    "    indices = [idx for idx, char in enumerate(moves_string) if char == \" \"]\n",
    "    indices = indices[::2]\n",
    "    indices = [idx - 1 for idx in indices]\n",
    "\n",
    "    for i, idx in enumerate(indices):\n",
    "        if moves_string[idx] == \"+\":\n",
    "            indices[i] -= 1\n",
    "            \n",
    "    # the final move may also be valid\n",
    "    indices.append(len(moves_string) - 1)\n",
    "\n",
    "    # given a move, we know the to square and from square. There exists a corresponding\n",
    "    # previous move if a previous move with the same piece type moved to the from square\n",
    "    move_squares = [] \n",
    "    ret = []\n",
    "\n",
    "    moves = moves_string.strip().split(\".\")[1:]\n",
    "\n",
    "    # omit the first king move: it is unclear how castling is attended to\n",
    "    king_moved = False\n",
    "\n",
    "    board = chess.Board()\n",
    "    for i, move in enumerate(moves):\n",
    "        # process the white move\n",
    "        t = move.strip().split()\n",
    "        try:\n",
    "            white_move = t[0]\n",
    "            mv = board.parse_san(white_move)\n",
    "            piece = board.piece_at(mv.from_square)\n",
    "            board.push_san(white_move)\n",
    "        except:\n",
    "            break\n",
    "\n",
    "        to_square = mv.to_square \n",
    "\n",
    "        if piece.piece_type != chess.KING or king_moved:\n",
    "            move_squares.append((indices[i], to_square))\n",
    "\n",
    "        if piece.piece_type == chess.KING:\n",
    "            king_moved = True\n",
    "\n",
    "        # check if there is a corresponding previous move\n",
    "        if piece.piece_type in interested_pieces:\n",
    "            from_square = mv.from_square\n",
    "\n",
    "            for idx, square in reversed(move_squares):\n",
    "                if square == from_square:\n",
    "                    ret.append((indices[i], idx))\n",
    "                    break\n",
    "\n",
    "        # process the black move\n",
    "        try:\n",
    "            black_move = t[1]\n",
    "            board.push_san(black_move)\n",
    "        except:\n",
    "            break\n",
    "\n",
    "    return zip(*ret)\n",
    "\n",
    "def compute_prev_piece_acc(\n",
    "        df,\n",
    "        num_games=100,\n",
    "        interested_pieces=[chess.KNIGHT]):\n",
    "    \"\"\"\n",
    "    Measure how often the \"previous-move\" attention heads in layer 3 and 4 point correctly.\n",
    "\n",
    "    For each move with a corresponding previous move, see if the highest-attended token index matches the  \n",
    "    true index of that previous move.\n",
    "\n",
    "    Returns:\n",
    "        (accuracy_layer3, accuracy_layer4)\n",
    "    \"\"\"\n",
    "\n",
    "    def l3_attn_hook(attn, hook):\n",
    "        # attn: [batch, heads, seq, seq]\n",
    "        layer_attn[3] = attn.detach()\n",
    "        return attn\n",
    "\n",
    "    def l4_attn_hook(attn, hook):\n",
    "        # attn: [batch, heads, seq, seq]\n",
    "        layer_attn[4] = attn.detach()\n",
    "        return attn\n",
    "\n",
    "    # create attention hooks\n",
    "    l3_hook_name = \"blocks.3.attn.hook_pattern\"\n",
    "    l4_hook_name = \"blocks.4.attn.hook_pattern\"\n",
    "\n",
    "    hooks = [(l3_hook_name, l3_attn_hook), (l4_hook_name, l4_attn_hook)]\n",
    "\n",
    "    l3_head = 2\n",
    "    l4_head = 0\n",
    "\n",
    "    l3_correct = 0\n",
    "    l4_correct = 0\n",
    "    total = 0\n",
    "\n",
    "    prev_move_df = df[:num_games]\n",
    "\n",
    "    board_seqs_int_Bl = get_board_seqs_int(prev_move_df)\n",
    "    board_seqs_str_Bl = get_board_seqs_string(prev_move_df)\n",
    "\n",
    "    for seqs_int, seq_str in tqdm(zip(board_seqs_int_Bl, board_seqs_str_Bl),\n",
    "                                    total=len(board_seqs_int_Bl)):\n",
    "        try:\n",
    "            indices, prev_indices = filter_prev_move_indices(moves_string=seq_str, \n",
    "                                                            interested_pieces = interested_pieces)\n",
    "        except ValueError: \n",
    "            # sometimes there are no values to unpack\n",
    "            continue\n",
    "\n",
    "        layer_attn = {}\n",
    "\n",
    "        # hook into layers to find the attention matrices\n",
    "        with model.hooks(fwd_hooks=hooks):\n",
    "            model(seqs_int.unsqueeze(0))\n",
    "\n",
    "        indices_tensor = torch.tensor(indices)\n",
    "\n",
    "        l3_attn = layer_attn[3][0, l3_head, indices_tensor, :]\n",
    "        l4_attn = layer_attn[4][0, l4_head, indices_tensor, :]\n",
    "\n",
    "        # find the highest-attended tokens\n",
    "        l3_from = torch.argmax(l3_attn, dim=1)\n",
    "        l4_from = torch.argmax(l4_attn, dim=1)\n",
    "\n",
    "        # get the true previous move indices\n",
    "        prev_idx = torch.as_tensor(prev_indices, device=l3_attn.device)\n",
    "\n",
    "        # book‑keeping\n",
    "        batch_size = prev_idx.numel()\n",
    "        total      += batch_size\n",
    "\n",
    "        # correctness counts\n",
    "        l3_correct += (l3_from == prev_idx).sum().item()\n",
    "        l4_correct += (l4_from == prev_idx).sum().item()\n",
    "\n",
    "        del layer_attn\n",
    "        if DEVICE == \"mps\":\n",
    "            torch.mps.empty_cache()\n",
    "        elif DEVICE == \"cuda\":\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "    return (l3_correct / total, l4_correct / total)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How well the heads perform vary with the piece type. It looks like the heads have more difficulty with longe range pieces like rooks (\\~83% accurate) and queens (\\~91% accurate). The model can \"cheat\" for bishops by treating them as distinct; the c1 bishop will always be on a dark square, and the f1 bishop will always be on a light square. \n",
    "\n",
    "Furthermore, the model struggles with pawns (\\~56% accurate). Due to the limited movement range of a pawn, the destination square often already contains a lot of information about the origin square. Hence, the model may not need to explicitly find the previous move to infer the from square."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interested_piece = chess.KNIGHT # change to the interested piece type (e.g., chess.BISHOP)\n",
    "\n",
    "l3_acc, l4_acc = compute_prev_piece_acc(\n",
    "    df=df,\n",
    "    num_games=1000, # change to the desired number of games.\n",
    "    interested_pieces = [interested_piece] # modify for the desired set of piece types\n",
    ")\n",
    "\n",
    "print(f\"Layer 3 accuracy: {(100 * l3_acc):.2f}\")\n",
    "print(f\"Layer 4 accuracy: {(100 * l4_acc):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer 5 Pawn Average Head\n",
    "\n",
    "From the attention, we see that Head 5.1 attends mostly to pawn moves. Hence, we look for the linear representation of the pawn structure in this head. We propose the following method for computing the pawn structure:\n",
    "\n",
    "### Head Value Vector\n",
    "For every pawn move $t$, let from($t$) and to($t$) denote the from-square and to-square of the move. Then, let $v_s \\in \\mathbb{R}^d$ be a vector representation of the square $s$. We claim that the pre-projection value vector $z_t  = W^Vx_t$ computed by Head 5.1 satisfies:\n",
    "\n",
    "$$\n",
    "z_t \\approx -v_{\\text{from}(t)} + v_{\\text{to}(t)}\n",
    "$$\n",
    "\n",
    "### Averaging\n",
    "\n",
    "Then, at move $T$<sup>1</sup>, taking the average of $z_t$ over all pawn moves $t$ that occur before $T$ yields:\n",
    "\n",
    "$$\n",
    "Z_T = \\frac{1}{n} \\sum_{t \\leq T}z_t = \\frac{1}{n} \\left(-\\sum_{t \\leq T} v_{\\text{from}(t)} + \\sum_{t \\leq T} v_{\\text{to}(t)}\\right)\n",
    "$$\n",
    "\n",
    "Now, let us track a particular pawn. Every square, besides the initial square, that this pawn moves from must have previously been moved to. For example, if a pawn moves *from* e4 to e5, at some point, the pawn must have been moved *to* e4. Thus, the two sums nearly cancel out and we are left with:\n",
    "\n",
    "$$\n",
    "\\frac1n \\Bigl(\n",
    "      v_{\\text{final}(p)}\n",
    "      - v_{\\text{initial}(p)}\n",
    "   \\Bigr)\n",
    "$$\n",
    "\n",
    "If we let $P$ be the set of all moved pawns, we can sum over all pawns to obtain:\n",
    "$$\n",
    "Z_T = \\frac1n \\Bigl(\n",
    "      \\sum_{p \\in P} v_{\\text{final}(p)} - \\sum_{p \\in P} v_{\\text{initial}(p)}\n",
    "   \\Bigr)\n",
    "$$\n",
    "\n",
    "Unfortunately, this doesn't give us a completely clean linear representation of the pawn structure, which would be $\\sum v_{\\text{final}(p)}$; how the model deals with the $\\frac 1 n$ scaling and initial pawn positions is outside the scope of this work. However, we offer some discussion on this at the end of the section.\n",
    "\n",
    "### Captures\n",
    "When the opponent captures a pawn, the head attends to the dot token following the capturing move. Then, this capture token $t$ can erase the pawn on the square by having the value $z_t = -v_{\\text{capture square}}$. In our analysis, we treat these captures as \"pawn moves\" and include it in the $Z_T$ sum.\n",
    "\n",
    "---\n",
    "\n",
    "<sup>1</sup> When we say \"move $T$\", we generally refer to the dot token following that move. For example, given the game string:\n",
    "\n",
    "```text\n",
    "1.e4 c5 2.Nf3 Nc6 3.Nc3 e5 4.Bc4 d6 5.d3 Be7 6.Nd5\n",
    "```\n",
    "the $T$ that contains all of the pawn moves up to and including d3 would the index of the dot following the 6.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load precomputed data (code can be found in generate_data.py)\n",
    "\n",
    "cache_file = os.path.join(DATA_DIR, \"precomputed_game_cache.pt\")\n",
    "\n",
    "with open(cache_file, \"rb\") as f:\n",
    "    data = torch.load(f, weights_only=False)\n",
    "\n",
    "# per move data\n",
    "indices = data[\"index\"] # the indices of the rank token for each move\n",
    "head_v = data[\"head_v\"] # the value vectors associated with the rank token position\n",
    "to_squares = data[\"to\"] # for each move, the to-square\n",
    "frm_squares = data[\"from\"] # for each move, the from-square\n",
    "piece_types = data[\"piece_type\"] # for each move, the piece type (the captured piece if it was a capture by the opponent)\n",
    "\n",
    "# per dot data\n",
    "dots_indices = data[\"dots_game_index\"] # the indices of the dots\n",
    "dots_attn = data[\"dots_attn\"] # the attention matrices for the dots\n",
    "board_stacks = data[\"board_state\"] # the board state at the dot\n",
    "\n",
    "# perform sanity checks on the data\n",
    "n = len(to_squares)\n",
    "for i in range(n):\n",
    "    piece_len = to_squares[i].shape[0]\n",
    "\n",
    "    assert frm_squares[i].shape[0] == piece_len\n",
    "    assert len(piece_types[i]) == piece_len    \n",
    "    assert head_v[i].shape[1] == piece_len\n",
    "    assert indices[i].shape[0] == piece_len\n",
    "    assert board_stacks[i].shape[0] == piece_len\n",
    "\n",
    "    dots_len = dots_indices[i].shape[0]\n",
    "    \n",
    "    assert dots_attn[i].shape[2] == dots_len\n",
    "    assert dots_attn[i].shape[3] == piece_len"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cosine Similarity Analysis\n",
    "\n",
    "To see if the model is truly taking an average over the pawn moves $t$, we compare $Z_T$ to the true head contribution. We say that the true contribution is given as:\n",
    "\n",
    "$$\n",
    "H_T = \\sum_{t \\leq T} \\alpha_t z_t\n",
    "$$\n",
    "\n",
    "where $\\alpha_t$ is the attention score for token $t$ at move $T$. We emphasize that this **only includes the contributions from the rank tokens**, as they tend to have the largest attention weights. We exclude the delimiter (;) and all other non-rank tokens from the computation.\n",
    "\n",
    "In this section, we compute the cosine similarity between $H_T$ and $Z_T$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_piece_sims(interested_head, alpha, piece_type, num_games=100, parity_req = None):\n",
    "    \"\"\"\n",
    "    Check cosine similarity between actual head and reconstructed head for piece piece_type.\n",
    "    The parity requirement restricts the partiy of the square (e.g., 0 for dark squares and 1 for light squares)\n",
    "    and is used to distinguish between the bishops.\n",
    "\n",
    "    Computes and returns the similarities between the head and:\n",
    "     - EMA with decay factor alpha\n",
    "     - most recent move\n",
    "     - average\n",
    "    \"\"\"\n",
    "    sims = []\n",
    "    last_sims = []\n",
    "    avg_sims = []\n",
    "\n",
    "    for i in tqdm(range(num_games)):\n",
    "        piece_len = to_squares[i].shape[0]\n",
    "        piece_indices = []\n",
    "\n",
    "        cur_ema = 0\n",
    "        cur_sum = 0\n",
    "\n",
    "        emas = []\n",
    "        sums = []\n",
    "\n",
    "        for j in range(piece_len):\n",
    "            piece = piece_types[i][j][1]\n",
    "            square = to_squares[i][j]\n",
    "            rank = square // 8\n",
    "            file = square % 8\n",
    "            parity = (rank + file) % 2\n",
    "            meets_parity = True\n",
    "            if parity_req is not None:\n",
    "                meets_parity = (parity == parity_req)\n",
    "\n",
    "            if piece == piece_type and meets_parity:\n",
    "                piece_indices.append(j)\n",
    "\n",
    "                piece_v = head_v[i][0, j, interested_head, :]\n",
    "\n",
    "                # update sum and ema\n",
    "                cur_ema = alpha * cur_ema + (1 - alpha) * piece_v\n",
    "                cur_sum = cur_sum + piece_v\n",
    "\n",
    "            emas.append(cur_ema) \n",
    "            sums.append(cur_sum)\n",
    "\n",
    "        dots_len = dots_indices[i].shape[0]\n",
    "        for j in range(dots_len):\n",
    "            dot_idx = dots_indices[i][j]\n",
    "            cur_piece_indices = [idx for idx in piece_indices if indices[i][idx] <= dot_idx]\n",
    "\n",
    "            # skip if no piece moves yet\n",
    "            if len(cur_piece_indices) == 0:\n",
    "                continue\n",
    "\n",
    "            last_idx = cur_piece_indices[-1]\n",
    "\n",
    "            # find the most recent avg and ema values at this move\n",
    "            last_ema = emas[last_idx]\n",
    "            last = head_v[i][0, last_idx, interested_head, :]\n",
    "            last_avg = sums[last_idx] / len(cur_piece_indices)\n",
    "\n",
    "            cur_piece_indices = torch.tensor(cur_piece_indices)\n",
    "\n",
    "            # compute the true head value from the piece\n",
    "            attn = dots_attn[i][0, interested_head, j, :][cur_piece_indices]\n",
    "            v = head_v[i][0, cur_piece_indices, interested_head, :]\n",
    "            act = torch.einsum(\"l, l d -> d\", attn, v)\n",
    "\n",
    "            # compute cosine similarities\n",
    "            sim = torch.nn.functional.cosine_similarity(last_ema, act, dim=0)\n",
    "            sims.append(sim.item())\n",
    "\n",
    "            sim = torch.nn.functional.cosine_similarity(last, act, dim=0)\n",
    "            last_sims.append(sim.item())\n",
    "\n",
    "            sim = torch.nn.functional.cosine_similarity(last_avg, act, dim=0)\n",
    "            avg_sims.append(sim.item())\n",
    "\n",
    "    return sims, last_sims, avg_sims"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The averaging scheme obtains a cosine similarity of roughly 0.9 with the true head contribution. The baseline of just using the last pawn move's value vector obtains a cosine similarity of roughly 0.44"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, last_sims, avg_sims = get_piece_sims(\n",
    "    interested_head=1,\n",
    "    alpha=0,\n",
    "    piece_type=chess.PAWN,\n",
    "    num_games=1000,\n",
    "    parity_req=None\n",
    ")\n",
    "\n",
    "print(f\"Cosine Similarity between most recent move and head output: {torch.mean(torch.tensor(last_sims)).item():.4f}\")\n",
    "print(f\"Cosine Similarity between average and head output: {torch.mean(torch.tensor(avg_sims)).item():.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Value Vector Analysis\n",
    "\n",
    "Earlier, we claimed that the value vector $z$ is the sum $-v_{\\text{from}} + v_{\\text{to}}$. Then, we should be able to extract the to and from squares from $z$ using a linear probe. Let $y_t \\in \\{0,1\\}^{64}$ be the one-hot encoding of the to-square of the move. We train a linear probe  \n",
    "$$\n",
    "\\hat{y}_t = \\mathrm{softmax}\\bigl(W\\,z_t\\bigr)\n",
    "$$\n",
    "\n",
    "by minimizing the average cross-entropy loss. We also repeat this process for the from squares.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_to_from_stack(head, piece_reqs, num_games=100, include_dots=False):\n",
    "    \"\"\"\n",
    "    Piece reqs is a tuple that contains the piece type (e.g., chess.KNIGHT) and the square parity:\n",
    "    None, 0, or 1 for no square restrictions, darks squares, and light squares respectively\n",
    "\n",
    "    For each move of the specified piece type, this function returns the head output, the to square, \n",
    "    and the from square\n",
    "    \"\"\"\n",
    "    piece_type = piece_reqs[0]\n",
    "    parity_req = piece_reqs[1]\n",
    "\n",
    "    piece_state_stack = []\n",
    "    piece_to_stack = []\n",
    "    piece_from_stack = []\n",
    "\n",
    "    for i in tqdm(range(num_games)):\n",
    "        piece_len = to_squares[i].shape[0]\n",
    "\n",
    "        dots_index_list = dots_indices[i].tolist()\n",
    "\n",
    "        for j in range(piece_len):\n",
    "            if not include_dots and indices[i][j] in dots_index_list: \n",
    "                continue\n",
    "\n",
    "            piece = piece_types[i][j][1]\n",
    "\n",
    "            square = to_squares[i][j]\n",
    "            rank = square // 8\n",
    "            file = square % 8\n",
    "            parity = (rank + file) % 2\n",
    "\n",
    "            meets_parity = True\n",
    "            if parity_req is not None:\n",
    "                meets_parity = (parity == parity_req)\n",
    "\n",
    "            if piece == piece_type and meets_parity:\n",
    "                piece_v = head_v[i][0, j, head, :]\n",
    "\n",
    "                piece_state_stack.append(piece_v)\n",
    "                piece_to_stack.append(to_squares[i][j])\n",
    "                piece_from_stack.append(frm_squares[i][j]) \n",
    "\n",
    "    return torch.stack(piece_state_stack), torch.stack(piece_to_stack), torch.stack(piece_from_stack)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The information for when the opponent captures a pawn is stored in the dots. Since the dots do not have \"to squares\", we ignore them for now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_state_stack, pawn_to_stack, pawn_from_stack = generate_to_from_stack(1, (chess.PAWN, None), num_games=10000, include_dots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# probe training utils\n",
    "# the same default hyperparameters are used to train all probes\n",
    "\n",
    "@dataclass\n",
    "class TrainingParams:\n",
    "    wd: float = 0.01\n",
    "    lr: float = 0.001\n",
    "    beta1: float = 0.9\n",
    "    beta2: float = 0.99\n",
    "    max_train_games: int = 10000\n",
    "    max_test_games: int = 10000\n",
    "    max_val_games: int = 1000\n",
    "    max_iters: int = 50000\n",
    "    eval_iters: int = 50\n",
    "    num_epochs: int = 100\n",
    "\n",
    "@dataclass\n",
    "class LinearProbe:\n",
    "    linear_probe: torch.Tensor\n",
    "    probe_name: str\n",
    "    optimiser: torch.optim.AdamW\n",
    "    loss: torch.Tensor = torch.tensor(0.0)\n",
    "    accuracy: torch.Tensor = torch.tensor(0.0)\n",
    "    accuracy_queue: collections.deque = field(\n",
    "        default_factory=lambda: collections.deque(maxlen=1000)\n",
    "    )\n",
    "\n",
    "def create_linear_probe(train_params, num_classes, has_rc=False, num_rows=8, num_cols=8, dim=64):\n",
    "    linear_probe_name = \"probe\"\n",
    "\n",
    "    if has_rc:\n",
    "        linear_probe_DC = torch.randn(\n",
    "            dim,\n",
    "            num_rows,\n",
    "            num_cols,\n",
    "            num_classes,\n",
    "            requires_grad=False,\n",
    "            device=DEVICE,\n",
    "        ) / torch.sqrt(torch.tensor(D_MODEL))\n",
    "    else:\n",
    "        linear_probe_DC = torch.randn(\n",
    "            dim,\n",
    "            num_classes,\n",
    "            requires_grad=False,\n",
    "            device=DEVICE,\n",
    "        ) / torch.sqrt(torch.tensor(D_MODEL))\n",
    "\n",
    "    linear_probe_DC.requires_grad = True\n",
    "\n",
    "    optimiser = torch.optim.AdamW(\n",
    "        [linear_probe_DC],\n",
    "        lr=train_params.lr,\n",
    "        betas=(train_params.beta1, train_params.beta2),\n",
    "        weight_decay=train_params.wd,\n",
    "    )\n",
    "    linear_probe = LinearProbe(\n",
    "        linear_probe=linear_probe_DC,\n",
    "        probe_name=linear_probe_name,\n",
    "        optimiser=optimiser,\n",
    "    )\n",
    "    return linear_probe\n",
    "\n",
    "def linear_probe_forward_rc(probe, batch_data, batch_labels):\n",
    "    logits = torch.einsum(\"bd,dxyc->bcxy\", batch_data, probe.linear_probe)\n",
    "    loss = F.cross_entropy(logits, batch_labels)\n",
    "    return logits, loss\n",
    "\n",
    "def linear_probe_forward_mse(probe, batch_data, batch_labels):\n",
    "    logits = torch.einsum(\"bd,dc->bc\", batch_data, probe.linear_probe)\n",
    "    loss = F.mse_loss(logits, batch_labels)\n",
    "    return logits, loss\n",
    "\n",
    "def linear_probe_forward(probe, batch_data, batch_labels):\n",
    "    logits = torch.einsum(\"bd,dc->bc\", batch_data, probe.linear_probe)\n",
    "    loss = F.cross_entropy(logits, batch_labels)\n",
    "    return logits, loss\n",
    "\n",
    "def train_probe(probe, state_stack, label_stack, batch_size, num_epochs, probe_fwd):\n",
    "    \"\"\"\n",
    "    Trains the probe with the state_stack as inputs, label_stack as labels. The loss\n",
    "    function is specified in probe_fwd\n",
    "\n",
    "    Returns the validation/training loss and accuracy.\n",
    "    \"\"\"\n",
    "\n",
    "    labels = label_stack.to(DEVICE)\n",
    "\n",
    "    VAL_FRAC      = 0.10\n",
    "    num_samples   = state_stack.size(0)\n",
    "    num_val       = int(num_samples * VAL_FRAC)\n",
    "    num_train     = num_samples - num_val\n",
    "\n",
    "    dataset       = TensorDataset(state_stack, labels)\n",
    "    train_set, val_set = random_split(dataset, [num_train, num_val])\n",
    "\n",
    "    train_loader  = DataLoader(train_set, batch_size=batch_size,\n",
    "                            shuffle=True,  drop_last=False)\n",
    "    val_loader    = DataLoader(val_set,   batch_size=batch_size,\n",
    "                            shuffle=False, drop_last=False)\n",
    "\n",
    "    last_val_acc = None\n",
    "    last_val_loss = None\n",
    "    last_train_acc = None\n",
    "    last_train_loss = None\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        # train forward pass\n",
    "        for batch_data, batch_labels in tqdm(train_loader):\n",
    "            batch_data   = batch_data.to(DEVICE)\n",
    "            batch_labels = batch_labels.to(DEVICE)\n",
    "\n",
    "            logits, probe.loss = probe_fwd(probe, batch_data, batch_labels)\n",
    "\n",
    "            preds = logits.argmax(dim=1)\n",
    "\n",
    "            # if mse is the loss, accuracy is not a meaningful metric\n",
    "            if probe_fwd != linear_probe_forward_mse:\n",
    "                probe.accuracy = (preds == batch_labels).float().mean() \n",
    "                probe.accuracy_queue.append(probe.accuracy.item())\n",
    "                probe.accuracy = torch.tensor(sum(probe.accuracy_queue) / len(probe.accuracy_queue))\n",
    "\n",
    "            probe.optimiser.zero_grad()\n",
    "            probe.loss.backward()\n",
    "            probe.optimiser.step()\n",
    "\n",
    "        # validation\n",
    "        val_loss = 0.0\n",
    "        val_correct = 0\n",
    "        val_seen = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for vdata, vlabels in val_loader:\n",
    "                vdata   = vdata.to(DEVICE)\n",
    "                vlabels = vlabels.to(DEVICE)\n",
    "\n",
    "                logits, loss = probe_fwd(probe, vdata, vlabels)\n",
    "\n",
    "                val_loss += loss.item() * vlabels.size(0)\n",
    "\n",
    "                preds = logits.argmax(dim=1)\n",
    "                if probe_fwd != linear_probe_forward_mse:\n",
    "                    val_correct += (preds == vlabels).float().mean() * vlabels.size(0)\n",
    "                val_seen += vlabels.size(0)\n",
    "\n",
    "        val_accuracy = val_correct / val_seen\n",
    "        val_loss /= val_seen\n",
    "\n",
    "        last_val_loss = val_loss\n",
    "        last_val_acc = val_accuracy\n",
    "        last_train_acc = probe.accuracy.item()\n",
    "        last_train_loss = probe.loss.item()\n",
    "\n",
    "        print(f\"Epoch {epoch:3d} │ \"\n",
    "            f\"train loss {probe.loss.item():.5f} │ \"\n",
    "            f\"train acc {probe.accuracy.item():.4f} │ \"\n",
    "            f\"val loss {val_loss:.5f} │ \"\n",
    "            f\"val acc {val_accuracy:.4f}\")\n",
    "\n",
    "    return last_val_loss, last_val_acc, last_train_loss, last_train_acc\n",
    "\n",
    "def test_probe(\n",
    "    probe,\n",
    "    state_stack,\n",
    "    label_stack,\n",
    "    batch_size=512,\n",
    "    probe_fwd=None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Tests the probe with the state_stack as inputs, label_stack as labels. The loss\n",
    "    function is specified in probe_fwd.\n",
    "\n",
    "    Returns the loss and accuracy on the dataset.\n",
    "    \"\"\"\n",
    "\n",
    "    assert probe_fwd is not None, \"`probe_fwd` helper (forward pass) must be supplied\"\n",
    "\n",
    "    labels = label_stack.to(DEVICE)\n",
    "\n",
    "    loader = DataLoader(\n",
    "        TensorDataset(state_stack, labels),\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        drop_last=False,\n",
    "    )\n",
    "\n",
    "    val_loss = 0.0\n",
    "    val_correct = 0\n",
    "    val_seen = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for vdata, vlabels in loader:\n",
    "            vdata   = vdata.to(DEVICE)\n",
    "            vlabels = vlabels.to(DEVICE)\n",
    "\n",
    "            logits, loss = probe_fwd(probe, vdata, vlabels)\n",
    "\n",
    "            val_loss += loss.item() * vlabels.size(0)\n",
    "\n",
    "            preds = logits.argmax(dim=1)\n",
    "            if probe_fwd != linear_probe_forward_mse:\n",
    "                val_correct += (preds == vlabels).float().mean() * vlabels.size(0)\n",
    "            val_seen += vlabels.size(0)\n",
    "\n",
    "    val_accuracy = val_correct / val_seen\n",
    "    val_loss /= val_seen\n",
    "\n",
    "    return val_loss, val_accuracy\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probe_params = TrainingParams() # We use the same hyperparameters across all linear probes\n",
    "pawn_to_probe = create_linear_probe(probe_params, 64)\n",
    "pawn_from_probe = create_linear_probe(probe_params, 64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both probes achieve an accuracy of over 99.7%, which strongly indicates that $z_t$ encodes both the origin and destination square."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_vl_to, pawn_va_to, pawn_tl_to, pawn_ta_to = train_probe(pawn_to_probe, \n",
    "                            pawn_state_stack,\n",
    "                            pawn_to_stack,\n",
    "                            batch_size=64,\n",
    "                            num_epochs=3,\n",
    "                            probe_fwd=linear_probe_forward\n",
    "                        ) \n",
    "\n",
    "pawn_vl_from, pawn_va_from, pawn_tl_from, pawn_ta_from = train_probe(pawn_from_probe, \n",
    "                            pawn_state_stack,\n",
    "                            pawn_from_stack,\n",
    "                            batch_size=64,\n",
    "                            num_epochs=3,\n",
    "                            probe_fwd=linear_probe_forward\n",
    "                        ) \n",
    "\n",
    "print(f\"To square probe final validation accuracy: {100 * pawn_va_to:.2f}\")\n",
    "print(f\"From square probe final validation accuracy: {100 * pawn_va_from:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another assumption we make is that the to and from square vectors belong to the same subspace and point in opposite directions. This is important for cancellation during averaging, since \"moving to square $s$\" should negate \"moving from square $s$\".\n",
    "\n",
    "If this is truly the case, then negating $z_t$ should reverse the roles of the to and from-squares. Hence, we combine the to and from-square datasets such that for each move $t$, we include both $(z_t, \\text{to square}_t)$ and $(-z_t, \\text{from square}_t)$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_pawn_state_stack = torch.cat([pawn_state_stack, -pawn_state_stack])\n",
    "combined_pawn_square_stack = torch.cat([pawn_to_stack, pawn_from_stack])\n",
    "\n",
    "combined_pawn_probe = create_linear_probe(probe_params, 64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The probe obtains an accuracy of 99.9%, demonstrating that negating the value vector reliably swaps the to and from labels.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_vl, pawn_va, pawn_tl, pawn_ta = train_probe(combined_pawn_probe,\n",
    "                            combined_pawn_state_stack,\n",
    "                            combined_pawn_square_stack,\n",
    "                            batch_size=64,\n",
    "                            num_epochs=3,\n",
    "                            probe_fwd=linear_probe_forward\n",
    "                        ) \n",
    "\n",
    "print(f\"Combined final validation accuracy: {100 * pawn_va:.2f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Value Vector Reconstruction\n",
    "\n",
    "Finally, we want to confirm that the to and from vectors make up most of the information that $z_t$ carries. Thus, we examine whether we can reconstruct the value vector using only the to and from square information (and just the from square for captures).\n",
    "\n",
    "\n",
    "In the following section, we create a 128-dimensional input for each pawn move by concatenating the two 64-dimensional one-hot square vectors. Then, we train a linear probe to reconstruct the head value vector using a MSE loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_state_stack_all, pawn_to_stack_all, pawn_from_stack_all = generate_to_from_stack(1, (chess.PAWN, None), num_games=10000, include_dots=True)\n",
    "\n",
    "# split the dataset between the captures and normal moves (captures don't have \"to\" squares)\n",
    "normal_idx = (torch.nonzero((pawn_from_stack_all != -1), as_tuple=True))[0]\n",
    "capture_idx = (torch.nonzero((pawn_from_stack_all == -1), as_tuple=True))[0]\n",
    "\n",
    "# extract the move value vectors (the target vectors we want to reconstruct)\n",
    "normal_state_stack = pawn_state_stack_all[normal_idx]\n",
    "capture_state_stack = pawn_state_stack_all[capture_idx]\n",
    "\n",
    "normal_to = pawn_to_stack_all[normal_idx]\n",
    "normal_from = pawn_from_stack_all[normal_idx]\n",
    "\n",
    "capture_to = pawn_to_stack_all[capture_idx]\n",
    "\n",
    "to_idx = normal_to.squeeze(-1)\n",
    "from_idx = normal_from.squeeze(-1)\n",
    "\n",
    "to_onehot   = F.one_hot(to_idx,  num_classes=64).float()\n",
    "from_onehot = F.one_hot(from_idx, num_classes=64).float()\n",
    "\n",
    "# the to square should never equal the from squares\n",
    "assert (to_idx != from_idx).all(), \"Found a sample where to‑idx == from‑idx!\"\n",
    "\n",
    "to_from_stack = torch.cat([from_onehot, to_onehot], dim=1)\n",
    "\n",
    "to_idx_capture = capture_to.squeeze(-1)\n",
    "\n",
    "# for captures, the capture square is negated, so it should belong with the from squares\n",
    "from_onehot_capture = F.one_hot(to_idx_capture, num_classes=64).float()\n",
    "to_onehot_capture = torch.zeros_like(from_onehot_capture)\n",
    "\n",
    "to_from_stack_capture = torch.cat([from_onehot_capture, to_onehot_capture], dim=1)\n",
    "\n",
    "to_from_stack_all = torch.cat([to_from_stack, to_from_stack_capture])\n",
    "pawn_state_stack = torch.cat([normal_state_stack, capture_state_stack])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The reconstruction probe obtains a validation MSE of around 0.19. The average variance of the value vector is \\~5 for capture moves and \\~1.8 for non-capture moves. Although these are different, we can take the overall average variance (\\~4.33) to estimate the $R^2$:\n",
    "\n",
    "$$\n",
    "R^2 = 1 - \\frac{\\text{MSE}}{{VAR}} = 1 - \\frac{0.19}{0.433} \\approx 0.956\n",
    "$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Mean variance of the value vectors: {torch.mean(torch.var(pawn_state_stack_all, dim=1)).item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_reconstruction_probe = create_linear_probe(\n",
    "    train_params=probe_params,\n",
    "    num_classes=64,\n",
    "    dim=128\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_reconstruction_vl, _, pawn_reconstruction_tl, _ = train_probe(\n",
    "    pawn_reconstruction_probe,\n",
    "    to_from_stack_all,\n",
    "    pawn_state_stack,\n",
    "    64,\n",
    "    5,\n",
    "    linear_probe_forward_mse\n",
    ")\n",
    "\n",
    "print(f\"Final validation MSE loss for pawn reconstruction: {pawn_reconstruction_vl:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We observe that the norms of the capture moves is lower than the norm of the non capture moves. This makes sense since $z_t = -v_{\\text{from}(t)}$ for captures whereas $z_t = -v_{\\text{from}(t)} + v_{\\text{to}(t)}$ for non capture moves. If the square vectors were all the same magnitude, then we would expect the non capture magnitudes to be roughly $\\sqrt{2}$ times the capture magnitudes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "norms = torch.norm(capture_state_stack, dim=1)\n",
    "print(f\"Mean of norms for capture states: {torch.mean(norms).item():2f}\")\n",
    "print(f\"Standard deviation of norms for capture states: {torch.std(norms).item():2f}\")\n",
    "norms = torch.norm(normal_state_stack, dim=1)\n",
    "print(f\"Mean of norms for non capture states: {torch.mean(norms).item():2f}\")\n",
    "print(f\"Standard deviation of norms for non capture states: {torch.std(norms).item():2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Board Reconstruction\n",
    "\n",
    "We want to verify that the average $Z_T$ is capable of recovering the full pawn structure. For each of the 64 squares, we train a linear probe to classify whether the square is blank or contains a white pawn. More formally, the probe computes:\n",
    "\n",
    "$$\n",
    "\\hat y_{i,j, T} = \\text{Softmax}(W_i Z_T)\n",
    "$$\n",
    "\n",
    "where $\\hat y_{i,j, T}$ is the probability distribution over the two classes $j \\in \\{0, 1\\}$ for square $i$ on move $T$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_avg_stack(head, piece_reqs, num_games=100):\n",
    "    \"\"\"\n",
    "    For each dot index, we look at all of the moves by the piece (specified in piece reqs) and\n",
    "    average the head output those moves. This information is bundled with the true head contribution\n",
    "    as well as the board state.\n",
    "    \"\"\"\n",
    "\n",
    "    piece_type = piece_reqs[0]\n",
    "    parity_req = piece_reqs[1]\n",
    "\n",
    "    sims = []\n",
    "\n",
    "    head_state_stack = []\n",
    "    avg_state_stack = []\n",
    "    piece_board_stack = []\n",
    "\n",
    "    for i in tqdm(range(num_games)):\n",
    "        piece_len = to_squares[i].shape[0]\n",
    "\n",
    "        dots_index_list = dots_indices[i].tolist()\n",
    "\n",
    "        cur_sum = torch.zeros_like(head_v[i][0, 0, head, :])\n",
    "\n",
    "        cur_piece_indices = []\n",
    "        nm = 0\n",
    "\n",
    "        for j in range(piece_len):\n",
    "            piece = piece_types[i][j][1]\n",
    "\n",
    "            square = to_squares[i][j]\n",
    "            rank = square // 8\n",
    "            file = square % 8\n",
    "            parity = (rank + file) % 2\n",
    "\n",
    "            meets_parity = True\n",
    "            if parity_req is not None:\n",
    "                meets_parity = (parity == parity_req)\n",
    "\n",
    "            if piece == piece_type and meets_parity:\n",
    "                cur_piece_indices.append(j)\n",
    "                piece_v = head_v[i][0, j, head, :]\n",
    "\n",
    "                # The board state at the dot includes the most recent black move. So, if the black move\n",
    "                # was a capture, we need to replace the previous information\n",
    "                replace = False\n",
    "                nm += 1\n",
    "\n",
    "                if indices[i][j] not in dots_index_list:\n",
    "                    cur_sum = cur_sum + piece_v\n",
    "                else:\n",
    "                    # if the character is a dot, it was a capture by black\n",
    "                    replace = True\n",
    "                    cur_sum = cur_sum + piece_v\n",
    "\n",
    "                # find the next dot\n",
    "                dot_idx = bisect.bisect_left(dots_index_list, indices[i][j])\n",
    "\n",
    "                if dot_idx >= len(dots_index_list):\n",
    "                    break \n",
    "                \n",
    "                # compute the true head contribution\n",
    "                attn = dots_attn[i][0, head, dot_idx, :][torch.tensor(cur_piece_indices)]\n",
    "                v = head_v[i][0, torch.tensor(cur_piece_indices), head, :]\n",
    "\n",
    "                # multiply v and attn\n",
    "                act = torch.einsum(\"l, l d -> d\", attn, v)\n",
    "\n",
    "                # compute cosine similarity between act and cur_sum \n",
    "                sim = torch.nn.functional.cosine_similarity(cur_sum, act, dim=0)\n",
    "\n",
    "                cur_board_stack = board_stacks[i][j].clone()\n",
    "                for r in range(8):\n",
    "                    for c in range(8):\n",
    "                        parity = (r + c) % 2\n",
    "                        meets_parity = True\n",
    "                        if parity_req is not None:\n",
    "                            meets_parity = (parity == parity_req)\n",
    "\n",
    "                        if not meets_parity or cur_board_stack[r][c] != piece_type:\n",
    "                            cur_board_stack[r][c] = 0\n",
    "                        else:\n",
    "                            cur_board_stack[r][c] = 1\n",
    "\n",
    "                # the previous board stack was invalid because it did not include the black capture\n",
    "                if replace:\n",
    "                    piece_board_stack[-1] = cur_board_stack\n",
    "                    avg_state_stack[-1] = cur_sum / nm\n",
    "                    head_state_stack[-1] = act\n",
    "                    sims[-1] = sim.item()\n",
    "                else:\n",
    "                    piece_board_stack.append(cur_board_stack)\n",
    "                    avg_state_stack.append(cur_sum / nm)\n",
    "                    head_state_stack.append(act)\n",
    "                    sims.append(sim.item())\n",
    "\n",
    "    # print the cosine similarity for checks\n",
    "    print(f\"Final similarity: {torch.mean(torch.tensor(sims)).item():.4f}\")\n",
    "\n",
    "    return torch.stack(avg_state_stack), torch.stack(piece_board_stack), torch.stack(head_state_stack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_avg_stack, pawn_board_stack, pawn_head_out = construct_avg_stack(1, (chess.PAWN, None), num_games=10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This board probe achieves an accuracy of 99.5% (compared to a baseline of \\~91% if all squares are guessed to be blank), confirming that the averaging scheme can accurately recover the pawn structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pawn_board_probe = create_linear_probe(train_params=probe_params,\n",
    "                                       num_classes=2,\n",
    "                                       has_rc=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vl, va, tl, ta = train_probe(\n",
    "    pawn_board_probe,\n",
    "    pawn_avg_stack,\n",
    "    pawn_board_stack.long(),\n",
    "    64,\n",
    "    5,\n",
    "    linear_probe_forward_rc\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, to see if this average scheme is similar to what the model is computing, we can apply the probe (trained on the averages) to the true head contribution. This achieves an accuracy of 98.9%, which suggests that the average and the true contribution encode the pawn structure similarly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vl, va = test_probe(\n",
    "    pawn_board_probe,\n",
    "    pawn_avg_stack,\n",
    "    pawn_board_stack.long(),\n",
    "    probe_fwd=linear_probe_forward_rc\n",
    ")\n",
    "\n",
    "print(f\"Accuracy using the computed average: {(100 * va.item()):.2f}\")\n",
    "\n",
    "vl, va = test_probe(\n",
    "    pawn_board_probe,\n",
    "    pawn_head_out,\n",
    "    pawn_board_stack.long(),\n",
    "    probe_fwd=linear_probe_forward_rc\n",
    ")\n",
    "\n",
    "print(f\"Accuracy using the true head: {(100 * va.item()):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Open Questions\n",
    "\n",
    "As mentioned earlier, $Z_T = \\frac{1}{n} \\left( \\sum v_{\\text{final}(p)} - \\sum v_{\\text{initial}(p)} \\right)$ is not a perfect representation. We would like to scale by $n$ and add $v_{\\text{initial}}$ for every single pawn.\n",
    "\n",
    "We would also like to note that even a linear probe can somewhat overcome these limitations as the board probe could retrieve the pawn structure with relatively high accuracy. However, things get messier when multiple piece types are involved and the dimensions may become overloaded due to superposition.\n",
    "\n",
    "#### Scaling\n",
    "One way that the model can remedy the scaling issue is by using the ';' token at the beginning of the game string. When there are fewer pawn moves, it can dump more attention into the ';' token and even out the scaling as pawn moves are made. We anecdotally observe an increase in overall attention on the pawn moves (and less on the ';') as pawn moves are made, but have not investigated this rigorously.\n",
    "\n",
    "Moreover, there are a finite number of pawn moves (48) and not *too* many pawn moves per game (especially not until the endgame), so it is also possible that not having the exact scale is ok.\n",
    "\n",
    "#### Initial Pawn Positions\n",
    "Another question that arises is how the model knows the \"initial pawn positions.\" Since the initial positions are fixed, it is possible that the model could add a constant for each pawn (including unmoved pawns). \n",
    "\n",
    "However, a linear probe is expressive enough to deal with this issue and adopts an alternative strategy. We train a board probe (not included in this notebook) using the average of the 128-dimension to/from vectors in the value vector reconstruction section. We find that for most to-square dimensions, the probe injects a small positive bias onto all eight starting pawn squares. These small contributions are enough to mark the square as occupied, but when a pawn actually leaves the square, the subtraction term for the from-square overwhelms that bias.\n",
    "\n",
    "In addition, the probe's learned \"occupied\" and \"blank\" vectors are almost antipodal, with a cosine similarity of -0.999. Because the probe's task here is binary classification, it can afford to have a large negative direction for blank squares. However, this strategy may not be viable when multiple piece types (and not enough dimensions) are involved.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer 5 Exponential Moving Average Heads\n",
    "\n",
    "In layer 5, multiple heads work together to construct the positions of the remaining pieces (except the King). Although taking the average (similar to the pawn average head) is a theoretically viable strategy, it seems that the model opts for a different tactic. \n",
    "\n",
    "We begin by **fixing** a piece type (e.g., knights).\n",
    "\n",
    "### Head Value Vector\n",
    "For every knight move $t$, let from($t$) and to($t$) denote the from-square and to-square of the move. Then, let $v_s \\in \\mathbb{R}^d$ be a vector representation of the square $s$ for knights. We claim that the pre-projection value vector $z_t  = W^Vx_t$ computed by Head 5.6 satisfies:\n",
    "\n",
    "$$\n",
    "z_t \\approx -\\alpha \\: v_{\\text{from}(t)} + \\: v_{\\text{to}(t)} + u_t\n",
    "$$\n",
    "\n",
    "with fixed decay factor $\\alpha \\in (0, 1)$, and auxiliary information $u_t$.\n",
    "\n",
    "### EMA\n",
    "\n",
    "Now, we claim the model keeps track of an exponential moving average $Z_T$ where $Z_0 = \\mathbf{0}$ and:\n",
    "\n",
    "$$\n",
    "    Z_T = \\alpha \\: Z_{T-1} + (1 - \\alpha) \\: z_T\n",
    "$$\n",
    "\n",
    "If there were $n$ piece moves before $T$, we can denote $t_k$ as the $k$ th piece move. Expanding the recurrence yields:\n",
    "\n",
    "$$\n",
    "    Z_T = (1 - \\alpha) \\sum_{k=1}^n \\alpha^{n-k} z_{t_k} \n",
    "$$\n",
    "\n",
    "If we track only a single piece, then $v_{\\text{from}(t_k)}$ = $v_{\\text{to}(t_{k-1})}$ and cancel out. Then, we are left with \n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "Z_T &= (1 - \\alpha) \\: v_{\\text{final}} - (1-\\alpha)\\alpha^{n-1} \\: v_{\\text{initial}} + u \\\\\n",
    "&\\approx (1 - \\alpha) \\: v_{\\text{final}} + u\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "where $u$ is the sum over all $u_t$. This essentially gives us the final position of the piece, as the initial vector decays exponentially. Similar to the averaging scheme, on a capture, we can let $\\text{from}(t)$ be the capture square and $v_{\\text{to}(t)} = \\mathbf{0}$.\n",
    "\n",
    "### Persisting Companion Pieces\n",
    "\n",
    "The previous strategy works for \"unique\" pieces like bishops (the model distinguishes between the dark squared bishop and the light squared bishop) and queens<sup>1</sup>. However, we can extend the previous strategy for single pieces to pairs of pieces like knights and rooks. In particular, we can let\n",
    "\n",
    "$$\n",
    "u_{t} = (1 - \\alpha) \\: v_{\\text{companion}(t)}\n",
    "$$\n",
    "\n",
    "where $\\text{companion}(t)$ represents the square of the \"other\" knight/rook ($u_t = \\textbf{0}$ if the other piece was captured). This effectively renews the position of the companion as  $\\alpha(1 − \\alpha) + (1 − \\alpha)^2 = (1 − \\alpha)$. Therefore, up to normalization, we are left with a linear representation of both the piece’s own position and that of its symmetric partner.\n",
    "\n",
    "---\n",
    "<sup>1</sup> Perhaps it is naive to treat the queen as a unique piece. While underpromotion to a minor piece (knight, bishop) is extremely rare, it is not uncommon to promote a pawn to a queen. So, it is possible (and even likely) that queens also have \"companions,\" though this requires more investigation.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grid_search(head, piece_type, parity_req, num_games):\n",
    "    \"\"\"\n",
    "    Grid search for the best decay value alpha. Looks at all values between 0.1 and 0.9 inclusive\n",
    "    \"\"\"\n",
    "    lst = []\n",
    "    avgs = None\n",
    "    last = None\n",
    "    for i in range(9):\n",
    "        alpha = (i + 1) * 0.1\n",
    "        sims, last, avgs  = get_piece_sims(\n",
    "            interested_head=head,\n",
    "            alpha=alpha,\n",
    "            piece_type=piece_type,\n",
    "            num_games=num_games,\n",
    "            parity_req=parity_req,\n",
    "        )\n",
    "        lst.append((alpha, torch.mean(torch.tensor(sims)).item()))\n",
    "    print(f\"Sim between true contribution and average: {torch.mean(torch.tensor(avgs)).item()}\")\n",
    "    print(f\"Sim between true contribution and last: {torch.mean(torch.tensor(last)).item()}\")\n",
    "    return lst, torch.mean(torch.tensor(avgs)).item(), torch.mean(torch.tensor(last)).item()\n",
    "\n",
    "def create_line_graph(data_points,\n",
    "                      avg,\n",
    "                      last,\n",
    "                      title,\n",
    "                      x_label=\"Decay factor (α)\",\n",
    "                      y_label=\"Cosine Similarity\"):\n",
    "\n",
    "    if not data_points:\n",
    "        raise ValueError(\"data_points list is empty\")\n",
    "\n",
    "    x_vals, y_vals = zip(*data_points)\n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    \n",
    "    plt.plot(x_vals, y_vals,\n",
    "             marker='o', markersize=6,\n",
    "             linewidth=2, label='EMA')\n",
    "\n",
    "    \n",
    "    plt.axhline(avg,  color='tab:green', linestyle='--',\n",
    "                linewidth=2, label='Average')\n",
    "    plt.axhline(last, color='tab:red',  linestyle=':',\n",
    "                linewidth=2, label='Last-move')\n",
    "\n",
    "    plt.xlabel(x_label, fontsize=16)\n",
    "    plt.ylabel(y_label, fontsize=16)\n",
    "    plt.xticks(fontsize=14)\n",
    "    plt.yticks(fontsize=14)\n",
    "\n",
    "    plt.grid(False)\n",
    "    plt.legend(fontsize=14)\n",
    "    plt.title(title, fontsize=20)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finding the Decay Factor (and cos sim analysis)\n",
    "\n",
    "It is theoretically challenging to find the decay factor $\\alpha$. So, we perform a grid search over $\\alpha \\in [0.1, 0.2, ... , 0.9]$ and empirically pick the $\\alpha$ that maximizes the cosine similarity with the true contribution $H_T$. For most pieces, we obtain high cosine similarities of at least 0.97."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Additional piece types:\n",
    "rook_lst, rook_avg, rook_last = grid_search(2, chess.ROOK, None, 100)\n",
    "create_line_graph(\n",
    "    data_points=rook_lst,\n",
    "    title=\"Rook similarities at Head 5.2\",\n",
    "    avg=rook_avg,\n",
    "    last=rook_last,\n",
    "    x_label=\"Decay factor (α)\",\n",
    "    y_label=\"Cosine Simlarity\"\n",
    ")\n",
    "queen_lst, queen_avg, queen_last = grid_search(2, chess.QUEEN, None, 100)\n",
    "create_line_graph(\n",
    "    data_points=queen_lst,\n",
    "    title=\"Queen similarities at Head 5.2\",\n",
    "    avg=queen_avg,\n",
    "    last=queen_last,\n",
    "    x_label=\"Decay factor (α)\",\n",
    "    y_label=\"Cosine Simlarity\"\n",
    ")\n",
    "dsbishop_lst, dsb_avg, dsb_last = grid_search(6, chess.BISHOP, 0, 100)\n",
    "create_line_graph(\n",
    "    data_points=dsbishop_lst,\n",
    "    title=\"Dark-squared Bishop similarities at Head 5.6\",\n",
    "    avg=dsb_avg,\n",
    "    last=dsb_last,\n",
    "    x_label=\"Decay factor (α)\",\n",
    "    y_label=\"Cosine Simlarity\"\n",
    ")\n",
    "lsbishop_list, lsb_avg, lsb_last = grid_search(3, chess.BISHOP, 1, 100)\n",
    "create_line_graph(\n",
    "    data_points=lsbishop_list,\n",
    "    title=\"Light-square Bishop similarities at Head 5.3\",\n",
    "    avg=lsb_avg,\n",
    "    last=lsb_last,\n",
    "    x_label=\"Decay factor (α)\",\n",
    "    y_label=\"Cosine Simlarity\"\n",
    ")\n",
    "\n",
    "knight_lst, knight_avg, knight_last = grid_search(6, chess.KNIGHT, None, 100)\n",
    "create_line_graph(\n",
    "    data_points=knight_lst,\n",
    "    title=\"Knight similarities at Head 5.6\",\n",
    "    avg=knight_avg,\n",
    "    last=knight_last,\n",
    "    x_label=\"Decay factor (α)\",\n",
    "    y_label=\"Cosine Simlarity\"\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Value Vector Analysis\n",
    "\n",
    "We repeat the previous value vector analysis. Since the value vector $z_t$ is the sum $-\\alpha \\: v_{\\text{from}(t)} + \\: v_{\\text{to}(t)} + u_t$, we should again be able to extract the to and from squares using a linear probe. We also show that the to and from contribution vectors are opposite directions, and construct the combined to/from dataset in the same way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "piece_head_pairs = [\n",
    "    (\"knight\", 6, (chess.KNIGHT, None, 0.6)),\n",
    "    (\"queen\", 2, (chess.QUEEN, None, 0.3)),\n",
    "    (\"rook\", 2, (chess.ROOK, None, 0.5)),\n",
    "    (\"dsbishop\", 6, (chess.BISHOP, 0, 0.4)),\n",
    "    (\"lsbishop\", 3, (chess.BISHOP, 1, 0.4)),\n",
    "]\n",
    "\n",
    "piece_state_stacks = {}\n",
    "piece_to_stacks = {}\n",
    "piece_from_stacks = {}\n",
    "\n",
    "for name, head, reqs in piece_head_pairs:\n",
    "    piece_state_stack, piece_to_stack, piece_from_stack = generate_to_from_stack(head, reqs, num_games=10000)\n",
    "    piece_state_stacks[name] = piece_state_stack\n",
    "    piece_to_stacks[name] = piece_to_stack\n",
    "    piece_from_stacks[name] = piece_from_stack\n",
    "\n",
    "    print(f\"Computed to/from data for: {name}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although the magnitudes of the to and from-square vector contributions differ, a single linear probe trained on the combined set still achieves high accuracy, confirming that the to-square and from-square vector representations lie in the same subspace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "piece_names = [\"knight\", \"queen\", \"rook\", \"dsbishop\", \"lsbishop\"]\n",
    "\n",
    "piece_to_probes = {}\n",
    "piece_from_probes = {}\n",
    "piece_to_from_probes = {}\n",
    "\n",
    "piece_to_from_probe_results = {}\n",
    "piece_to_probe_results = {}\n",
    "piece_from_probe_results = {}\n",
    "\n",
    "num_iters = 10\n",
    "\n",
    "for name in piece_names:\n",
    "    to_probe = create_linear_probe(probe_params, 64)\n",
    "    from_probe = create_linear_probe(probe_params, 64)\n",
    "    to_from_probe = create_linear_probe(probe_params, 64)\n",
    "\n",
    "    piece_to_probes[name] = to_probe\n",
    "    piece_from_probes[name] = from_probe\n",
    "    piece_to_from_probes[name] = to_from_probe\n",
    "\n",
    "    piece_to_probe_results[name] = []\n",
    "    piece_from_probe_results[name] = []\n",
    "    piece_to_from_probe_results[name] = []\n",
    "\n",
    "    vl, va, tl, ta = train_probe(to_probe, \n",
    "                    piece_state_stacks[name],\n",
    "                    piece_to_stacks[name],\n",
    "                    batch_size=64,\n",
    "                    num_epochs=num_iters,\n",
    "                    probe_fwd=linear_probe_forward\n",
    "                ) \n",
    "    \n",
    "    piece_to_probe_results[name].append({\n",
    "        \"val_loss\": vl,\n",
    "        \"val_acc\": va,\n",
    "        \"train_loss\": tl,\n",
    "        \"train_acc\": ta,\n",
    "    })\n",
    "    \n",
    "    vl, va, tl, ta = train_probe(from_probe, \n",
    "                    piece_state_stacks[name],\n",
    "                    piece_from_stacks[name],\n",
    "                    batch_size=64,\n",
    "                    num_epochs=num_iters,\n",
    "                    probe_fwd=linear_probe_forward\n",
    "                ) \n",
    "    \n",
    "    piece_from_probe_results[name].append({\n",
    "        \"val_loss\": vl,\n",
    "        \"val_acc\": va,\n",
    "        \"train_loss\": tl,\n",
    "        \"train_acc\": ta,\n",
    "    })\n",
    "    \n",
    "    vl, va, tl, ta = train_probe(to_from_probe, \n",
    "                    torch.cat([piece_state_stacks[name], -piece_state_stacks[name]]),\n",
    "                    torch.cat([piece_to_stacks[name], piece_from_stacks[name]]),\n",
    "                    batch_size=64,\n",
    "                    num_epochs=num_iters,\n",
    "                    probe_fwd=linear_probe_forward\n",
    "                ) \n",
    "    \n",
    "    piece_to_from_probe_results[name].append({\n",
    "        \"val_loss\": vl,\n",
    "        \"val_acc\": va,\n",
    "        \"train_loss\": tl,\n",
    "        \"train_acc\": ta,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def summarize_to_from(piece_to_probe_results,\n",
    "                         piece_from_probe_results,\n",
    "                         how=\"last\"):\n",
    "    rows = []\n",
    "\n",
    "    def _select(res_list):\n",
    "        if how == \"last\":\n",
    "            return res_list[-1]\n",
    "        elif how == \"best\":\n",
    "            return max(res_list, key=lambda d: d[\"val_acc\"])\n",
    "        else:\n",
    "            raise ValueError(\"how must be 'last' or 'best'\")\n",
    "\n",
    "    for piece_name, res_list in piece_to_probe_results.items():\n",
    "        r = _select(res_list)\n",
    "        rows.append({\n",
    "            \"piece\": piece_name,\n",
    "            \"probe\": \"to\",\n",
    "            \"train_loss\": r[\"train_loss\"],\n",
    "            \"train_acc\": r[\"train_acc\"],\n",
    "            \"val_loss\":   r[\"val_loss\"],\n",
    "            \"val_acc\":    r[\"val_acc\"],\n",
    "        })\n",
    "\n",
    "    for piece_name, res_list in piece_from_probe_results.items():\n",
    "        r = _select(res_list)\n",
    "        rows.append({\n",
    "            \"piece\": piece_name,\n",
    "            \"probe\": \"from\",\n",
    "            \"train_loss\": r[\"train_loss\"],\n",
    "            \"train_acc\": r[\"train_acc\"],\n",
    "            \"val_loss\":   r[\"val_loss\"],\n",
    "            \"val_acc\":    r[\"val_acc\"],\n",
    "        })\n",
    "\n",
    "    df = pd.DataFrame(rows)\n",
    "\n",
    "    df.sort_values([\"piece\", \"probe\"], inplace=True)\n",
    "\n",
    "    print(\"\\n=== Probe summary ({}) ===\".format(how))\n",
    "    print(tabulate(df,\n",
    "                   headers=\"keys\",\n",
    "                   tablefmt=\"github\",\n",
    "                   floatfmt=\".4f\",\n",
    "                   showindex=False))\n",
    "\n",
    "    return df\n",
    "\n",
    "def summarize_combined(piece_to_from_probe_results, how=\"last\"):\n",
    "    def _select(res_list):\n",
    "        if how == \"last\":\n",
    "            return res_list[-1]\n",
    "        if how == \"best\":\n",
    "            return max(res_list, key=lambda d: d[\"val_acc\"])\n",
    "        raise ValueError(\"how must be 'last' or 'best'\")\n",
    "\n",
    "    rows = []\n",
    "    for piece_name, res_list in piece_to_from_probe_results.items():\n",
    "        r = _select(res_list)\n",
    "        rows.append({\n",
    "            \"piece\":       piece_name,\n",
    "            \"train_loss\":  r[\"train_loss\"],\n",
    "            \"train_acc\":   r[\"train_acc\"],\n",
    "            \"val_loss\":    r[\"val_loss\"],\n",
    "            \"val_acc\":     r[\"val_acc\"],\n",
    "        })\n",
    "\n",
    "    df = pd.DataFrame(rows).sort_values(\"piece\")\n",
    "\n",
    "    print(f\"\\n=== Combined to-from probe summary ({how}) ===\")\n",
    "    print(tabulate(df,\n",
    "                   headers=\"keys\",\n",
    "                   tablefmt=\"github\",\n",
    "                   floatfmt=\".4f\",\n",
    "                   showindex=False))\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = summarize_to_from(\n",
    "    piece_to_probe_results=piece_to_probe_results,\n",
    "    piece_from_probe_results=piece_from_probe_results,\n",
    "    how=\"best\",\n",
    ")\n",
    "\n",
    "_ = summarize_combined(\n",
    "    piece_to_from_probe_results=piece_to_from_probe_results,\n",
    "    how=\"best\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Board Reconstruction\n",
    "\n",
    "We want to verify that the exponential moving average $Z_T$ is capable of recovering the piece positions. For each of the 64 squares, we train a linear probe to classify whether the square is blank or contains the piece. For probe details, refer to the pawn board reconstruction section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_ema_stack(head, piece_reqs, num_games=100):\n",
    "    \"\"\"\n",
    "    For each dot index, we look at all of the moves by the piece (specified in piece reqs) and\n",
    "    take the EMA of the head output of those moves. This information is bundled with the true head \n",
    "    contribution as well as the board state.\n",
    "    \"\"\"\n",
    "    piece_type = piece_reqs[0]\n",
    "    parity_req = piece_reqs[1]\n",
    "    alpha = piece_reqs[2]\n",
    "\n",
    "    sims = []\n",
    "\n",
    "    head_state_stack = []\n",
    "    ema_state_stack = []\n",
    "    piece_board_stack = []\n",
    "\n",
    "    for i in tqdm(range(num_games)):\n",
    "        piece_len = to_squares[i].shape[0]\n",
    "\n",
    "        dots_index_list = dots_indices[i].tolist()\n",
    "\n",
    "        cur_ema = torch.zeros_like(head_v[i][0, 0, head, :])\n",
    "\n",
    "        cur_piece_indices = []\n",
    "\n",
    "        for j in range(piece_len):\n",
    "            piece = piece_types[i][j][1]\n",
    "\n",
    "            square = to_squares[i][j]\n",
    "            rank = square // 8\n",
    "            file = square % 8\n",
    "            parity = (rank + file) % 2\n",
    "\n",
    "            meets_parity = True\n",
    "            if parity_req is not None:\n",
    "                meets_parity = (parity == parity_req)\n",
    "\n",
    "            if piece == piece_type and meets_parity:\n",
    "                cur_piece_indices.append(j)\n",
    "                piece_v = head_v[i][0, j, head, :]\n",
    "\n",
    "                # The board state at the dot includes the most recent black move. So, if the black move\n",
    "                # was a capture, we need to replace the previous information\n",
    "                replace = False\n",
    "\n",
    "                # update the ema\n",
    "                if indices[i][j] not in dots_index_list:\n",
    "                    cur_ema = alpha * cur_ema + (1 - alpha) * piece_v\n",
    "                else:\n",
    "                    # compute the true head contribution\n",
    "                    replace = True\n",
    "                    cur_ema = alpha * cur_ema + (1 - alpha) * piece_v\n",
    "\n",
    "                # find the next dot\n",
    "                dot_idx = bisect.bisect_left(dots_index_list, indices[i][j])\n",
    "\n",
    "                if dot_idx >= len(dots_index_list):\n",
    "                    break \n",
    "                \n",
    "                attn = dots_attn[i][0, head, dot_idx, :][torch.tensor(cur_piece_indices)]\n",
    "                v = head_v[i][0, torch.tensor(cur_piece_indices), head, :]\n",
    "\n",
    "                # multiply v and attn\n",
    "                act = torch.einsum(\"l, l d -> d\", attn, v)\n",
    "\n",
    "                # compute cosine similarity between act and cur_ema \n",
    "                sim = torch.nn.functional.cosine_similarity(cur_ema, act, dim=0)\n",
    "\n",
    "                cur_board_stack = board_stacks[i][j].clone()\n",
    "                for r in range(8):\n",
    "                    for c in range(8):\n",
    "                        parity = (r + c) % 2\n",
    "                        meets_parity = True\n",
    "                        if parity_req is not None:\n",
    "                            meets_parity = (parity == parity_req)\n",
    "\n",
    "                        if not meets_parity or cur_board_stack[r][c] != piece_type:\n",
    "                            cur_board_stack[r][c] = 0\n",
    "                        else:\n",
    "                            cur_board_stack[r][c] = 1\n",
    "\n",
    "                # the previous board stack was invalid because it did not include the black capture\n",
    "                if replace:\n",
    "                    piece_board_stack[-1] = cur_board_stack\n",
    "                    ema_state_stack[-1] = cur_ema\n",
    "                    head_state_stack[-1] = act\n",
    "                    sims[-1] = sim.item()\n",
    "                else:\n",
    "                    piece_board_stack.append(cur_board_stack)\n",
    "                    ema_state_stack.append(cur_ema)\n",
    "                    head_state_stack.append(act)\n",
    "                    sims.append(sim.item())\n",
    "\n",
    "    # print the cosine similarity for checks\n",
    "    print(f\"Sim: {torch.mean(torch.tensor(sims)).item():.4f}\")\n",
    "\n",
    "    return torch.stack(ema_state_stack), torch.stack(piece_board_stack), torch.stack(head_state_stack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "piece_ema_stacks = {}\n",
    "piece_board_stacks = {}\n",
    "piece_head_stacks = {}\n",
    "\n",
    "for name, head, reqs in piece_head_pairs:\n",
    "    piece_ema_stack, piece_board_stack, piece_head_stack = construct_ema_stack(head, reqs, num_games=10000)\n",
    "\n",
    "    piece_ema_stacks[name] = piece_ema_stack\n",
    "    piece_board_stacks[name] = piece_board_stack\n",
    "    piece_head_stacks[name] = piece_head_stack\n",
    "\n",
    "    print(f\"Computed ema data for: {name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_probes = {}\n",
    "\n",
    "for name in piece_names:\n",
    "    board_probes[name] = create_linear_probe(train_params=probe_params,\n",
    "                                        num_classes=2,\n",
    "                                        has_rc=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name in piece_names:\n",
    "    vl, va, tl, ta = train_probe(\n",
    "        board_probes[name],\n",
    "        piece_ema_stacks[name],\n",
    "        piece_board_stacks[name].long(),\n",
    "        64,\n",
    "        5,\n",
    "        linear_probe_forward_rc\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, to see if the EMA scheme is similar to what the model is computing, we can apply the probe (trained on the averages) to the true head contribution. The accuracy remains high, which suggests that the EMA and the true contribution encode the piece positions similarly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name in piece_names:\n",
    "    print(f\"=== {name} ===\")\n",
    "\n",
    "    vl, va = test_probe(\n",
    "        board_probes[name],\n",
    "        piece_ema_stacks[name],\n",
    "        piece_board_stacks[name].long(),\n",
    "        probe_fwd=linear_probe_forward_rc\n",
    "    )\n",
    "\n",
    "    print(f\"Accuracy using the computed ema: {(100 * va.item()):.2f}\")\n",
    "\n",
    "    vl, va = test_probe(\n",
    "        board_probes[name],\n",
    "        piece_head_stacks[name],\n",
    "        piece_board_stacks[name].long(),\n",
    "        probe_fwd=linear_probe_forward_rc\n",
    "    )\n",
    "\n",
    "    print(f\"Accuracy using the true head: {(100 * va.item()):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Companion pieces\n",
    "\n",
    "For symmetric pieces like knights and rooks, we hypothesize that when one of the pieces is moved the position of the other piece is renewed. To test this, we train a probe to examine whether the square of the unmoved piece can be extracted from the value vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_to_board_stack(head, piece_reqs, num_games=100, include_dots=False):\n",
    "    \"\"\"\n",
    "    For each move of the specified piece type, this function returns the head's value vector, the to square, \n",
    "    and the board state\n",
    "    \"\"\"\n",
    "    piece_type = piece_reqs[0]\n",
    "    parity_req = piece_reqs[1]\n",
    "\n",
    "    piece_state_stack = []\n",
    "    piece_to_stack = []\n",
    "    piece_board_stack = []\n",
    "\n",
    "    for i in tqdm(range(num_games)):\n",
    "        piece_len = to_squares[i].shape[0]\n",
    "\n",
    "        dots_index_list = dots_indices[i].tolist()\n",
    "\n",
    "        for j in range(piece_len):\n",
    "            if not include_dots and indices[i][j] in dots_index_list: \n",
    "                continue\n",
    "\n",
    "            if indices[i][j] in dots_index_list:\n",
    "                assert frm_squares[i][j] == -1\n",
    "\n",
    "            piece = piece_types[i][j][1]\n",
    "\n",
    "            square = to_squares[i][j]\n",
    "            rank = square // 8\n",
    "            file = square % 8\n",
    "            parity = (rank + file) % 2\n",
    "\n",
    "            meets_parity = True\n",
    "            if parity_req is not None:\n",
    "                meets_parity = (parity == parity_req)\n",
    "\n",
    "            if piece == piece_type and meets_parity:\n",
    "                piece_v = head_v[i][0, j, head, :]\n",
    "\n",
    "                piece_state_stack.append(piece_v)\n",
    "                piece_to_stack.append(to_squares[i][j])\n",
    "\n",
    "                cur_board_stack = board_stacks[i][j].clone()\n",
    "                for r in range(8):\n",
    "                    for c in range(8):\n",
    "                        if cur_board_stack[r][c] != piece_type:\n",
    "                            cur_board_stack[r][c] = 0\n",
    "                        else:\n",
    "                            cur_board_stack[r][c] = 1\n",
    "\n",
    "                piece_board_stack.append(cur_board_stack)\n",
    "\n",
    "    \n",
    "    return torch.stack(piece_state_stack), torch.stack(piece_to_stack), torch.stack(piece_board_stack)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knight_state_stack, knight_to_stack, knight_board_stack = generate_to_board_stack(6, (chess.KNIGHT, None), num_games=10000)\n",
    "rook_state_stack, rook_to_stack, rook_board_stack = generate_to_board_stack(2, (chess.ROOK, None), num_games=10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_companion_data(state_stack, board_stack, to_stack, ignore_squares=[]):\n",
    "    \"\"\"\n",
    "    For symmetric pieces like knights and rooks, find the square of the \"other\" knight/rook\n",
    "    Assumes that there are at most 2 knights/rooks on the board\n",
    "    \"\"\"\n",
    "\n",
    "    train_indices = []\n",
    "    labels = []\n",
    "\n",
    "    for i in tqdm(range(state_stack.shape[0])):\n",
    "        board = board_stack[i]    \n",
    "        to = to_stack[i].item()\n",
    "\n",
    "        found = False\n",
    "\n",
    "        # check if there is another piece not on the to square\n",
    "        for r in range(8):\n",
    "            for c in range(8): \n",
    "                idx = r * 8 + c\n",
    "                if idx == to or idx in ignore_squares:\n",
    "                    continue\n",
    "                \n",
    "                if board[r][c]:\n",
    "                    found = True\n",
    "                    labels.append(idx)\n",
    "                    train_indices.append(i)\n",
    "\n",
    "            if found:\n",
    "                break\n",
    "\n",
    "    inp = state_stack[torch.tensor(train_indices)]\n",
    "    labels = torch.tensor(labels)\n",
    "\n",
    "    return inp, labels\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knight_inp, knight_labels = construct_companion_data(knight_state_stack, knight_board_stack, knight_to_stack)\n",
    "rook_inp, rook_labels = construct_companion_data(rook_state_stack, rook_board_stack, rook_to_stack)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For knights, the probe achieves an accuracy of \\~91%, whereas for rooks, the probe achieves an accuracy of \\~80%. This means that while there is some information of the companion piece's location, the signal is not perfect.\n",
    "\n",
    "As for the exact mechanism, we believe that Heads 4.1, 4.3 and 4.6 could be responsible for passing information of the companion piece, though more investigation is needed.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knight_companion_probe = create_linear_probe(probe_params, 64)\n",
    "rook_companion_probe = create_linear_probe(probe_params, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "knight_companian_vl, knight_companion_va, knight_companion_tl, knight_companion_ta = train_probe(\n",
    "    knight_companion_probe,\n",
    "    knight_inp,\n",
    "    knight_labels,\n",
    "    64,\n",
    "    16,\n",
    "    linear_probe_forward\n",
    ")\n",
    "\n",
    "print(f\"Final knight companion validation accuracy: {100 * knight_companion_va:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rook_companian_vl, rook_companion_va, rook_companion_tl, rook_companion_ta = train_probe(\n",
    "    rook_companion_probe,\n",
    "    rook_inp,\n",
    "    rook_labels,\n",
    "    64,\n",
    "    16,\n",
    "    linear_probe_forward\n",
    ")\n",
    "\n",
    "print(f\"Final rook companion validation accuracy: {100 * rook_companion_va:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Open Questions and Future Directions\n",
    "\n",
    "Although we believe we have uncovered the main circuits responsible for board reconstruction, some important details remain. First, it could be interesting to see how the transformer find the \"previous move\" of a piece, as this requires knowledge of how a piece moves. Whereas for bishops a simple (albeit slightly incorrect) mechanism like \"find the last bishop move on a dark/light square\" is sufficient, it could be interesting to see how the transformer handles other pieces.\n",
    "\n",
    "For the pawn average head, it is unclear how the transformer deals with scaling and how it incoporates the initial pawn positions. More specifically, it may be worthwhile to investigate how the heads use the ; token at the beginning of each game.\n",
    "\n",
    "The EMA heads are even more complicated. First, it is theoretically challenging to find the decay factor $\\alpha$. Moreover, the heads do not have clear cut piece responsibilities. Even though most piece types have a \"main head,\" sometimes other heads will still attend to their moves. Also, the companion piece mechanism is also unknown, and may not be accurate enough to support the hypothesized piece persistence.\n",
    "\n",
    "Finally, although a large chunk of the attention comes from the rank token of the moves, a significant portion of \"ambient\" attention is spread out throughout the rest of the tokens. It could be interesting to explore why this is the case.\n",
    "\n",
    "### Future Directions\n",
    "\n",
    "In this work, a lot of the comparisons made between hypothesized methods and the true head contributions relied on cosine similarities. A more direct approach worth exploring could be activation patching. However, getting the correct scale of the contributions could present a challenge.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "thesis4",
   "language": "python",
   "name": "thesis4"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
