{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fa3215a9",
   "metadata": {},
   "source": [
    "#### This notebook is used to create dataframes (i.e. df_avg_sim_{arch}_epoch{epoch}.pkl) holding results related to \"Attention-Head Stability\" for the architecture. \n",
    "#### We create such dataframes for all 26 architectures.\n",
    "#### Eventually, all those 26 dataframes are used in \"most_and_least_stable_layers.ipynb\" to create plots for Section 4.6 (Most- and least-stable layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21bc05df",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import itertools\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import sys\n",
    "from collections import Counter, defaultdict\n",
    "from copy import deepcopy\n",
    "from dataclasses import dataclass\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "from typing import Any, Callable, Literal, TypeAlias\n",
    "\n",
    "import einops\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch as t\n",
    "from datasets import load_dataset\n",
    "from IPython.display import clear_output, display\n",
    "from jaxtyping import Float, Int\n",
    "from rich import print as rprint\n",
    "from rich.table import Table\n",
    "\n",
    "from transformer_lens import HookedTransformer, HookedTransformerConfig\n",
    "from tabulate import tabulate\n",
    "from torch import Tensor, nn\n",
    "from torch.nn import functional as F\n",
    "from tqdm.auto import tqdm\n",
    "from transformer_lens import ActivationCache, loading_from_pretrained\n",
    "from transformer_lens.hook_points import HookPoint\n",
    "from transformer_lens.utils import get_act_name, to_numpy\n",
    "from transformer_lens import utils\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "\n",
    "from scipy.sparse import csr_array\n",
    "from scipy.sparse.csgraph import maximum_bipartite_matching, min_weight_full_bipartite_matching\n",
    "\n",
    "device = t.device(\"mps\" if t.backends.mps.is_available() else \"cuda\" if t.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aba7f609",
   "metadata": {},
   "source": [
    "## Prompt Sentences Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a6f9b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 100 prompts from ./100_prompts.pkl\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "prompts_file = \"100_prompts\"\n",
    "\n",
    "with open(f\"../{prompts_file}.pkl\", \"rb\") as f:\n",
    "    prompts = pickle.load(f)\n",
    "\n",
    "print(f\"Loaded {len(prompts)} prompts from ./{prompts_file}.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0972cefe",
   "metadata": {},
   "source": [
    "## Model & State_dict Loading "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ad4b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "\n",
    "# ------------------- Specify Model Arch -------------------\n",
    "arch = \"l8_h8\"\n",
    "\n",
    "# ------------------- Load Model Config -------------------\n",
    "\n",
    "def load_named_config(module_name: str, config_name: str) -> dict:\n",
    "    \"\"\"\n",
    "    Import a module that defines CONFIGS: Dict[str, Dict[str, Any]]\n",
    "    and return CONFIGS[config_name].\n",
    "    \"\"\"\n",
    "    try:\n",
    "        mod = importlib.import_module(module_name)\n",
    "    except Exception as e:\n",
    "        raise ImportError(f\"Could not import config module '{module_name}': {e}\") from e\n",
    "\n",
    "    if not hasattr(mod, \"CONFIGS\"):\n",
    "        raise AttributeError(f\"Module '{module_name}' does not define CONFIGS.\")\n",
    "\n",
    "    CONFIGS = getattr(mod, \"CONFIGS\")\n",
    "    if config_name not in CONFIGS:\n",
    "        available = \", \".join(sorted(CONFIGS.keys()))\n",
    "        raise KeyError(f\"Config '{config_name}' not found in {module_name}. Available: {available}\")\n",
    "\n",
    "    return dict(CONFIGS[config_name])  # copy so we can tweak\n",
    "\n",
    "cfg_dict = load_named_config(\"model_configs\", arch)\n",
    "\n",
    "\n",
    "# ------------------- Model Configuration -------------------\n",
    "\n",
    "# Build HookedTransformerConfig using the loaded config\n",
    "cfg = HookedTransformerConfig(\n",
    "    n_layers=cfg_dict[\"n_layers\"],\n",
    "    d_model=cfg_dict[\"d_model\"],\n",
    "    n_heads=cfg_dict[\"n_heads\"],\n",
    "    d_head=cfg_dict[\"d_head\"],\n",
    "    d_mlp=cfg_dict.get(\"d_mlp\", None),\n",
    "    n_ctx=cfg_dict[\"n_ctx\"],\n",
    "    act_fn=cfg_dict.get(\"act_fn\", \"gelu\"),\n",
    "    d_vocab=cfg_dict[\"d_vocab\"],\n",
    "    init_weights=True,\n",
    "    tokenizer_name=cfg_dict[\"tokenizer_name\"],\n",
    "    model_name=cfg_dict.get(\"model_name\", arch),\n",
    "    attn_only=cfg_dict.get(\"attn_only\", False),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6efde1a",
   "metadata": {},
   "source": [
    "### Run next cell to cover each variant of architecture in \"arch_list\" creating a pickle file for it. \n",
    "### Skip this cell if pickle files have been already created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7dc1812",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Setting device to CPU as GPU memory is insufficient for this computation, but for smaller number of prompts/models it can be set to GPU\n",
    "device = 'cpu'\n",
    "\n",
    "# Run for each arch but only for the single mentioned epoch (first entry of epoch_list).\n",
    "# Then plot all arch curves on the same plot.\n",
    "arch_list = [f\"{arch}\", f\"{arch}_wd\"]\n",
    "\n",
    "cfg_dict = load_named_config(\"model_configs\",arch)\n",
    "\n",
    "# ------------------- Model Configuration -------------------\n",
    "\n",
    "# Build HookedTransformerConfig using the loaded config\n",
    "cfg = HookedTransformerConfig(\n",
    "    n_layers=cfg_dict[\"n_layers\"],\n",
    "    d_model=cfg_dict[\"d_model\"],\n",
    "    n_heads=cfg_dict[\"n_heads\"],\n",
    "    d_head=cfg_dict[\"d_head\"],\n",
    "    d_mlp=cfg_dict.get(\"d_mlp\", None),\n",
    "    n_ctx=cfg_dict[\"n_ctx\"],\n",
    "    act_fn=cfg_dict.get(\"act_fn\", \"gelu\"),\n",
    "    d_vocab=cfg_dict[\"d_vocab\"],\n",
    "    init_weights=True,\n",
    "    tokenizer_name=cfg_dict[\"tokenizer_name\"],\n",
    "    attn_only=cfg_dict.get(\"attn_only\", False),\n",
    ")\n",
    "\n",
    "\n",
    "# constants reused from your original cell\n",
    "SEEDS = [i for i in range(1, 51)]\n",
    "if \"gpt2\" in arch_list[0]:\n",
    "    SEEDS = [i for i in range(1, 6)] \n",
    "\n",
    "avg_sim_dir = \"df/avg_sim\"\n",
    "# Constants\n",
    "SCRATCH = \"Path to root directory\"\n",
    "NUM_LAYERS = cfg.n_layers\n",
    "NUM_HEADS = cfg.n_heads\n",
    "ATTN_ONLY = cfg.attn_only\n",
    "chkpt_file = \"final.pt\"\n",
    "shard = 9\n",
    "epoch = 1\n",
    "\n",
    "# helper (kept from original)\n",
    "def lower_triang(mat):\n",
    "    lower_triangular_mat = torch.tril(mat)\n",
    "    mask = torch.tril(torch.ones_like(mat, device=device)).bool()\n",
    "    lower_triangular_vec = lower_triangular_mat[mask]\n",
    "    return lower_triangular_vec\n",
    "\n",
    "cos = t.nn.CosineSimilarity(dim=1, eps=1e-08)\n",
    "\n",
    "# storage across architectures for combined plotting\n",
    "arch_layer_means = {}\n",
    "arch_layer_vars = {}\n",
    "\n",
    "for arch in arch_list:\n",
    "\n",
    "    print(f\"Processing arch: {arch}\")\n",
    "    cfg_dict = load_named_config(\"model_configs\", arch)\n",
    "\n",
    "    # Build HookedTransformerConfig for this arch\n",
    "    cfg = HookedTransformerConfig(\n",
    "        n_layers=cfg_dict[\"n_layers\"],\n",
    "        d_model=cfg_dict[\"d_model\"],\n",
    "        n_heads=cfg_dict[\"n_heads\"],\n",
    "        d_head=cfg_dict[\"d_head\"],\n",
    "        d_mlp=cfg_dict.get(\"d_mlp\", None),\n",
    "        n_ctx=cfg_dict[\"n_ctx\"],\n",
    "        act_fn=cfg_dict.get(\"act_fn\", \"gelu\"),\n",
    "        d_vocab=cfg_dict[\"d_vocab\"],\n",
    "        init_weights=True,\n",
    "        tokenizer_name=cfg_dict[\"tokenizer_name\"],\n",
    "        model_name=cfg_dict.get(\"model_name\", arch),\n",
    "        attn_only=cfg_dict.get(\"attn_only\", False),\n",
    "    )\n",
    "\n",
    "    chkpt_dir = SCRATCH + \"chkpts/\" + arch\n",
    "\n",
    "    # Load models for this epoch\n",
    "    models = []\n",
    "    for SEED in SEEDS:\n",
    "        cfg.seed = SEED\n",
    "        cfg.init_weights = True\n",
    "        model = HookedTransformer(cfg)\n",
    "        models.append(model)\n",
    "\n",
    "    for ind, SEED in enumerate(SEEDS):\n",
    "        if (arch == \"gpt2\") or (arch == \"gpt2_wd\"):\n",
    "            model_state_dict = t.load(chkpt_dir + f\"/gpt2_seed{SEED}_shard{shard}_epoch{epoch}_owt/{chkpt_file}\")\n",
    "            models[ind].load_and_process_state_dict(model_state_dict, fold_ln=False)\n",
    "        else:\n",
    "            if ATTN_ONLY:\n",
    "                model_state_dict = t.load(\n",
    "                    chkpt_dir + f\"/causal_attn_only_l{NUM_LAYERS}_h{NUM_HEADS}_seed{SEED}_epoch{epoch}_c4_gelu/{chkpt_file}\"\n",
    "                )\n",
    "                models[ind].load_and_process_state_dict(model_state_dict, fold_ln=False)\n",
    "\n",
    "            else:\n",
    "                model_state_dict = t.load(\n",
    "                    chkpt_dir + f\"/causal_attn_l{NUM_LAYERS}_h{NUM_HEADS}_seed{SEED}_epoch{epoch}_c4_gelu/{chkpt_file}\"\n",
    "                )\n",
    "                models[ind].load_and_process_state_dict(model_state_dict, fold_ln=False)\n",
    "\n",
    "    # Setting device to CPU as GPU memory is insufficient for this computation, but for smaller number of prompts/models it can be set to GPU\n",
    "    device = 'cpu'\n",
    "\n",
    "    # run prompts to collect caches (using CPU to avoid CUDA OOM)\n",
    "    prompts_cache = []\n",
    "    for prompt in prompts:\n",
    "        cache_for_prompt = []\n",
    "        for ind in range(len(SEEDS)):\n",
    "            _, cache_i = models[ind].run_with_cache(prompt, remove_batch_dim=True)\n",
    "            # Keep cache on CPU\n",
    "            cache_i = cache_i.to('cpu')\n",
    "            cache_for_prompt.append(cache_i)\n",
    "        prompts_cache.append(cache_for_prompt)\n",
    "\n",
    "    # compute cosine-similarity matrix across layers/heads/models per prompt\n",
    "    NUM_MODELS = len(models)\n",
    "    NUM_HEADS = models[0].cfg.n_heads\n",
    "    NUM_LAYERS = models[0].cfg.n_layers\n",
    "    NUM_PROMPTS = len(prompts_cache)\n",
    "\n",
    "    # Set the list of anchors index that you want to cover.\n",
    "    # Default is [0]\n",
    "    anchor_range = 1\n",
    "\n",
    "    # Free memory held by models\n",
    "    del models\n",
    "    del model_state_dict\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    all_results = []\n",
    "\n",
    "    for anchor in range(anchor_range):\n",
    "\n",
    "        print(\"Computing cosine similarity matrix...\")\n",
    "        prompts_cs_matrix = t.empty(( NUM_LAYERS, NUM_HEADS, NUM_MODELS, NUM_HEADS, NUM_PROMPTS), device=device)\n",
    "\n",
    "        # Compute cosine-similarity matrix across layers/heads/models per prompt\n",
    "        for ind_prompt in range(NUM_PROMPTS):\n",
    "            cache = prompts_cache[ind_prompt]\n",
    "            cache_anchor = cache[anchor].to(device)\n",
    "            \n",
    "            for layer in range(NUM_LAYERS):\n",
    "                for head_anchor in range(NUM_HEADS):\n",
    "                        \n",
    "                        head_anchor_attn = cache_anchor[utils.get_act_name('pattern',layer,'a')][head_anchor]\n",
    "                        head_anchor_attn = lower_triang(head_anchor_attn.to(device))\n",
    "                        head_anchor_attn = t.unsqueeze(head_anchor_attn, 0)\n",
    "\n",
    "                        for model_i in range(NUM_MODELS):\n",
    "                            cache_model_i = cache[model_i].to(device)\n",
    "                            for head_pair in range(NUM_HEADS):\n",
    "                                head_i_attn = cache_model_i[utils.get_act_name('pattern',layer,'a')][head_pair]\n",
    "                                head_i_attn = lower_triang(head_i_attn.to(device))\n",
    "                                head_i_attn = t.unsqueeze(head_i_attn, 0)\n",
    "                                cos_score_i = cos(head_anchor_attn, head_i_attn)\n",
    "                                prompts_cs_matrix[layer, head_anchor, model_i, head_pair, ind_prompt] = cos_score_i\n",
    "                                del head_i_attn, cos_score_i\n",
    "\n",
    "                            del cache_model_i\n",
    "\n",
    "                        del head_anchor_attn\n",
    "\n",
    "            del cache_anchor\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "                    \n",
    "        # aggregate to get per-(layer, head_anchor) cos_sim (same aggregation as original)\n",
    "        #results = []\n",
    "        for layer_i in range(NUM_LAYERS):\n",
    "\n",
    "            csh = prompts_cs_matrix[layer_i,:,:,:]                # [head_anchor, model_i, head_pair, prompts]\n",
    "            csh = torch.mean(csh, -1)                            # mean over prompts -> [head_anchor, model_i, head_pair]\n",
    "            csh = csh.max(dim=-1, keepdim=True)                  # max over head_pair -> values shape [head_anchor, model_i, 1]\n",
    "            csh = csh.values[:,:,0]                              # [head_anchor, model_i]\n",
    "            csh = torch.cat([csh[:, :anchor], csh[:, anchor+1:]], dim=1) # remove comparison of anchor model with itself\n",
    "            csh = csh.mean(dim=-1).reshape(NUM_HEADS, 1)         # mean over models -> [head_anchor, 1]\n",
    "\n",
    "            for head_anchor in range(NUM_HEADS):\n",
    "                all_results.append({\n",
    "                    \"arch\": arch,\n",
    "                    \"epoch\": epoch,\n",
    "                    \"anchor_idx\": anchor,\n",
    "                    \"layer\": layer_i + 1,\n",
    "                    \"head_anchor\": head_anchor + 1,\n",
    "                    \"cos_sim\": csh[head_anchor].item()\n",
    "                })\n",
    "                    \n",
    "\n",
    "    # create dataframe with per-anchor results and also store aggregated mean/var per (layer, head)\n",
    "    df_all_anchors = pd.DataFrame(all_results)\n",
    "    # ensure directory exists\n",
    "    out_dir = os.path.join(SCRATCH, avg_sim_dir)\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "    # save detailed per-anchor dataframe for this arch\n",
    "    out_path = os.path.join(out_dir, f\"df_avg_sim_{arch}_epoch{epoch}.pkl\")\n",
    "    df_all_anchors.to_pickle(out_path)\n",
    "    print(f\"Saved df_all_anchors to {out_path}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.10.4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
