{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import itertools\n",
    "import math\n",
    "import os\n",
    "import random\n",
    "import sys\n",
    "from collections import Counter, defaultdict\n",
    "from copy import deepcopy\n",
    "from dataclasses import dataclass\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "from typing import Any, Callable, Literal, TypeAlias\n",
    "\n",
    "import einops\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch as t\n",
    "from datasets import load_dataset\n",
    "from IPython.display import clear_output, display\n",
    "from jaxtyping import Float, Int\n",
    "from rich import print as rprint\n",
    "from rich.table import Table\n",
    "\n",
    "from transformer_lens import HookedTransformer, HookedTransformerConfig\n",
    "from tabulate import tabulate\n",
    "from torch import Tensor, nn\n",
    "from torch.nn import functional as F\n",
    "from tqdm.auto import tqdm\n",
    "from transformer_lens import ActivationCache, loading_from_pretrained\n",
    "from transformer_lens.hook_points import HookPoint\n",
    "from transformer_lens.utils import get_act_name, to_numpy\n",
    "from transformer_lens import utils\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "\n",
    "from scipy.sparse import csr_array\n",
    "from scipy.sparse.csgraph import maximum_bipartite_matching, min_weight_full_bipartite_matching\n",
    "\n",
    "device = t.device(\"mps\" if t.backends.mps.is_available() else \"cuda\" if t.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model & State_dict Loading "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "\n",
    "# ------------------- Load Model Config -------------------\n",
    "\n",
    "def load_named_config(module_name: str, config_name: str) -> dict:\n",
    "    \"\"\"\n",
    "    Import a module that defines CONFIGS: Dict[str, Dict[str, Any]]\n",
    "    and return CONFIGS[config_name].\n",
    "    \"\"\"\n",
    "    try:\n",
    "        mod = importlib.import_module(module_name)\n",
    "    except Exception as e:\n",
    "        raise ImportError(f\"Could not import config module '{module_name}': {e}\") from e\n",
    "\n",
    "    if not hasattr(mod, \"CONFIGS\"):\n",
    "        raise AttributeError(f\"Module '{module_name}' does not define CONFIGS.\")\n",
    "\n",
    "    CONFIGS = getattr(mod, \"CONFIGS\")\n",
    "    if config_name not in CONFIGS:\n",
    "        available = \", \".join(sorted(CONFIGS.keys()))\n",
    "        raise KeyError(f\"Config '{config_name}' not found in {module_name}. Available: {available}\")\n",
    "\n",
    "    return dict(CONFIGS[config_name])  # copy so we can tweak\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompt Sentences Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 100 prompts from ./100_prompts.pkl\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "prompts_file = \"100_prompts\"\n",
    "\n",
    "with open(f\"./{prompts_file}.pkl\", \"rb\") as f:\n",
    "    prompts = pickle.load(f)\n",
    "\n",
    "print(f\"Loaded {len(prompts)} prompts from ./{prompts_file}.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CKA Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------- CKA Implementation -------------------\n",
    "# Code adapted from CKA paper and also cited in our paper: https://github.com/google-research/google-research/blob/master/representation_similarity/cka.py\n",
    "\n",
    "def gram_linear(x):\n",
    "  \"\"\"Compute Gram (kernel) matrix for a linear kernel.\n",
    "\n",
    "  Args:\n",
    "    x: A num_examples x num_features matrix of features.\n",
    "\n",
    "  Returns:\n",
    "    A num_examples x num_examples Gram matrix of examples.\n",
    "  \"\"\"\n",
    "  return x.dot(x.T)\n",
    "\n",
    "\n",
    "def gram_rbf(x, threshold=1.0):\n",
    "  \"\"\"Compute Gram (kernel) matrix for an RBF kernel.\n",
    "\n",
    "  Args:\n",
    "    x: A num_examples x num_features matrix of features.\n",
    "    threshold: Fraction of median Euclidean distance to use as RBF kernel\n",
    "      bandwidth. (This is the heuristic we use in the paper. There are other\n",
    "      possible ways to set the bandwidth; we didn't try them.)\n",
    "\n",
    "  Returns:\n",
    "    A num_examples x num_examples Gram matrix of examples.\n",
    "  \"\"\"\n",
    "  dot_products = x.dot(x.T)\n",
    "  sq_norms = np.diag(dot_products)\n",
    "  sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]\n",
    "  sq_median_distance = np.median(sq_distances)\n",
    "  return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))\n",
    "\n",
    "\n",
    "def center_gram(gram, unbiased=False):\n",
    "  \"\"\"Center a symmetric Gram matrix.\n",
    "\n",
    "  This is equvialent to centering the (possibly infinite-dimensional) features\n",
    "  induced by the kernel before computing the Gram matrix.\n",
    "\n",
    "  Args:\n",
    "    gram: A num_examples x num_examples symmetric matrix.\n",
    "    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased\n",
    "      estimate of HSIC. Note that this estimator may be negative.\n",
    "\n",
    "  Returns:\n",
    "    A symmetric matrix with centered columns and rows.\n",
    "  \"\"\"\n",
    "  if not np.allclose(gram, gram.T):\n",
    "    raise ValueError('Input must be a symmetric matrix.')\n",
    "  gram = gram.copy()\n",
    "\n",
    "  if unbiased:\n",
    "    # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.\n",
    "    # L. (2014). Partial distance correlation with methods for dissimilarities.\n",
    "    # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically\n",
    "    # stable than the alternative from Song et al. (2007).\n",
    "    n = gram.shape[0]\n",
    "    np.fill_diagonal(gram, 0)\n",
    "    means = np.sum(gram, 0, dtype=np.float64) / (n - 2)\n",
    "    means -= np.sum(means) / (2 * (n - 1))\n",
    "    gram -= means[:, None]\n",
    "    gram -= means[None, :]\n",
    "    np.fill_diagonal(gram, 0)\n",
    "  else:\n",
    "    means = np.mean(gram, 0, dtype=np.float64)\n",
    "    means -= np.mean(means) / 2\n",
    "    gram -= means[:, None]\n",
    "    gram -= means[None, :]\n",
    "\n",
    "  return gram\n",
    "\n",
    "\n",
    "def cka(gram_x, gram_y, debiased=False):\n",
    "  \"\"\"Compute CKA.\n",
    "\n",
    "  Args:\n",
    "    gram_x: A num_examples x num_examples Gram matrix.\n",
    "    gram_y: A num_examples x num_examples Gram matrix.\n",
    "    debiased: Use unbiased estimator of HSIC. CKA may still be biased.\n",
    "\n",
    "  Returns:\n",
    "    The value of CKA between X and Y.\n",
    "  \"\"\"\n",
    "  gram_x = center_gram(gram_x, unbiased=debiased)\n",
    "  gram_y = center_gram(gram_y, unbiased=debiased)\n",
    "\n",
    "  # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or\n",
    "  # n*(n-3) (unbiased variant), but this cancels for CKA.\n",
    "  scaled_hsic = gram_x.ravel().dot(gram_y.ravel())\n",
    "\n",
    "  normalization_x = np.linalg.norm(gram_x)\n",
    "  normalization_y = np.linalg.norm(gram_y)\n",
    "  return scaled_hsic / (normalization_x * normalization_y)\n",
    "\n",
    "\n",
    "def _debiased_dot_product_similarity_helper(\n",
    "    xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y,\n",
    "    n):\n",
    "  \"\"\"Helper for computing debiased dot product similarity (i.e. linear HSIC).\"\"\"\n",
    "  # This formula can be derived by manipulating the unbiased estimator from\n",
    "  # Song et al. (2007).\n",
    "  return (\n",
    "      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)\n",
    "      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))\n",
    "\n",
    "\n",
    "def feature_space_linear_cka(features_x, features_y, debiased=False):\n",
    "  \"\"\"Compute CKA with a linear kernel, in feature space.\n",
    "\n",
    "  This is typically faster than computing the Gram matrix when there are fewer\n",
    "  features than examples.\n",
    "\n",
    "  Args:\n",
    "    features_x: A num_examples x num_features matrix of features.\n",
    "    features_y: A num_examples x num_features matrix of features.\n",
    "    debiased: Use unbiased estimator of dot product similarity. CKA may still be\n",
    "      biased. Note that this estimator may be negative.\n",
    "\n",
    "  Returns:\n",
    "    The value of CKA between X and Y.\n",
    "  \"\"\"\n",
    "  features_x = features_x - np.mean(features_x, 0, keepdims=True)\n",
    "  features_y = features_y - np.mean(features_y, 0, keepdims=True)\n",
    "\n",
    "  dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2\n",
    "  normalization_x = np.linalg.norm(features_x.T.dot(features_x))\n",
    "  normalization_y = np.linalg.norm(features_y.T.dot(features_y))\n",
    "\n",
    "  if debiased:\n",
    "    n = features_x.shape[0]\n",
    "    # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.\n",
    "    sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)\n",
    "    sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)\n",
    "    squared_norm_x = np.sum(sum_squared_rows_x)\n",
    "    squared_norm_y = np.sum(sum_squared_rows_y)\n",
    "\n",
    "    dot_product_similarity = _debiased_dot_product_similarity_helper(\n",
    "        dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,\n",
    "        squared_norm_x, squared_norm_y, n)\n",
    "    normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(\n",
    "        normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,\n",
    "        squared_norm_x, squared_norm_x, n))\n",
    "    normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(\n",
    "        normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,\n",
    "        squared_norm_y, squared_norm_y, n))\n",
    "\n",
    "  return dot_product_similarity / (normalization_x * normalization_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stability of residual stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run for each arch but only for the single mentioned epoch (first entry of epoch_list).\n",
    "# Then plot all arch curves on the same plot.\n",
    "\n",
    "arch = \"gpt2\"\n",
    "arch_list = [f\"{arch}\", f\"{arch}_wd\"]\n",
    "print(arch_list)\n",
    "\n",
    "# pick only the mentioned epoch (first one). Change this if you want a different single epoch.\n",
    "epoch = 1\n",
    "\n",
    "# constants reused from your original cell\n",
    "SEEDS = [i for i in range(1, 51)]\n",
    "if arch == \"gpt2\":\n",
    "    SEEDS = [i for i in range(1, 6)]\n",
    "    shard = 9\n",
    "SCRATCH = \"Path to root directory\"\n",
    "chkpt_file = \"final.pt\"\n",
    "\n",
    "\n",
    "\n",
    "df = pd.DataFrame(columns=['arch', 'inst_1', 'inst_2', 'layer', \"cka_sim\"])\n",
    "\n",
    "for arch in arch_list:\n",
    "\n",
    "    # Models weights directory\n",
    "    chkpt_dir = SCRATCH + \"chkpts/\" + arch\n",
    "    print(f\"chkpt_dir: {chkpt_dir}/{arch}\")\n",
    "\n",
    "    print(f\"Processing arch {arch} (epoch {epoch}) ...\")\n",
    "\n",
    "    cfg_dict = load_named_config(\"model_configs\", arch)\n",
    "\n",
    "    # Build HookedTransformerConfig using the loaded config\n",
    "    cfg = HookedTransformerConfig(\n",
    "        n_layers=cfg_dict[\"n_layers\"],\n",
    "        d_model=cfg_dict[\"d_model\"],\n",
    "        n_heads=cfg_dict[\"n_heads\"],\n",
    "        d_head=cfg_dict[\"d_head\"],\n",
    "        d_mlp=cfg_dict.get(\"d_mlp\", None),\n",
    "        n_ctx=cfg_dict[\"n_ctx\"],\n",
    "        act_fn=cfg_dict.get(\"act_fn\", \"gelu\"),\n",
    "        d_vocab=cfg_dict[\"d_vocab\"],\n",
    "        init_weights=True,\n",
    "        tokenizer_name=cfg_dict[\"tokenizer_name\"],\n",
    "        model_name=cfg_dict.get(\"model_name\", arch),\n",
    "        attn_only=cfg_dict.get(\"attn_only\", False),\n",
    "    )\n",
    "\n",
    "    ATTN_ONLY = cfg.attn_only\n",
    "    NUM_LAYERS = cfg.n_layers\n",
    "    NUM_HEADS = cfg.n_heads\n",
    "\n",
    "    # Load models for this epoch\n",
    "    models = []\n",
    "    for SEED in SEEDS:\n",
    "        cfg.seed = SEED\n",
    "        cfg.init_weights = True\n",
    "        model = HookedTransformer(cfg)\n",
    "        models.append(model)\n",
    "\n",
    "    for ind, SEED in enumerate(SEEDS):\n",
    "        if (arch == \"gpt2\") or (arch == \"gpt2_wd\"):\n",
    "            model_state_dict = t.load(chkpt_dir + f\"/gpt2_seed{SEED}_shard{shard}_epoch{epoch}_owt/{chkpt_file}\")\n",
    "            models[ind].load_and_process_state_dict(model_state_dict, fold_ln=False)\n",
    "        else:\n",
    "            if ATTN_ONLY:\n",
    "                model_state_dict = t.load(\n",
    "                    chkpt_dir + f\"/causal_attn_only_l{NUM_LAYERS}_h{NUM_HEADS}_seed{SEED}_epoch{epoch}_c4_gelu/{chkpt_file}\"\n",
    "                )\n",
    "                models[ind].load_and_process_state_dict(model_state_dict, fold_ln=False)\n",
    "\n",
    "            else:\n",
    "                model_state_dict = t.load(\n",
    "                    chkpt_dir + f\"/causal_attn_l{NUM_LAYERS}_h{NUM_HEADS}_seed{SEED}_epoch{epoch}_c4_gelu/{chkpt_file}\"\n",
    "                )\n",
    "                models[ind].load_and_process_state_dict(model_state_dict, fold_ln=False)\n",
    "\n",
    "    # Setting device to CPU as GPU memory is insufficient for this computation, but for smaller number of prompts/models it can be set to GPU\n",
    "    device = 'cpu'\n",
    "\n",
    "    # run prompts to collect caches (using CPU to avoid CUDA OOM)\n",
    "    prompts_cache = []\n",
    "    for prompt in prompts:\n",
    "        cache_for_prompt = []\n",
    "        for ind in range(len(SEEDS)):\n",
    "            _, cache_i = models[ind].run_with_cache(prompt, remove_batch_dim=True)\n",
    "            # Keep cache on CPU\n",
    "            cache_i = cache_i.to('cpu')\n",
    "            cache_for_prompt.append(cache_i)\n",
    "        prompts_cache.append(cache_for_prompt)\n",
    "\n",
    "    # Constants\n",
    "    NUM_MODELS = len(models)\n",
    "    NUM_HEADS = models[0].cfg.n_heads\n",
    "    NUM_LAYERS = models[0].cfg.n_layers\n",
    "    NUM_PROMPTS = len(prompts_cache)\n",
    "\n",
    "\n",
    "    for inst_1 in range(NUM_MODELS):\n",
    "        print(f\"  Processing model instance {inst_1+1}\")\n",
    "        for inst_2 in range(NUM_MODELS):\n",
    "            for layer in range(NUM_LAYERS):\n",
    "                activations1 = []\n",
    "                activations2 = []\n",
    "\n",
    "                for cache in prompts_cache:\n",
    "                    activations1.append(cache[inst_1][utils.get_act_name('resid_pre', layer)].numpy())\n",
    "                    activations2.append(cache[inst_2][utils.get_act_name('resid_pre', layer)].numpy())\n",
    "\n",
    "                activations1 = np.vstack(activations1).astype(np.float64)\n",
    "                activations2 = np.vstack(activations2).astype(np.float64)\n",
    "\n",
    "                # compute gram matrices and enforce symmetry to avoid the ValueError\n",
    "                G1 = gram_rbf(activations1)\n",
    "                G2 = gram_rbf(activations2)\n",
    "                G1 = (G1 + G1.T) / 2.0\n",
    "                G2 = (G2 + G2.T) / 2.0\n",
    "\n",
    "                cka_sim = float(cka(G1, G2))\n",
    "\n",
    "                new_row = {'arch': arch, 'inst_1': inst_1+1, 'inst_2': inst_2+1, 'layer': layer,  'cka_sim': cka_sim}\n",
    "                df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>arch</th>\n",
       "      <th>inst_1</th>\n",
       "      <th>inst_2</th>\n",
       "      <th>layer</th>\n",
       "      <th>cka_sim</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>l2_h8</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>l2_h8</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>l2_h8</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0.938305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>l2_h8</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0.953574</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>l2_h8</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.938515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>l2_h8_wd</td>\n",
       "      <td>50</td>\n",
       "      <td>48</td>\n",
       "      <td>1</td>\n",
       "      <td>0.953675</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>l2_h8_wd</td>\n",
       "      <td>50</td>\n",
       "      <td>49</td>\n",
       "      <td>0</td>\n",
       "      <td>0.880930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>l2_h8_wd</td>\n",
       "      <td>50</td>\n",
       "      <td>49</td>\n",
       "      <td>1</td>\n",
       "      <td>0.952866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>l2_h8_wd</td>\n",
       "      <td>50</td>\n",
       "      <td>50</td>\n",
       "      <td>0</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>l2_h8_wd</td>\n",
       "      <td>50</td>\n",
       "      <td>50</td>\n",
       "      <td>1</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10000 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          arch inst_1 inst_2 layer   cka_sim\n",
       "0        l2_h8      1      1     0  1.000000\n",
       "1        l2_h8      1      1     1  1.000000\n",
       "2        l2_h8      1      2     0  0.938305\n",
       "3        l2_h8      1      2     1  0.953574\n",
       "4        l2_h8      1      3     0  0.938515\n",
       "...        ...    ...    ...   ...       ...\n",
       "9995  l2_h8_wd     50     48     1  0.953675\n",
       "9996  l2_h8_wd     50     49     0  0.880930\n",
       "9997  l2_h8_wd     50     49     1  0.952866\n",
       "9998  l2_h8_wd     50     50     0  1.000000\n",
       "9999  l2_h8_wd     50     50     1  1.000000\n",
       "\n",
       "[10000 rows x 5 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ticklabel_layers = [i for i in range(1, NUM_LAYERS+1)]\n",
    "\n",
    "# Combined plots: all architectures on same axes for the selected epoch\n",
    "plt.figure(figsize=(10,4))\n",
    "for arch in arch_list:\n",
    "    df_arch = df[df['arch'] == arch]\n",
    "    df_layer_sim = df_arch.groupby('layer')['cka_sim'].mean().reset_index()\n",
    "    plt.plot(ticklabel_layers, df_layer_sim['cka_sim'], marker='o', label='Adam' if arch == arch_list[0] else 'AdamW')\n",
    "plt.xlabel('Layer', fontsize=1)\n",
    "plt.ylabel('Stability (CKA)', fontsize=16)\n",
    "#plt.title(f'Stability comparison of residual stream for 8-layers 8-heads MLP architecture: Adam vs AdamW')\n",
    "plt.ylim(0.5,1)\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.10.4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
