{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee35b20c-78c4-428b-8a50-8a866d44bfd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import numpy as np\n",
    "from sklearn.cross_decomposition import CCA\n",
    "import torch\n",
    "import os\n",
    "import gc\n",
    "import yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1210bce-96c8-4d3a-a062-1670c2538b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "os.environ[\"HF_TOKEN\"] = global_cfg['hf_access_token']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29977fdd-85de-4f63-a630-089fd2e2a481",
   "metadata": {},
   "source": [
    "# Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9dcced7-7098-408f-9661-165b01f51411",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRUNCATION_LENGTH = 512\n",
    "PADDING_TOKEN = 0\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "modelA_name = 'gpt2-small'#'gemma-2-2b'#'pythia-70m-deduped'#\n",
    "modelB_name = 'gpt2-medium'#'pythia-160m-deduped'#'gemma-2-9b'#'pythia-160m-deduped'#\n",
    "tokenized_dataset = {}\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/{modelA_name}_tokenized_dataset_200000_{dataset_key}_{TRUNCATION_LENGTH}.pt', weights_only=True)\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de402d64-34cf-48e1-b05e-8b44517794e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_name = 'gemma-2-9b'\n",
    "#tokenizer = AutoTokenizer.from_pretrained(f'google/{model_name}', trust_remote_code=True, cache_dir=CACHE_DIR, token=access_token)\n",
    "#hf_model = AutoModelForCausalLM.from_pretrained(f'google/{model_name}', cache_dir=CACHE_DIR, token=access_token, torch_dtype=torch.float16, device_map='cpu')\n",
    "#modelA = HookedTransformer.from_pretrained_no_processing(model_name=model_name, device='cuda', cache_dir=CACHE_DIR, torch_dtype=torch.float16)\n",
    "#del hf_model\n",
    "#del tokenizer\n",
    "#gc.collect()\n",
    "#torch.cuda.empty_cache()\n",
    "modelA = HookedTransformer.from_pretrained(modelA_name, cache_dir=CACHE_DIR, device=device)\n",
    "modelB = HookedTransformer.from_pretrained(modelB_name, cache_dir=CACHE_DIR, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bbaa366-4433-468b-8367-76087cad6f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#modelA = modelA.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "facfdc25-2921-4512-a6a4-ab38d7d1aa20",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9477be0-2851-41b1-9331-838a412d0f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def get_activations(dataloader, model, layer):\n",
    "    all_acts = []\n",
    "    for i, data in enumerate(dataloader):\n",
    "        mask = data == model.tokenizer.bos_token_id #(data != model.tokenizer.pad_token_id) & (data != model.tokenizer.eos_token_id) & (data != model.tokenizer.bos_token_id) & (data != model.tokenizer.unk_token_id)\n",
    "        acts = model(data, stop_at_layer=layer)\n",
    "        all_acts.append(acts[~mask].cpu())\n",
    "    return torch.cat(all_acts)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddc60641-0e0e-406f-b057-411b2b96ff37",
   "metadata": {},
   "source": [
    "# Cache Activations\n",
    "You can run this one model at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1eb8bae-b424-4004-a6f9-b5335aa5b6d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c30a444-f285-444f-b8bc-ca22182f5d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def cache_all_activations(model, dataloader):\n",
    "    os.makedirs(f'activation_store/{model.cfg.model_name}/', exist_ok=True)\n",
    "    print(f\"Created dir {f'activation_store/{model.cfg.model_name}/'}\")\n",
    "    stop_layers = list(range(1, model.cfg.n_layers+1))\n",
    "    for layer in tqdm(stop_layers):\n",
    "        torch.save(get_activations(dataloader, model, layer), f\"activation_store/{model.cfg.model_name}/{layer}_svcca_activations.pt\")\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "621bb9ab-b54c-42a8-a1f0-e312dec40908",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache_all_activations(modelA, torch.utils.data.DataLoader(tokenized_dataset['train'][:100],batch_size=10,shuffle=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35028e49-1fb5-4b93-8059-2d5e7dc7f8df",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(f'activation_store/{modelA.cfg.model_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ba20c07-1fda-4d6d-9caf-d102ecd5e82d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache_all_activations(modelB, torch.utils.data.DataLoader(tokenized_dataset['train'][:100],batch_size=10,shuffle=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5099ba1c-041d-426f-975f-9e3bb4fb9f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(f'activation_store/{modelB_name}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a990f3-949f-4b4a-8253-f4e3727ead02",
   "metadata": {},
   "source": [
    "# Compute SVCCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552bd9c6-6a49-40c6-afbd-447d4d6c31cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from typing import Tuple\n",
    "\n",
    "def svcca_torch(\n",
    "    X: torch.Tensor,\n",
    "    Y: torch.Tensor,\n",
    "    explained_var: float = 0.99,\n",
    "    eps: float = 1e-10,\n",
    ") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"\n",
    "    SVCCA in pure PyTorch.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X, Y : (d, n) torch.Tensor\n",
    "        d = feature / neuron / residual‑stream size,\n",
    "        n = samples (e.g. batch × sequence positions).\n",
    "    explained_var : float\n",
    "        Fraction of variance to keep in the SVD compression step.\n",
    "    eps : float\n",
    "        Ridge term for numerical stability in covariance inversion.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    rho : (k,) torch.Tensor\n",
    "        Canonical correlations (descending).\n",
    "    Wx, Wy : torch.Tensor\n",
    "        CCA weight matrices mapping original spaces → canonical space.\n",
    "        Shapes (k, d_X) and (k, d_Y).\n",
    "    Xc, Yc : torch.Tensor\n",
    "        Compressed (SVD‑reduced) representations fed into CCA.\n",
    "    \"\"\"\n",
    "\n",
    "    device = X.device\n",
    "    assert X.shape[1] == Y.shape[1], \"X and Y must have the same #samples (columns)\"\n",
    "\n",
    "    # ---------- helper: SVD compression ----------\n",
    "    def _svd_compress(A: torch.Tensor, thresh: float):\n",
    "        A = A - A.mean(dim=1, keepdim=True)                       # centre\n",
    "        U, S, _ = torch.linalg.svd(A, full_matrices=False)\n",
    "        var = S ** 2\n",
    "        var = var / var.sum()\n",
    "        r = int((torch.cumsum(var, dim=0) < thresh).sum() + 1)    # minimal r\n",
    "        Ur = U[:, :r]                                             # (d, r)\n",
    "        Ar = Ur.T @ A                                             # (r, n)\n",
    "        return Ar, Ur\n",
    "\n",
    "    # ---------- compress each view ----------\n",
    "    Xc, Ux = _svd_compress(X, explained_var)\n",
    "    Yc, Uy = _svd_compress(Y, explained_var)\n",
    "\n",
    "    kx, n = Xc.shape\n",
    "    ky     = Yc.shape[0]\n",
    "\n",
    "    # ---------- covariances ----------\n",
    "    Cxx = (Xc @ Xc.T) / (n - 1) + eps * torch.eye(kx, device=device)\n",
    "    Cyy = (Yc @ Yc.T) / (n - 1) + eps * torch.eye(ky, device=device)\n",
    "    Cxy = (Xc @ Yc.T) / (n - 1)\n",
    "\n",
    "    # ---------- matrix inverse‑square‑root via eigendecomp ----------\n",
    "    def _inv_sqrtm(C):\n",
    "        eigval, eigvec = torch.linalg.eigh(C)           # SPD → real eigs\n",
    "        eigval = torch.clamp(eigval, min=eps)\n",
    "        return eigvec @ torch.diag(eigval.rsqrt()) @ eigvec.T\n",
    "\n",
    "    Cxx_inv_sqrt = _inv_sqrtm(Cxx)\n",
    "    Cyy_inv_sqrt = _inv_sqrtm(Cyy)\n",
    "\n",
    "    # ---------- CCA core ----------\n",
    "    T = Cxx_inv_sqrt @ Cxy @ Cyy_inv_sqrt\n",
    "    U, S, Vh = torch.linalg.svd(T, full_matrices=False)\n",
    "    rho = S                                   # canonical correlations\n",
    "\n",
    "    # ---------- back‑map CCA weights to original dims ----------\n",
    "    Wx = (U.T @ Cxx_inv_sqrt @ Ux.T)          # (k, d_X)\n",
    "    Wy = (Vh @ Cyy_inv_sqrt @ Uy.T)           # (k, d_Y)\n",
    "\n",
    "    return rho, Wx, Wy, Xc, Yc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f59263df-f89c-4631-b900-3b493167120b",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelA_name = 'gemma-2-2b'\n",
    "modelB_name = 'gemma-2-9b'\n",
    "#modelA_name = 'gpt2'\n",
    "#modelB_name = 'gpt2-medium'\n",
    "model_names = 'gemma-2'\n",
    "acts_A = {}\n",
    "for filename in os.listdir(f'activation_store/{modelA_name}'):\n",
    "    acts_A[int(filename.split('_')[0])] = torch.load(os.path.join('activation_store/', modelA_name, filename), weights_only=True)\n",
    "\n",
    "acts_B = {}\n",
    "for filename in os.listdir(f'activation_store/{modelB_name}'):\n",
    "    acts_B[int(filename.split('_')[0])] = torch.load(os.path.join('activation_store/', modelB_name, filename), weights_only=True).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06bbf8c6-7098-46a2-83a7-4cfcec8197ad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "corrs = {}\n",
    "for i, X in acts_A.items():\n",
    "    temp = {}\n",
    "    for j, Y in acts_B.items():\n",
    "        #mean_corr, cca_values = compute_svcca_similarity(X.cuda(), Y.cuda())\n",
    "        rho, _, _, _, _ = svcca_torch(X.cuda().T, Y.cuda().T)\n",
    "        mean_corr = rho.mean().item()\n",
    "        print(\"stop layer model A\", i, \"stop layer model B\", j, mean_corr)\n",
    "        temp[j] = mean_corr\n",
    "    corrs[i] = temp\n",
    "#print(f\"Mean SVCCA similarity: {mean_corr}\")\n",
    "#print(f\"Canonical correlations: {cca_values}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cb6cbe1-0fe8-4ee0-9bcf-433e2e785270",
   "metadata": {},
   "outputs": [],
   "source": [
    "# stop layer 13, stop layer 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f016c2d-df0c-4ddf-8242-82c5c1d3c127",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f737eafa-fae2-46f8-94d4-a87bbeacc761",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = plt.cm.viridis(np.linspace(0, 1, len(corrs)))  # Get num_lines discrete colors from viridis\n",
    "for post_layer in np.arange(len(corrs)):\n",
    "    i = post_layer + 1\n",
    "    corr_dict = corrs[i]\n",
    "    layers = []\n",
    "    b_corrs = []\n",
    "    for j, value in corr_dict.items():\n",
    "        layers.append(j)\n",
    "        b_corrs.append(value)\n",
    "    layers = np.array(layers)\n",
    "    b_corrs = np.array(b_corrs)\n",
    "    idxs = layers.argsort()\n",
    "    layers = layers[idxs]\n",
    "    b_corrs = b_corrs[idxs]\n",
    "    plt.plot(layers, b_corrs, label=f'{modelA_name}.{i}.pre', color=cmap[i-1])\n",
    "plt.xlabel(f'pre-layer in {modelB_name}')\n",
    "plt.ylabel('svcca')\n",
    "plt.legend(bbox_to_anchor=(1, 1.05))\n",
    "plt.savefig(f'results/figures/{model_names}_svcca.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "facb72ff-8253-47f0-8b6e-0fe94eaee305",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(corrs, f\"results/{model_names}_svcca_results.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
