{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# to import functions from `patching_utils.py` and `plotly_utils.py`,\n",
    "# we need to add the repository directory to the system path.\n",
    "current_dir = os.path.dirname(os.getcwd())\n",
    "print(current_dir)\n",
    "if current_dir not in sys.path:\n",
    "    sys.path.append(current_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import Tensor\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "t.set_grad_enabled(False)\n",
    "\n",
    "from jaxtyping import Float, Int, Bool, Union\n",
    "\n",
    "import circuitsvis as cv\n",
    "from IPython.display import display\n",
    "from plotly_utils import imshow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from patching_utils import batched_get_path_patch_to_repr, batched_get_path_patch_to_head, prepare_data_for_fwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_path_to_repr(\n",
    "    receiver_layers: list[int],\n",
    "    receiver_input: str,\n",
    "    model: HookedTransformer,\n",
    "    data: list[dict],\n",
    "    begin_layer: int,\n",
    "    batch_size: int\n",
    "):\n",
    "    t.cuda.empty_cache()\n",
    "    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_repr(receiver_layers, receiver_input, model, data, begin_layer, batch_size)\n",
    "    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)\n",
    "    results[:begin_layer, :] = 0.0\n",
    "    end_layer=max(receiver_layers)\n",
    "    results[end_layer+1:, :] = 0.0\n",
    "    visualize_heatmap(results)\n",
    "    return results\n",
    "\n",
    "def analyze_path_to_head(\n",
    "    receiver_heads: list[int],\n",
    "    receiver_input: str,\n",
    "    model: HookedTransformer,\n",
    "    data: list[dict],\n",
    "    begin_layer: int,\n",
    "    batch_size: int\n",
    "):\n",
    "    t.cuda.empty_cache()\n",
    "    patched_logit_diff, normal_logit_diff, contrast_logit_diff = batched_get_path_patch_to_head(receiver_heads, receiver_input, model, data, begin_layer, batch_size)\n",
    "    results = (patched_logit_diff - contrast_logit_diff) / (contrast_logit_diff  - normal_logit_diff)\n",
    "    results[:begin_layer, :] = 0.0\n",
    "    end_layer=max([head[0] for head in receiver_heads])\n",
    "    results[end_layer:, :] = 0.0\n",
    "    visualize_heatmap(results)\n",
    "    return results\n",
    "\n",
    "def visualize_heatmap(results, info=\"\"):\n",
    "    imshow(\n",
    "        100 * results.t(),\n",
    "        title=\"Path patching results {}\".format(info),\n",
    "        labels={\"x\":\"Layer\", \"y\":\"Head\", \"color\": \"Logit diff. variation\"},\n",
    "        coloraxis=dict(colorbar_ticksuffix = \"%\"),\n",
    "        width=500,\n",
    "        height=400,\n",
    "    )\n",
    "\n",
    "def visualize_heads(heads, model, data, mode=\"contrast\"):\n",
    "    data = data[:4]\n",
    "    t.cuda.empty_cache()\n",
    "    normal_input, contrast_input, _, _, normal_cache, contrast_cache, _, _ = prepare_data_for_fwd(model, data)\n",
    "    \n",
    "    if mode == \"contrast\":\n",
    "        cache = contrast_cache\n",
    "        tokens = model.to_tokens(contrast_input)[0]\n",
    "    elif mode == \"normal\":\n",
    "        cache = normal_cache\n",
    "        tokens = model.to_tokens(normal_input)[0]\n",
    "\n",
    "    attn_patterns_for_important_heads: Float[Tensor, \"head q k\"] = t.stack([\n",
    "        cache[\"pattern\", layer][:, head].mean(0)\n",
    "            for layer, head in heads\n",
    "    ])\n",
    "    \n",
    "    display(cv.attention.attention_patterns(\n",
    "        attention = attn_patterns_for_important_heads,\n",
    "        tokens = model.to_str_tokens(tokens),\n",
    "        attention_head_names = [f\"{layer}.{head}\" for layer, head in heads],\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.91it/s]\n",
      "WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model mistralai/Mistral-7B-v0.1 into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda:2\"\n",
    "model_name = \"mistralai/Mistral-7B-v0.1\"\n",
    "model = HookedTransformer.from_pretrained(model_name, device=device)\n",
    "model.set_ungroup_grouped_query_attention(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "from data_utils import process_dataset, read_jsonl\n",
    "\n",
    "setting, nmax, offset, n_icl_examples = \"setting1\", 9, 1, 4\n",
    "filename = f\"../data/addition/{setting}/addition_nmax{nmax}_offset{offset}.jsonl\"\n",
    "data = read_jsonl(filename)\n",
    "processed_data = process_dataset(data, n_icl_examples=n_icl_examples, offset=offset)\n",
    "\n",
    "# load 4 examples for demo\n",
    "processed_data = processed_data[:4]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Path Patching Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "batch: 100%|██████████| 1/1 [05:15<00:00, 315.66s/it]\n"
     ]
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "https://plot.ly",
        "staticPlot": false
       },
       "data": [
        {
         "coloraxis": "coloraxis",
         "hovertemplate": "Layer: %{x}<br>Head: %{y}<br>Logit diff. variation: %{z}<extra></extra>",
         "name": "0",
         "type": "heatmap",
         "xaxis": "x",
         "yaxis": "y",
         "z": {
          "bdata": "18UguTtjh7dnPUC8KvKjug/a/TnDKHG7O2OHuJmSb7vu1Fi7NJjFvAHzyrwOi1O7R/AcvA+huTsUs9i8gknSvKlht72HgtW9+7eDPUEBXDpM90g7DovTuqaewrlw2Q85WjHtuyR3rbqsUn07j4+Kux73DT6oXTC+RC2oPH+gsj87Ywc5HeZQu6aeQrkP2v0618WgueUKODoL3kS8WSZ6O0ZDjjvfj0G71AdVPbpe0DoSWTu8Z2yyPOJHwz2ZY/09+jekPnpakL5RmIm/E28hPEjv+7xWngs7T0ZzOixwwDwLPCm7Es0FO0EB3LgGzKU8ZfkIO0mSuLtCSGi/BmUZwA/a/TjN5Rm79p75udkUyzhaPOC5cii6uZUtX7tUH848pIhcu9ngLz3wACW6gNYovHcBlTyn3VC9foSZPf/iLj3D1iA9zv0zP6p+6j6wibw7iuWhPJUt3zqsUv27O2MHOmcyTTo+sjE7WjxgOpe3D7z2B9G8S1WtPCFVJsDcVBZAQQFcudkUy7jdQJe7wNnGOtkUy7jzT0+6r0WFOxZg57tw2Q+5Vp4LPUDr9TvTmdQ7AqmIPeF2tTz4UVK+bK3DuoIlU7zt57m+ZR4qPn6R0LznKxG8Y9ivvI6y6DkHEN262RTLt+T0UbteaKw6pp5Cu+K7DbvSYBC7+8ycwLb8JD5yKLo4O2OHucDZRrpyKLo6xD5XO5Lpp7tHcd87iRSUuT2nPjzQqA69eeiIvAXBMj0iKIO+lQzFPTftub3nHbm9878SvkU4m7vC7yy7JHctu3Nh/jteaCy7von7PDk3uzuuaGM8pIhcu1CvSr24Gpm7KNHKuzPCDr2sC/PBftKBPgs8qbjwAKW5dXdkucDZRjqQMSa8ZvjnvNFKKju5vDQ8IDV7u+uFrjroWeI6Pae+u6jt7DgYsva9geD6vPphbrwJWl498aSDvsBCHjyTIw295W1FPZDOGD24Gpk7ZOMiOpD3QL3/Rby6Ali7PucrkTz/rpO8jdcIPVHJNz59R8e/O2OHNxmBQLp1d2S6yrlNO8mACbpbGYI7TPfIO/AApbq4Gpk7sOcgPJ58Bz2rsQI9Q4sMvWzcNT4yj/a+0eFSPGbVibztPTC9qtRgvYSkkLxuXbe9pIvBvt1AF71GQw47cb2evvvgq7sMUg+9vYqcuvhagT0Oi1M7McSyPgN6Fj4P2v2418UgOsij5zqVCgG6sIm8OzFJm7xkEXS8Y9ivu2QR9Ds6qmQ93GN1OmuXXTyux2m+8W6lvfVlNbudOhS+tw+mPLrCDj/pYfA9T1Hmu+Bgz7z8H7q+spQvOv9FPLqS3rQ75aFgPPnKxbpgt1Y7ugyAvOcgHjxJfIE/Zb+jvAs8qTjZFMu4QQFcuuhZYjqOsui6qxm5OjtjB7fS97g8uvaZPCBwAbzyosC8vp2dvdna5byaqFW6okSlvIwAMT33Qwm/J8ZXu4Q7uTzQAgDAN/ODOwGQPb1yKDq5Vbb2O7AK/zweZQ68rjqSPMgMP7wLavo7ylD2POoG68CaAHC+Sqgeur2teju141k7dVQGuzBtGjraiBW84rsNuljttboBleY72dsGPQJyiDpw4Z29JHetOoYXOj3+0pI8Sc5hPjfzAzwfzUQ9yziLvLWLPz5AnIo9aFFjPqRai7wcrYw6ZmcJPeuz/ztK15C91rRjvpqoVTpA9ug7tw+mPe93pT8LPKm4O2MHN9kUyzdWnou5WO21OboMAD2ZbxE76difvDJUjrwoAL08xKcuvBs/DD2jT5i6kdPBO0ccKr4V5/O9H2q3vYmjrj0BivO8K1pavztjB7draYw78jlpO6NPmLr+oyA97+o+OxQcsLtrolC897UAPVaei7qAz8LAVQmZwDtjB7c+sjG5Vp4LunDZD7oUHDC7ZO6VO4pC5TuZbxG7qmwqPLPY5jslsHG7DMAPvSBA7jvuEaM9EEh+vSOd8D37uwq+yN3MPMXsBr32pEM+ndQhu7gaGbwiJZ49IigDOnwIwTuJoEk8NF7gPJOGGr2mnsK6PJHYO5oAcD3FPt+/CzypuFaeCzpyKLq4xD7Xu/WfmjwkbLo7xo4iPdRrg7uiOTI7349BPcVUvTt+gTQ9ha+DvT0QFrxn1Gi8dVNlPTUA/Lym2Ke8fNnOPI6y6LttvgA9evLauq0vnzvtm5S74rsNuhc9CTz74Cu734TOuwWSwL2EmR087630vXESVDw7Y4e4qO1suQs8KTgJmg28sIk8vMDZRjpyKLo5KonMPOK7DbsGnbO8F8N0vRB9ujzTwx09BOraPYvBIryC94G8pfymvGq+wT0kdy08ImFHu53JrruOdr89Zz1APPkoKjx/V+u7JerWvOGlJzsOi1M6R7PSPacSDby1v9o8rLBhPKNPmDjZFMu3D9p9uRytjLoHEN27l1krumbViTw7Y4c5jO9zPKyNAz2XTji9lTn0vrb/CT0nLy88UjjavI+4sr1odeI+VIglu3V35DmB4Rs83dhhvgt1bbvlCri7DFIPPBZV9LqgGNm7n7r0vEwmOz1jnko9IUvhuuwtlD2crX69AAAAgEEB3LgP2v04IUvhOqaeQrlnMs26PZzLOxHlcDxPRnO6EDhivHDZDzxK1m88QMgXuzvvvLysfEY9PDSVvJmd4rvPaF+9XFJGu4L3gbyWqTc+DFIPuwrI3ryf6gc7iDfyOqo60z78GfA7Dv8dvOoMOz00XuC8qkMCvqwkLD2EmZ26PrIxOWwWG7zwAKU6zynRuya8hbw+srE5O2OHN2C31jriUja8KvKju4KDNz3btuY8RT1EvRFT8T2h0h2+uWPZvilFFTzMCPi7qTEjvooTcz02IVU7FjKWugs8Kbkw+U88YKxjPOwjQz6jcvY664Wuuzfzg7stzjQ/va36uvAApTnZFMu4D9r9uPAApbkt5Io82RTLN4xjvrptuDa8EsKSOzvvPLzUpMc7awA1vNUYEj0ZgwM+aK0DPn+Hf76+Xo++6TaEugksDb28f6m7jrCkPcLejr5w2Y85g7x7ujRpU7xrC6i8kWi2v8LvrLsp3L27fimau9SaAcDg93c9D9p9uqNPGLlw2Q85UVxZu8ijZ7udya47PrIxurUosjyVCoG6Kdw9u83lGbu5JYy7QMgXPEXPQ7xO00m+lj6cPvGXzbz5ykW7OoeGvQY6Jr1xElQ72yqxuyEdkLsXA6S8UC6IOyKC4T01Umu/eVA/PBc9ibzN5Zm8RqWHwGb4Zz1ogXc6PrIxOVsZAjvzT0+664WuuudDfDtKy/y6FlV0vAt17bviu406pFqLu4RtkL0p3L089CYnvQJsPr0SyiC+RC2ounVOPD7ji/o8m1VkvCxwwLyYe4Y+ZsqWPCIog7mP4do8yrnNO1waJD6e35Q77tRYO+pvyDt6Ic28o0+YuajtbLhBAdy4cii6OOhZYrr74Cs793ubOg/afTtyHcc7tw8mPIofh7vx0TI9lJY2PEX+tTwWMpY78PWxPawkLL0wCo09jyPOveVzD7xYhN69g/E3vaPmwDwuV7Q8iRSUOtj+ZDt8zts993ubOmcyTbzyOek7jGO+PCLKHjwkrFo/O2MHtzbokLqqpe68eKOwPHIourp0PiC7g7x7uxzQ6jqXwgK9u9IaPHygCr09BSM8C3Xtu6aeQjmg7zC9QNMKvfd7mzpgSdY9IUvhunTbEr2+gYy+g7z7PDlCrrsiYUe7IigDunaCVzy5Wae9s0E+vBh8l73MoME89DIsP6qmDz69ipy6O2MHOJLeNDtELSg7uSUMu8vPMzrGX7A7++Aru0baNrxsrcO6mWP9vBQUoj0fx3o9r9wtvYO8e7qFRiw8GmMLvaxSfbzgmrQ8hbQsPSmPBj/JgAm8lQqBuoz6ZryOsug5S+zVvEjMnT2ylC87geD6PKNy9rpQIxW8mOahvqjt7Lh1d+S5ZBxnuzk3u7pk4yI60UqqOsmAibqd1KE67+o+u4vBIjxxEtQ7MoOAu96ul70+4aM8Tg0vvFkDnDv/Rby72KGhPUEMz7tF/jW9A6DZvfAApbm149k63GN1uWuXXbs0ng+9rI2DPBBDVTyaqFU7gDNsPSZaGj7Io2c7g7z7unDZDzkP2v24l1krutfFIDp+KZq7ZxQYPYTS4buX5WA97ZuUO+UKOLu85Zu9ad/bvQUpab2IZmQ9/PaRuo6ndbwWL7G9XFLGvCnih71Jkji86TYEu/AApTzasT29lkNFO41usTt5uZa7CY8avJu+OzyEmZ26i9fpPov1vT35ysU63GN1OoO8ezqvRYW72dsGOz6yMTmEmZ06IUthO8mAiTpY7TU7bwfhPNL3ODxEW3k8+Hr6vWWPD77vKT2/3HArvqxN1D6+Zh2+IsoePOn7fTx4ozA7DotTO9oliD3iu427vmO4vZdZq7oSzQW7cRJUu5U40jsgRRY+KVCIvHDZDznDKPG5Vp6LOXG0b7wOi9M6LkFOOxZV9DqG6Ee7hUYsvYkUlLkUEb07MRUAPVVkpj1uxo49jhIQPn3mhL78GXC7mz+dPvVlNTvI0tk8KGhzPGh3Jb1M98g6gGLeuwtq+ruvfkk70DREuyIogztBAVy5gNYoPMlOsj0tSdy918WgOTULb7rZFEu4qO3suog38rrZFMu364UuO6JEJbzSjmE9djCHvOPp3jsTnhM9hXUePfD1sTxjP0Q+l7FFPXyfaTxfkpW+3XlbuzqHhrxHfFK8yKPnunz9zbuo+YC8+b/SO7tAG71AyJc7TPdIuoKOKjxrolA8MnU4vyAHqjzZFMu4jrLouajtbDl+Hie6pp7COYbox7r3e5s6VcHpun1M+Dugl5a8+30ePYO8e7ylzbS8dVSGu8fH5r3t2iI9lTjSOzgC/70Q46w9Zw7OvNVVPT/oWeI6Tg2vO9JVnTv74Cu7j48KvAs8qbo6dkm96TYEPDULb7r/og4/EVQSP6NPmLhPRnO6IigDuhmBwDohHZC7DovTug6L0zuMY746ZOMiOrCJvLsBx729SPVFPYrlobyd1KG6A6vMPLqdXr2M1CM+QQFcOun2VD2rGTm8d8VsvuuFrroCA2e9rCSsvNj+ZLvdS4q7jyYzvDpNobtX4sI7WO21OQP4kj6/cZC8o08YuQXBsro7Ywc6dVSGOtkUS7j9JGM8/PYRu4boxzsMUg87MrFRvTQwD7x2jco8ul7QOiBv4LyMaQg++Fb7vendSL3pNgS8CFU1PWNBh7xsRY08PZzLu9eRhb0sCIq8hJmduifGV7u9w+C73GP1OmuiUD0iYce7JvAgvdMdk78kdy26J8ZXuuuFrjriu426QQHcOBHl8LvcY3U734/BOhCsLDypyo48lCg2PbDm/7xF/jU9DSJ8PEczEr9Cdaa8DIGBvA6ckL5Uudu9cNkPuX/inT7MCHi725OIvILsDj2XWSu6UC6IO+oRZDy9w2A8oZb1vIKDN7ynrl69IX98PQ==",
          "dtype": "f4",
          "shape": "32, 32"
         }
        }
       ],
       "layout": {
        "coloraxis": {
         "cmid": 0,
         "colorbar": {
          "ticksuffix": "%",
          "title": {
           "text": "Logit diff. variation"
          }
         },
         "colorscale": [
          [
           0,
           "rgb(103,0,31)"
          ],
          [
           0.1,
           "rgb(178,24,43)"
          ],
          [
           0.2,
           "rgb(214,96,77)"
          ],
          [
           0.3,
           "rgb(244,165,130)"
          ],
          [
           0.4,
           "rgb(253,219,199)"
          ],
          [
           0.5,
           "rgb(247,247,247)"
          ],
          [
           0.6,
           "rgb(209,229,240)"
          ],
          [
           0.7,
           "rgb(146,197,222)"
          ],
          [
           0.8,
           "rgb(67,147,195)"
          ],
          [
           0.9,
           "rgb(33,102,172)"
          ],
          [
           1,
           "rgb(5,48,97)"
          ]
         ]
        },
        "height": 400,
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermap": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermap"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "title": {
         "text": "Path patching results "
        },
        "width": 500,
        "xaxis": {
         "anchor": "y",
         "constrain": "domain",
         "domain": [
          0,
          1
         ],
         "scaleanchor": "y",
         "title": {
          "text": "Layer"
         }
        },
        "yaxis": {
         "anchor": "x",
         "autorange": "reversed",
         "constrain": "domain",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "Head"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_resid_post_31 = analyze_path_to_repr([31], \"resid_post\", model, processed_data, begin_layer=0, batch_size=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "batch:   0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "layer: 100%|██████████| 31/31 [05:04<00:00,  9.82s/it]\n",
      "batch: 100%|██████████| 1/1 [05:04<00:00, 304.66s/it]\n"
     ]
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "https://plot.ly",
        "staticPlot": false
       },
       "data": [
        {
         "coloraxis": "coloraxis",
         "hovertemplate": "Layer: %{x}<br>Head: %{y}<br>Logit diff. variation: %{z}<extra></extra>",
         "name": "0",
         "type": "heatmap",
         "xaxis": "x",
         "yaxis": "y",
         "z": {
          "bdata": "qO1suEEB3DhODS873GN1OdxjdTk7Ywe6CzwpONfFILqXWau6QQFcOiASHbxQIxW7fAjBO2TulTtL7NW7SqieOsDZxrv5Mx08jGM+Oz6yMTnwAKW5lQoBur2KnDkq8iM7DovTOSR3rTlVwek6fh4nvFaeC7vSYBC7AAAAAAAAAACo7ey4/0U8OpUKAbqjT5g4CzypuHIoOrnA2UY6ZOMiusZfsLt1VIY6ImFHO9kUSzn4kQE7x2ojPOcgHjsDq0y7IiiDu/tJAzxQLog7MoMAu1aeizm9ipy5DotTOcvPM7xbGYI6T0ZzOj6yMblw2Q87wyhxOgGVZroAAAAAAAAAAPAApbmo7Ww5cii6ONxjdTkdt948o0+YuNxj9TqH87q74d5rO6sZOTxgfhI7D9p9OaNy9jrdbmi8uUhqPJwyBrwc0Gq66TYEPLpeUDyDvHu64rsNOjboELsXPYm72QnYO4SZHbsTBkq7JHctu7kljLtyKLo4QQHcOgAAAAAAAAAA2RTLt9kUS7gS+9Y8qO3sOFAjlTrXxSC5yrnNu9kUSzryFgu8wyjxOgs8KTm5JQw8PJHYu7KfortWnou74+neu6ji+TtKqJ47CSZDvKnKjjt4ozC764WuujtjBzizqpW7Vp4Lugs8qTlM90g6FmvaOztYlLsnxtc5AAAAAAAAAACjTxi5o0+YOHIoujnZ2wa7O2OHOK9FBboCcog6ZOOiumcyTbpkHGc7yYCJug/afbmgI8y6PrIxuQ7/HbwNXYI7o0+Yuuk2BLu0tYi7cii6OPNPzzrZFMu4o0+YORmBwDr/RTy6U6ykPHijsLqIN3I7YKHwu1YoDcAAAAAAAAAAADtjB7ijT5i4wNlGuobox7oLPCk4l1mruiFL4Tq9ipy6J8ZXunijsLoiKAM7evLaOjtjh7cS8GM8LDZbPKY1a7z2nvk6aV6ZOxYylrqEmR07o0+Yu1sOjzvLzzM6O2MHNysraLujTxi5mW8RvPaeebuo7Ww50EYTvwAAAAAAAAAAO2MHuKjtbDgyvEQ6O2MHuAs8qbhXSxo8O2OHOtxjdTmjT5i64ruNuvaeebqf6gc7mqhVO1dueDwutZg8n+oHuwJyiDr4kYE719ATPA/afbmbvjs7qmwqvDXdHTwLPKm5ZzLNPLYEs7sLPCk5ih+Hu+ADDLzEKBrBAAAAAAAAAABBAdy4O2MHOkEB3LhGQw67dXfkOb2KHLriu406SqgeOi+Q+LrisJo7k/Qau5T/jTtBAVw5WQOcuyIoA7o5N7u6u2nDO4TSYTs77zy8vYqcuXzZzry2bQo8O2MHuQs8qbnXxSC7IUvhOkd80jo1C2+6PIbluksTN8EAAAAAAAAAADtjh7hY7TW6l1krunDZDzkOi9O6CzwpudkUyzguQc46RWZsu7KULzpgT6C8AmeVO1AuCDxM90i6WjzgOZ4N5jtqjOq79p75uYWvA7uQPBm8WO21OT6yMTrrha66rmjjOhQRvTxBAVy58hYLPHijsDohHRA8vYocOwAAAAAAAAAAQQHcOBzb3buVCgG6cNkPOwAAAIDcY/W5OTc7O1JyvzqF3dQ7bwfhu/aeeTuEpBA8ncmuO84eXrqf6gc7L5B4OyR3LbrwXgk8MoMAO/AApTq5JQw8q3cdvEEB3DgLPCk4ss1zOz6yMTk7Ywc4qw7GOyFL4boCcgi8AAAAAAAAAABw2Q+5O2MHuKNPGDmjT5i58ACluQs8KTg7Ywe3xD5Xuy5XNDuJFJQ6jGO+OcDZRjo+sjG66c0sPI/ITjtelv27kdPBO8QQhrtgt1a7SSlhPAGVZjqvfsm7o08YOaaewrogEp07ZOMiOh3mULuylK+9SqgeuzTSqjsAAAAAAAAAAAs8KbiMY766cii6OdkUS7kP2n26o0+YuJqo1briuw26nt+UOztjh7k26JC7Vp4LOjlwfzs9BSO8XV25uws8qTg26JC6RkOOO1AjlTuaqNU7IiiDOdkUSzoG15i7/PaROvz2EToyvES6iDdyOxh0Cb5ELSg7SteQvAAAAAAAAAAAo0+YuAs8KbgP2v0480/PuR/8trppXhm7YKFwOy+QeLtyKLo6MG2auvAApblraQw8wNnGus4e3jq9w+C7PrKxO9uTCDww1dA86TaEOg/a/bnSYJC7wgWTOkOLjLxyKLo4IiiDud5Nzr2G6Ee7houEPKjt7LgCcgi7AAAAAAAAAAAAAACAAAAAgAs8KTlBAVy5cig6OTtjBzdw2Y+5jGO+OdxjdbnDKHE6j4+KOyWw8buo7Ww4+crFupdOuDtScr+6C2r6uxZrWjtPRnM6iRSUuqaeQjknxte56FliulAjlbomjRM7LkFOO380DbvRSio7wViEvGiB97sAAAAAAAAAADtjB7ijT5i4O2MHuB/8NjtQLog7wNnGOjtjhzrZFMs6T0ZzuuhZYjt0PiC7r0UFuiJhRzvUDZ88wzSFvLKfort0SZM7S7ORuw6LUzk7Ywe6va16O4SZnboN9Cq8jGM+OqnKjroTbyG8O2MHOe/qPrtUiKU8wgUTOwAAAAAAAAAAcii6uAs8KbijT5i4CzwpuEQtqDrXxSC5O2OHOIBt0TpyKDq5yYAJO8DZxrkHEN27cpwEvMLvLDvIo2c7tVekPEfwHLx1VIY680/POXIzrbugGNm7j48KO0d80ruMY745l1mruq46krtHfNK6G5emu1jttbrRSqq6AAAAAAAAAAAP2n25D9p9OTtjh7io7Ww5cNkPOWcyTboOi1M5Vp6LOfAApblyKLq6NuiQurlIarsG+na7EwbKOlkm+rs7Ywc3Dv8dPEEB3LimnsK6AnKIuyfGV7uJFJQ6yHWWO9xjdbkG+na7R4dFvH80DTvZFEu4ebkWu9+PQTsAAAAAAAAAAFaei7lyKLq4O2MHudkUy7hBAVy5j4+KOljttbl7lHY8O2MHOsVUvTpM98g6ZOOiugJyiDtKqJ47Lle0O4xjvrtogfc6ikLlO1sZgjpuwym7NuiQu6nKDjvZFEu43GN1OQs8KTkOOYM8eEVMvXIoOrmDvPs6OTe7ugAAAAAAAAAAPrIxOkEB3LhBAdw4goO3uyR3LbpBAdy4D9r9OF7RgzxWnou5PIblujyG5ToBlWY7BwXqO1ox7bs4LMg7cig6OtkUSztD7hm9qdWBuxmBQLtY7bU5JHetuT6yMbp53HS77ZuUuwt1bTthwsm8g7x7OyfG17nzT8+5AAAAAAAAAACo7Wy52RTLOAs8qTh1d+S5g7z7OnDZj7l1VIa6uwBsvNfFIDlKqB46DFIPu+K7DTqOsug6mycTPCIoAzz74Cs8j4+KOoKDtzs+sjG7hujHOnIouriF3dQ7FjIWu31M+LukWos7LkFOO27x+rtw2Y+5DotTugs8KbsAAAAAAAAAADtjB7e9ipy52RTLNxmBQLrJgAk6spQvOrKUr7qylC8618UgOg6L0zn89pE7N5WfPCR3LbrEScq7cNkPOzyG5TpgiYW7sVrKvDBtGjtl+Yi7ALoGvDtjh7eylC87CzypOXDZD7nD+f68O2OHuFaeCzzv6j47tLWIOwAAAAAAAAAAqO1suCIog7meDeY72RRLulo84LncY3U5dXfkuXm5lrt1VIY7wu8sO5lvkTtk46K7WxmCuoxjPrtBAVw6Ly4MPaI5srt1d+Q6PZxLO0d80rp1d+Q5PrIxOkrL/LqaqFU6CzypuUEBXLqsUn27FjKWu50xZb26XtC6AAAAAAAAAABw2Q862RRLOMJi1rw+srG5qO3sOZqoVbrDKPE5J8bXucmAibvwAKW7ZzJNu53JrrtWngs7DIBgPDKm3jvj6d68++ArO0FqMzxZYQC8iP6tO5/qBzupyo46P72kOw6L07mmnkK6OGjxvTBtGjpaPOA5/S/Wu72KnDkAAAAAAAAAADtjB7cLPKm4CzypuSR3LToLPCk4o0+YOFjttbpyKLo5AnKIuqeptTtM98i7JHetumLNPLsWMha7QQFcO5UtXzwWMha6XwpIveUKOLqiObI7l7cPPHV3ZLlyKLq5QQHcOvd7m7ppU6Y7va36uljtNTo1C287jGM+OwAAAAAAAAAAO2MHudkUy7c7Ywe3Vp6LuaNPGLmd1KG6MrzEu+hZYjvA2ca5qO1sOkDIF7togXe6jrLouu2mhzviuw08qO3suF3GEDz89pG7O2OHu2cyzTqOsug5xOGTvGTjorrW8/E7wgWTusmACTpKqB66r0UFugOrTLs7Ywc5AAAAAAAAAACjTxi5WjxgO0EB3Dhw2Y86IiiDOQGVZrrOHl462RRLuMMo8Tn76x68NQtvOmlemTpfrYQ8zeWZu7DyE7ymnsI6T0bzOg/a/TpyMy08r9wtvJLeNLuCjiq82RRLuNsfvjtnMk26o08YuuuFrjrW6P66VtdPu318CzwAAAAAAAAAAJUKATrcY3U5gD+AOwvkDr1aPOC5O2MHtz6yMbsOi1O5D6E5u54NZjtQLoi75Qo4u+uFrjoaxfe7Rto2PHV3ZDl/NA07bK3DOrS1iLtX4sK7qO3sOPKtMzw+sjG5Vp6LOjpNITwLPCk5KvKjOi+QeLuaqFU7T0ZzugAAAAAAAAAAQQFcOXDZj7mo7Wy5PrIxOVaeC7oZgcC6EvBjO4kUFDoiYUc7WFYNPKNPGLvtmxS793ubuhMGSjsQrCw8xb0UPAbMpTvu1Ng8iRSUuv7GfryS3rS6N4qsPNxjdTmG6Ee7D9r9uDUL77qjT5i4O2MHOSR3rTpHfNI6AAAAAAAAAADZFMu43GN1OdxjdbkOi9M5PrIxua9FhTpyKLo4jGM+uvNPzzmWQ8W72dsGu8mACbtHfNI6rl3wO/aeeTrwAKU6VcHpu00CPDxv/O06Czypu3KcBDzJgIm69p75OQXBsjrLzzM7ih+HOwOg2Ttl+Yi7ZfkIuy9ip7sAAAAAAAAAAMDZRjpyKDq52RTLN3IoOrnZFEu56FliOq46krvZFMu3fh4nu91Alzokd625Jo0TOzK8xDv4tN87IDX7u+uFrjuOsug7WO01Or2KnDl1d2S6ZzLNupLetDvyFou7mqhVOnVUhjr4Vns8T0ZzOqaeQro+sjG5evJaPAAAAAAAAAAAdXdkOY6y6DkiKIM5o0+YuGTjIjpk46K7QQHcOaNPGDrZFEu480/PutkUy7g8huW6HK0MuxQcMLx3mD082yoxO5Ut3ztaPGA6ZBznu2TjIrs7Y4c4qgNTO9sqsTuf6oc77tRYu6jt7Dprl1075Qo4Ouk2BLs9nMs7AAAAAAAAAADZFMs42RRLOdxjdTlBAVy5bK3DumTjIjpWngs6J8bXOXijsLo8huW6Es0Fu4xjvrmJFBS7V+LCO/D1sbskd606JbDxOyryo7rN5Zm72RRLOkmdKzxScr87+JEBu1jtNbriu406bsMpu3V3ZLkkdy06++ArvEO+hb4AAAAAAAAAAA==",
          "dtype": "f4",
          "shape": "32, 32"
         }
        }
       ],
       "layout": {
        "coloraxis": {
         "cmid": 0,
         "colorbar": {
          "ticksuffix": "%",
          "title": {
           "text": "Logit diff. variation"
          }
         },
         "colorscale": [
          [
           0,
           "rgb(103,0,31)"
          ],
          [
           0.1,
           "rgb(178,24,43)"
          ],
          [
           0.2,
           "rgb(214,96,77)"
          ],
          [
           0.3,
           "rgb(244,165,130)"
          ],
          [
           0.4,
           "rgb(253,219,199)"
          ],
          [
           0.5,
           "rgb(247,247,247)"
          ],
          [
           0.6,
           "rgb(209,229,240)"
          ],
          [
           0.7,
           "rgb(146,197,222)"
          ],
          [
           0.8,
           "rgb(67,147,195)"
          ],
          [
           0.9,
           "rgb(33,102,172)"
          ],
          [
           1,
           "rgb(5,48,97)"
          ]
         ]
        },
        "height": 400,
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermap": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermap"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "title": {
         "text": "Path patching results "
        },
        "width": 500,
        "xaxis": {
         "anchor": "y",
         "constrain": "domain",
         "domain": [
          0,
          1
         ],
         "scaleanchor": "y",
         "title": {
          "text": "Layer"
         }
        },
        "yaxis": {
         "anchor": "x",
         "autorange": "reversed",
         "constrain": "domain",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "Head"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_h30_4_v = analyze_path_to_head([(30, 4)], \"v\", model, processed_data, begin_layer=0, batch_size=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fi4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
