{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9952e040",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapper_path = \"INSERT YOURS\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dad6c170",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "all_mappers = [\n",
    "    torch.load(mapper_path + f\"mapper_lookahead_{i}.pt\")\n",
    "    for i in range(1, 6)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6c1ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import trange\n",
    "import torch\n",
    "\n",
    "def get_projector(Vh, r):\n",
    "    Vh[r:] = 0\n",
    "    return Vh.adjoint() @ Vh\n",
    "\n",
    "def get_projections(Vh, E, r):\n",
    "    return get_projector(Vh, r) @ E\n",
    "\n",
    "def get_cos(Vh, E, r):\n",
    "    Ehat = get_projections(Vh, E, r)\n",
    "    cos =  torch.nn.functional.cosine_similarity(E, Ehat, dim=0)\n",
    "    del Ehat\n",
    "    return cos\n",
    "\n",
    "def get_all_cos(W, E):\n",
    "    U, S, Vh = torch.linalg.svd(W, full_matrices=False)\n",
    "    del U\n",
    "    result = torch.zeros((S.size(0) + 1, E.size(1)), dtype=E.dtype, device=\"cpu\")\n",
    "    for r in trange(S.size(0), 0, -1):\n",
    "        result[r] = get_cos(Vh, E, r).cpu()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5abddbee",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapper_matrices = [\n",
    "    mapper[\"projection.weight\"] for mapper in all_mappers\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e3b8d99",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release=\"gemma-scope-2b-pt-res-canonical\",  # <- Release name\n",
    "    sae_id=\"layer_15/width_16k/canonical\",  # <- SAE id (not always a hook point!)\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "sae_matrix = sae.W_enc\n",
    "sae_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8540f6dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    cos_results = [\n",
    "        get_all_cos(mapper_mat, sae_matrix)\n",
    "        for mapper_mat in mapper_matrices\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4dbb55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(cos_results, \"INSERT YOURS\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
