{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b7x5VMJx-qtk"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import hf_hub_download, notebook_login\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kUYRF57KNxrL"
   },
   "source": [
    "# setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yNSK0f1A4sHu"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "from google.colab import files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ta39w_IC-sHo"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "%pip install sae-lens\n",
    "%pip install datasets==3.6.0\n",
    "%pip install --upgrade --force-reinstall numpy pandas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aI2j6zRBC2t8"
   },
   "outputs": [],
   "source": [
    "from sae_lens import SAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SzJHEAM2QUDS"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn, Tensor\n",
    "\n",
    "from functools import partial\n",
    "from jaxtyping import Float, Int\n",
    "from typing import Optional, Callable, Union, List, Tuple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ffi6_HJh_okl"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zOQZCTiMKqCe"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "from fnmatch import fnmatch\n",
    "from pathlib import Path\n",
    "from typing import NamedTuple, Optional, Callable, Union, List, Tuple\n",
    "# from jaxtyping import Float, Int\n",
    "from collections import Counter\n",
    "\n",
    "import einops\n",
    "import torch\n",
    "from torch import Tensor, nn\n",
    "from huggingface_hub import snapshot_download\n",
    "from natsort import natsorted\n",
    "from safetensors.torch import load_model, save_model\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YOYUhykrUbI8"
   },
   "source": [
    "## corr fns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SaotL1OHcZ3_"
   },
   "outputs": [],
   "source": [
    "def batched_correlation(reshaped_activations_A, reshaped_activations_B, batch_size=100):\n",
    "    # Ensure tensors are on GPU\n",
    "    if torch.cuda.is_available():\n",
    "        reshaped_activations_A = reshaped_activations_A.to('cuda')\n",
    "        reshaped_activations_B = reshaped_activations_B.to('cuda')\n",
    "\n",
    "    # Normalize columns of A\n",
    "    mean_A = reshaped_activations_A.mean(dim=0, keepdim=True)\n",
    "    std_A = reshaped_activations_A.std(dim=0, keepdim=True)\n",
    "    normalized_A = (reshaped_activations_A - mean_A) / (std_A + 1e-8)  # Avoid division by zero\n",
    "\n",
    "    # Normalize columns of B\n",
    "    mean_B = reshaped_activations_B.mean(dim=0, keepdim=True)\n",
    "    std_B = reshaped_activations_B.std(dim=0, keepdim=True)\n",
    "    normalized_B = (reshaped_activations_B - mean_B) / (std_B + 1e-8)  # Avoid division by zero\n",
    "\n",
    "    num_batches = (normalized_B.shape[1] + batch_size - 1) // batch_size\n",
    "    max_values = []\n",
    "    max_indices = []\n",
    "\n",
    "    for batch in range(num_batches):\n",
    "        start = batch * batch_size\n",
    "        end = min(start + batch_size, normalized_B.shape[1])\n",
    "        batch_corr_matrix = torch.matmul(normalized_A.t(), normalized_B[:, start:end]) / normalized_A.shape[0]\n",
    "        max_val, max_idx = batch_corr_matrix.max(dim=0)\n",
    "        max_values.append(max_val)\n",
    "        max_indices.append(max_idx)\n",
    "\n",
    "        del batch_corr_matrix\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    corr_inds = torch.cat(max_indices).detach().cpu().numpy()\n",
    "    corr_vals = torch.cat(max_values).detach().cpu().numpy()\n",
    "    return corr_inds, corr_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3DHUuR_RVCR1"
   },
   "outputs": [],
   "source": [
    "def filter_corr_pairs(mixed_modA_feats, mixed_modB_feats, kept_modA_feats):\n",
    "    filt_corr_ind_A = []\n",
    "    filt_corr_ind_B = []\n",
    "    seen = set()\n",
    "    for ind_A, ind_B in zip(mixed_modA_feats, mixed_modB_feats):\n",
    "        if ind_A in kept_modA_feats:\n",
    "            filt_corr_ind_A.append(ind_A)\n",
    "            filt_corr_ind_B.append(ind_B)\n",
    "        elif ind_A not in seen:  # only keep one if it's over count X\n",
    "            seen.add(ind_A)\n",
    "            filt_corr_ind_A.append(ind_A)\n",
    "            filt_corr_ind_B.append(ind_B)\n",
    "    num_unq_pairs = len(list(set(filt_corr_ind_A)))\n",
    "    print(\"% unique: \", num_unq_pairs / len(filt_corr_ind_A))\n",
    "    print(\"num 1-1 feats after filt: \", num_unq_pairs )\n",
    "    return filt_corr_ind_A, filt_corr_ind_B, num_unq_pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "l1Oa-qJbT1tw"
   },
   "outputs": [],
   "source": [
    "def get_new_mean_corr(modA_feats, modB_feats, corr_vals):\n",
    "    new_vals = []\n",
    "    seen = set()\n",
    "    for ind_A, ind_B in zip(modA_feats, modB_feats):\n",
    "        if ind_A not in seen:\n",
    "            seen.add(ind_A)\n",
    "            val = corr_vals[ind_B]\n",
    "            new_vals.append(val)\n",
    "    new_mean_corr = sum(new_vals) / len(new_vals)\n",
    "    # print(new_mean_corr)\n",
    "    return new_mean_corr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vlKdEehFvC86"
   },
   "source": [
    "## sim fns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VyutL2SbeZtw"
   },
   "source": [
    "The following functions are from: https://github.com/mklabunde/resi/tree/main/repsim/measures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Rl7IYESN1irP"
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "from typing import Any, Callable, Dict, List, Tuple, Union\n",
    "\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import torch\n",
    "\n",
    "\n",
    "def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:\n",
    "    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:\n",
    "        return x if isinstance(x, np.ndarray) else x.numpy()\n",
    "\n",
    "    return list(map(convert, args))\n",
    "\n",
    "\n",
    "def to_torch_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[torch.Tensor]:\n",
    "    def convert(x: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:\n",
    "        return x if isinstance(x, torch.Tensor) else torch.from_numpy(x)\n",
    "\n",
    "    return list(map(convert, args))\n",
    "\n",
    "\n",
    "def adjust_dimensionality(\n",
    "    R: npt.NDArray, Rp: npt.NDArray, strategy=\"zero_pad\"\n",
    ") -> Tuple[npt.NDArray, npt.NDArray]:\n",
    "    D = R.shape[1]\n",
    "    Dp = Rp.shape[1]\n",
    "    if strategy == \"zero_pad\":\n",
    "        if D - Dp == 0:\n",
    "            return R, Rp\n",
    "        elif D - Dp > 0:\n",
    "            return R, np.concatenate((Rp, np.zeros((Rp.shape[0], D - Dp))), axis=1)\n",
    "        else:\n",
    "            return np.concatenate((R, np.zeros((R.shape[0], Dp - D))), axis=1), Rp\n",
    "    else:\n",
    "        raise NotImplementedError()\n",
    "\n",
    "\n",
    "def center_columns(R: npt.NDArray) -> npt.NDArray:\n",
    "    return R - R.mean(axis=0)[None, :]\n",
    "\n",
    "\n",
    "def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:\n",
    "    return R / np.linalg.norm(R, ord=\"fro\")\n",
    "\n",
    "\n",
    "def sim_random_baseline(\n",
    "    rep1: torch.Tensor, rep2: torch.Tensor, sim_func: Callable, n_permutations: int = 10\n",
    ") -> Dict[str, Any]:\n",
    "    torch.manual_seed(1234)\n",
    "    scores = []\n",
    "    for _ in range(n_permutations):\n",
    "        perm = torch.randperm(rep1.size(0))\n",
    "\n",
    "        score = sim_func(rep1[perm, :], rep2)\n",
    "        score = score if isinstance(score, float) else score[\"score\"]\n",
    "\n",
    "        scores.append(score)\n",
    "\n",
    "    return {\"baseline_scores\": np.array(scores)}\n",
    "\n",
    "\n",
    "class Pipeline:\n",
    "    def __init__(\n",
    "        self,\n",
    "        preprocess_funcs: List[Callable[[npt.NDArray], npt.NDArray]],\n",
    "        similarity_func: Callable[[npt.NDArray, npt.NDArray], Dict[str, Any]],\n",
    "    ) -> None:\n",
    "        self.preprocess_funcs = preprocess_funcs\n",
    "        self.similarity_func = similarity_func\n",
    "\n",
    "    def __call__(self, R: npt.NDArray, Rp: npt.NDArray) -> Dict[str, Any]:\n",
    "        for preprocess_func in self.preprocess_funcs:\n",
    "            R = preprocess_func(R)\n",
    "            Rp = preprocess_func(Rp)\n",
    "        return self.similarity_func(R, Rp)\n",
    "\n",
    "    def __str__(self) -> str:\n",
    "        def func_name(func: Callable) -> str:\n",
    "            return (\n",
    "                func.__name__\n",
    "                if not isinstance(func, functools.partial)\n",
    "                else func.func.__name__\n",
    "            )\n",
    "\n",
    "        def partial_keywords(func: Callable) -> str:\n",
    "            if not isinstance(func, functools.partial):\n",
    "                return \"\"\n",
    "            else:\n",
    "                return str(func.keywords)\n",
    "\n",
    "        return (\n",
    "            \"Pipeline(\"\n",
    "            + (\n",
    "                \"+\".join(map(func_name, self.preprocess_funcs))\n",
    "                + \"+\"\n",
    "                + func_name(self.similarity_func)\n",
    "                + partial_keywords(self.similarity_func)\n",
    "            )\n",
    "            + \")\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RgcXEjdcXAOj"
   },
   "outputs": [],
   "source": [
    "from typing import List, Set, Union\n",
    "\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import sklearn.neighbors\n",
    "import torch\n",
    "\n",
    "# from llmcomp.measures.utils import to_numpy_if_needed\n",
    "\n",
    "\n",
    "def _jac_sim_i(idx_R: Set[int], idx_Rp: Set[int]) -> float:\n",
    "    return len(idx_R.intersection(idx_Rp)) / len(idx_R.union(idx_Rp))\n",
    "\n",
    "\n",
    "def jaccard_similarity(\n",
    "    R: Union[torch.Tensor, npt.NDArray],\n",
    "    Rp: Union[torch.Tensor, npt.NDArray],\n",
    "    k: int = 10,\n",
    "    inner: str = \"cosine\",\n",
    "    n_jobs: int = 8,\n",
    ") -> float:\n",
    "    R, Rp = to_numpy_if_needed(R, Rp)\n",
    "\n",
    "    indices_R = nn_array_to_setlist(top_k_neighbors(R, k, inner, n_jobs))\n",
    "    indices_Rp = nn_array_to_setlist(top_k_neighbors(Rp, k, inner, n_jobs))\n",
    "\n",
    "    return float(\n",
    "        np.mean(\n",
    "            [_jac_sim_i(idx_R, idx_Rp) for idx_R, idx_Rp in zip(indices_R, indices_Rp)]\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "def top_k_neighbors(\n",
    "    R: npt.NDArray,\n",
    "    k: int,\n",
    "    inner: str,\n",
    "    n_jobs: int,\n",
    ") -> npt.NDArray:\n",
    "    # k+1 nearest neighbors, because we pass in all the data, which means that a point\n",
    "    # will be the nearest neighbor to itself. We remove this point from the results and\n",
    "    # report only the k nearest neighbors distinct from the point itself.\n",
    "    nns = sklearn.neighbors.NearestNeighbors(\n",
    "        n_neighbors=k + 1, metric=inner, n_jobs=n_jobs\n",
    "    )\n",
    "    nns.fit(R)\n",
    "    _, nns = nns.kneighbors(R)\n",
    "    return nns[:, 1:]\n",
    "\n",
    "\n",
    "def nn_array_to_setlist(nn: npt.NDArray) -> List[Set[int]]:\n",
    "    return [set(idx) for idx in nn]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K8QY53-0umRk"
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "import logging\n",
    "from abc import ABC\n",
    "from abc import abstractmethod\n",
    "from dataclasses import dataclass\n",
    "from dataclasses import field\n",
    "from typing import Any\n",
    "from typing import Callable\n",
    "from typing import get_args\n",
    "from typing import List\n",
    "from typing import Literal\n",
    "from typing import Optional\n",
    "from typing import Protocol\n",
    "from typing import Tuple\n",
    "from typing import Union\n",
    "\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import torch\n",
    "from einops import rearrange\n",
    "# from loguru import logger\n",
    "\n",
    "log = logging.getLogger(__name__)\n",
    "\n",
    "\n",
    "SHAPE_TYPE = Literal[\"nd\", \"ntd\", \"nchw\"]\n",
    "\n",
    "ND_SHAPE, NTD_SHAPE, NCHW_SHAPE = get_args(SHAPE_TYPE)[0], get_args(SHAPE_TYPE)[1], get_args(SHAPE_TYPE)[2]\n",
    "\n",
    "\n",
    "class SimilarityFunction(Protocol):\n",
    "    def __call__(  # noqa: E704\n",
    "        self,\n",
    "        R: torch.Tensor | npt.NDArray,\n",
    "        Rp: torch.Tensor | npt.NDArray,\n",
    "        shape: SHAPE_TYPE,\n",
    "    ) -> float: ...\n",
    "\n",
    "\n",
    "class RSMSimilarityFunction(Protocol):\n",
    "    def __call__(  # noqa: E704\n",
    "        self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE, n_jobs: int\n",
    "    ) -> float: ...\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class BaseSimilarityMeasure(ABC):\n",
    "    larger_is_more_similar: bool\n",
    "    is_symmetric: bool\n",
    "\n",
    "    is_metric: bool | None = None\n",
    "    invariant_to_affine: bool | None = None\n",
    "    invariant_to_invertible_linear: bool | None = None\n",
    "    invariant_to_ortho: bool | None = None\n",
    "    invariant_to_permutation: bool | None = None\n",
    "    invariant_to_isotropic_scaling: bool | None = None\n",
    "    invariant_to_translation: bool | None = None\n",
    "    name: str = field(init=False)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        self.name = self.__class__.__name__\n",
    "\n",
    "    @abstractmethod\n",
    "    def __call__(self, *args: Any, **kwds: Any) -> Any:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "class FunctionalSimilarityMeasure(BaseSimilarityMeasure):\n",
    "    @abstractmethod\n",
    "    def __call__(self, output_a: torch.Tensor | npt.NDArray, output_b: torch.Tensor | npt.NDArray) -> float:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "@dataclass(kw_only=True)\n",
    "class RepresentationalSimilarityMeasure(BaseSimilarityMeasure):\n",
    "    sim_func: SimilarityFunction\n",
    "\n",
    "    def __call__(\n",
    "        self,\n",
    "        R: torch.Tensor | npt.NDArray,\n",
    "        Rp: torch.Tensor | npt.NDArray,\n",
    "        shape: SHAPE_TYPE,\n",
    "    ) -> float:\n",
    "        return self.sim_func(R, Rp, shape)\n",
    "\n",
    "\n",
    "class RSMSimilarityMeasure(RepresentationalSimilarityMeasure):\n",
    "    sim_func: RSMSimilarityFunction\n",
    "\n",
    "    @staticmethod\n",
    "    def estimate_good_number_of_jobs(R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray) -> int:\n",
    "        # RSMs in are NxN (or DxD) so the number of jobs should roughly scale quadratically with increase in N (or D).\n",
    "        # False! As long as sklearn-native metrics are used, they will use parallel implementations regardless of job\n",
    "        # count. Each job would spawn their own threads, which leads to oversubscription of cores and thus slowdown.\n",
    "        # This seems to be not fully correct (n_jobs=2 seems to actually use two cores), but using n_jobs=1 seems the\n",
    "        # fastest.\n",
    "        return 1\n",
    "\n",
    "    def __call__(\n",
    "        self,\n",
    "        R: torch.Tensor | npt.NDArray,\n",
    "        Rp: torch.Tensor | npt.NDArray,\n",
    "        shape: SHAPE_TYPE,\n",
    "        n_jobs: Optional[int] = None,\n",
    "    ) -> float:\n",
    "        if n_jobs is None:\n",
    "            n_jobs = self.estimate_good_number_of_jobs(R, Rp)\n",
    "        return self.sim_func(R, Rp, shape, n_jobs=n_jobs)\n",
    "\n",
    "\n",
    "def to_numpy_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[npt.NDArray]:\n",
    "    def convert(x: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:\n",
    "        return x if isinstance(x, np.ndarray) else x.numpy()\n",
    "\n",
    "    return list(map(convert, args))\n",
    "\n",
    "\n",
    "def to_torch_if_needed(*args: Union[torch.Tensor, npt.NDArray]) -> List[torch.Tensor]:\n",
    "    def convert(x: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:\n",
    "        return x if isinstance(x, torch.Tensor) else torch.from_numpy(x)\n",
    "\n",
    "    return list(map(convert, args))\n",
    "\n",
    "\n",
    "def adjust_dimensionality(R: npt.NDArray, Rp: npt.NDArray, strategy=\"zero_pad\") -> Tuple[npt.NDArray, npt.NDArray]:\n",
    "    D = R.shape[1]\n",
    "    Dp = Rp.shape[1]\n",
    "    if strategy == \"zero_pad\":\n",
    "        if D - Dp == 0:\n",
    "            return R, Rp\n",
    "        elif D - Dp > 0:\n",
    "            return R, np.concatenate((Rp, np.zeros((Rp.shape[0], D - Dp))), axis=1)\n",
    "        else:\n",
    "            return np.concatenate((R, np.zeros((R.shape[0], Dp - D))), axis=1), Rp\n",
    "    else:\n",
    "        raise NotImplementedError()\n",
    "\n",
    "\n",
    "def center_columns(R: npt.NDArray) -> npt.NDArray:\n",
    "    return R - R.mean(axis=0)[None, :]\n",
    "\n",
    "\n",
    "def normalize_matrix_norm(R: npt.NDArray) -> npt.NDArray:\n",
    "    return R / np.linalg.norm(R, ord=\"fro\")\n",
    "\n",
    "\n",
    "def normalize_row_norm(R: npt.NDArray) -> npt.NDArray:\n",
    "    return R / np.linalg.norm(R, ord=2, axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "def standardize(R: npt.NDArray) -> npt.NDArray:\n",
    "    return (R - R.mean(axis=0, keepdims=True)) / R.std(axis=0)\n",
    "\n",
    "\n",
    "def double_center(x: npt.NDArray) -> npt.NDArray:\n",
    "    return x - x.mean(axis=0, keepdims=True) - x.mean(axis=1, keepdims=True) + x.mean()\n",
    "\n",
    "\n",
    "def align_spatial_dimensions(R: npt.NDArray, Rp: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]:\n",
    "    \"\"\"\n",
    "    Aligns spatial representations by resizing them to the smallest spatial dimension.\n",
    "    Subsequent aligned spatial representations are flattened, with the spatial aligned representations\n",
    "    moving into the *sample* dimension.\n",
    "    \"\"\"\n",
    "    R_re, Rp_re = resize_wh_reps(R, Rp)\n",
    "    R_re = rearrange(R_re, \"n c h w -> (n h w) c\")\n",
    "    Rp_re = rearrange(Rp_re, \"n c h w -> (n h w) c\")\n",
    "    if R_re.shape[0] > 5000:\n",
    "        logger.info(f\"Got {R_re.shape[0]} samples in N after flattening. Subsampling to reduce compute.\")\n",
    "        subsample = R_re.shape[0] // 5000\n",
    "        R_re = R_re[::subsample]\n",
    "        Rp_re = Rp_re[::subsample]\n",
    "\n",
    "    return R_re, Rp_re\n",
    "\n",
    "\n",
    "def average_pool_downsample(R, resize: bool, new_size: tuple[int, int]):\n",
    "    if not resize:\n",
    "        return R  # do nothing\n",
    "    else:\n",
    "        is_numpy = isinstance(R, np.ndarray)\n",
    "        R_torch = torch.from_numpy(R) if is_numpy else R\n",
    "        R_torch = torch.nn.functional.adaptive_avg_pool2d(R_torch, new_size)\n",
    "        return R_torch.numpy() if is_numpy else R_torch\n",
    "\n",
    "\n",
    "def resize_wh_reps(R: npt.NDArray, Rp: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]:\n",
    "    \"\"\"\n",
    "    Function for resizing spatial representations that are not the same size.\n",
    "    Does through fourier transform and resizing.\n",
    "\n",
    "    Args:\n",
    "        R: numpy array of shape  [batch_size, height, width, num_channels]\n",
    "        RP: numpy array of shape [batch_size, height, width, num_channels]\n",
    "\n",
    "    Returns:\n",
    "        fft_acts1: numpy array of shape [batch_size, (new) height, (new) width, num_channels]\n",
    "        fft_acts2: numpy array of shape [batch_size, (new) height, (new) width, num_channels]\n",
    "\n",
    "    \"\"\"\n",
    "    height1, width1 = R.shape[2], R.shape[3]\n",
    "    height2, width2 = Rp.shape[2], Rp.shape[3]\n",
    "    if height1 != height2 or width1 != width2:\n",
    "        height = min(height1, height2)\n",
    "        width = min(width1, width2)\n",
    "        new_size = [height, width]\n",
    "        resize = True\n",
    "    else:\n",
    "        height = height1\n",
    "        width = width1\n",
    "        new_size = None\n",
    "        resize = False\n",
    "\n",
    "    # resize and preprocess with fft\n",
    "    avg_ds1 = average_pool_downsample(R, resize=resize, new_size=new_size)\n",
    "    avg_ds2 = average_pool_downsample(Rp, resize=resize, new_size=new_size)\n",
    "    return avg_ds1, avg_ds2\n",
    "\n",
    "\n",
    "def fft_resize(images, resize=False, new_size=None):\n",
    "    \"\"\"Function for applying DFT and resizing.\n",
    "\n",
    "    This function takes in an array of images, applies the 2-d fourier transform\n",
    "    and resizes them according to new_size, keeping the frequencies that overlap\n",
    "    between the two sizes.\n",
    "\n",
    "    Args:\n",
    "              images: a numpy array with shape\n",
    "                      [batch_size, height, width, num_channels]\n",
    "              resize: boolean, whether or not to resize\n",
    "              new_size: a tuple (size, size), with height and width the same\n",
    "\n",
    "    Returns:\n",
    "              im_fft_downsampled: a numpy array with shape\n",
    "                           [batch_size, (new) height, (new) width, num_channels]\n",
    "    \"\"\"\n",
    "    assert len(images.shape) == 4, \"expecting images to be\" \"[batch_size, height, width, num_channels]\"\n",
    "    if resize:\n",
    "        # FFT --> remove high frequencies --> inverse FFT\n",
    "        im_complex = images.astype(\"complex64\")\n",
    "        im_fft = np.fft.fft2(im_complex, axes=(1, 2))\n",
    "        im_shifted = np.fft.fftshift(im_fft, axes=(1, 2))\n",
    "\n",
    "        center_width = im_shifted.shape[2] // 2\n",
    "        center_height = im_shifted.shape[1] // 2\n",
    "        half_w = new_size[0] // 2\n",
    "        half_h = new_size[1] // 2\n",
    "        cropped_fft = im_shifted[\n",
    "            :, center_height - half_h : center_height + half_h, center_width - half_w : center_width + half_w, :\n",
    "        ]\n",
    "        cropped_fft_shifted_back = np.fft.ifft2(cropped_fft, axes=(1, 2))\n",
    "        return cropped_fft_shifted_back.real\n",
    "    else:\n",
    "        return images\n",
    "\n",
    "\n",
    "class Pipeline:\n",
    "    def __init__(\n",
    "        self,\n",
    "        preprocess_funcs: List[Callable[[npt.NDArray], npt.NDArray]],\n",
    "        similarity_func: Callable[[npt.NDArray, npt.NDArray, SHAPE_TYPE], float],\n",
    "    ) -> None:\n",
    "        self.preprocess_funcs = preprocess_funcs\n",
    "        self.similarity_func = similarity_func\n",
    "\n",
    "    def __call__(self, R: npt.NDArray, Rp: npt.NDArray, shape: SHAPE_TYPE) -> float:\n",
    "        try:\n",
    "            for preprocess_func in self.preprocess_funcs:\n",
    "                R = preprocess_func(R)\n",
    "                Rp = preprocess_func(Rp)\n",
    "            return self.similarity_func(R, Rp, shape)\n",
    "        except ValueError as e:\n",
    "            log.info(f\"Pipeline failed: {e}\")\n",
    "            return np.nan\n",
    "\n",
    "    def __str__(self) -> str:\n",
    "        def func_name(func: Callable) -> str:\n",
    "            return func.__name__ if not isinstance(func, functools.partial) else func.func.__name__\n",
    "\n",
    "        def partial_keywords(func: Callable) -> str:\n",
    "            if not isinstance(func, functools.partial):\n",
    "                return \"\"\n",
    "            else:\n",
    "                return str(func.keywords)\n",
    "\n",
    "        return (\n",
    "            \"Pipeline(\"\n",
    "            + (\n",
    "                \"+\".join(map(func_name, self.preprocess_funcs))\n",
    "                + \"+\"\n",
    "                + func_name(self.similarity_func)\n",
    "                + partial_keywords(self.similarity_func)\n",
    "            )\n",
    "            + \")\"\n",
    "        )\n",
    "\n",
    "\n",
    "def flatten(*args: Union[torch.Tensor, npt.NDArray], shape: SHAPE_TYPE) -> List[Union[torch.Tensor, npt.NDArray]]:\n",
    "    if shape == \"ntd\":\n",
    "        return list(map(flatten_nxtxd_to_ntxd, args))\n",
    "    elif shape == \"nd\":\n",
    "        return list(args)\n",
    "    elif shape == \"nchw\":\n",
    "        return list(map(flatten_nxcxhxw_to_nxchw, args))  # Flattening non-trivial for nchw\n",
    "    else:\n",
    "        raise ValueError(\"Unknown shape of representations. Must be one of 'ntd', 'nchw', 'nd'.\")\n",
    "\n",
    "\n",
    "def flatten_nxtxd_to_ntxd(R: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:\n",
    "    R = to_torch_if_needed(R)[0]\n",
    "    log.debug(\"Shape before flattening: %s\", str(R.shape))\n",
    "    R = torch.flatten(R, start_dim=0, end_dim=1)\n",
    "    log.debug(\"Shape after flattening: %s\", str(R.shape))\n",
    "    return R\n",
    "\n",
    "\n",
    "def flatten_nxcxhxw_to_nxchw(R: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:\n",
    "    R = to_torch_if_needed(R)[0]\n",
    "    log.debug(\"Shape before flattening: %s\", str(R.shape))\n",
    "    R = torch.reshape(R, (R.shape[0], -1))\n",
    "    log.debug(\"Shape after flattening: %s\", str(R.shape))\n",
    "    return R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "esinf2X00sPl"
   },
   "outputs": [],
   "source": [
    "import scipy.optimize\n",
    "\n",
    "def permutation_procrustes(\n",
    "    R: Union[torch.Tensor, npt.NDArray],\n",
    "    Rp: Union[torch.Tensor, npt.NDArray],\n",
    "    shape: SHAPE_TYPE,\n",
    "    optimal_permutation_alignment: Optional[Tuple[npt.NDArray, npt.NDArray]] = None,\n",
    ") -> float:\n",
    "    R, Rp = flatten(R, Rp, shape=shape)\n",
    "    R, Rp = to_numpy_if_needed(R, Rp)\n",
    "    R, Rp = adjust_dimensionality(R, Rp)\n",
    "\n",
    "    if not optimal_permutation_alignment:\n",
    "        PR, PRp = scipy.optimize.linear_sum_assignment(R.T @ Rp, maximize=True)  # returns column assignments\n",
    "        optimal_permutation_alignment = (PR, PRp)\n",
    "    PR, PRp = optimal_permutation_alignment\n",
    "    return float(np.linalg.norm(R[:, PR] - Rp[:, PRp], ord=\"fro\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MU_QCO_UvFKl"
   },
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "from typing import Union\n",
    "\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import scipy.spatial.distance\n",
    "import scipy.stats\n",
    "import sklearn.metrics\n",
    "import torch\n",
    "# from repsim.measures.utils import flatten\n",
    "# from repsim.measures.utils import RSMSimilarityMeasure\n",
    "# from repsim.measures.utils import SHAPE_TYPE\n",
    "# from repsim.measures.utils import to_numpy_if_needed\n",
    "\n",
    "\n",
    "def representational_similarity_analysis(\n",
    "    R: Union[torch.Tensor, npt.NDArray],\n",
    "    Rp: Union[torch.Tensor, npt.NDArray],\n",
    "    shape: SHAPE_TYPE,\n",
    "    inner=\"correlation\",\n",
    "    outer=\"spearman\",\n",
    "    n_jobs: Optional[int] = None,\n",
    ") -> float:\n",
    "    \"\"\"Representational similarity analysis\n",
    "\n",
    "    Args:\n",
    "        R (Union[torch.Tensor, npt.NDArray]): N x D representation\n",
    "        Rp (Union[torch.Tensor, npt.NDArray]): N x D' representation\n",
    "        inner (str, optional): inner similarity function for RSM. Must be one of\n",
    "            scipy.spatial.distance.pdist identifiers . Defaults to \"correlation\".\n",
    "        outer (str, optional): outer similarity function that compares RSMs. Defaults to\n",
    "             \"spearman\". Must be one of \"spearman\", \"euclidean\"\n",
    "\n",
    "    Returns:\n",
    "        float: _description_\n",
    "    \"\"\"\n",
    "    R, Rp = flatten(R, Rp, shape=shape)\n",
    "    R, Rp = to_numpy_if_needed(R, Rp)\n",
    "\n",
    "    if inner == \"correlation\":\n",
    "        # n_jobs only works if metric is in PAIRWISE_DISTANCES as defined in sklearn, i.e., not for correlation.\n",
    "        # But correlation = 1 - cosine dist of row-centered data, so we use the faster cosine metric and center the data.\n",
    "        R = R - R.mean(axis=1, keepdims=True)\n",
    "        S = scipy.spatial.distance.squareform(  # take the lower triangle of RSM\n",
    "            1 - sklearn.metrics.pairwise_distances(R, metric=\"cosine\", n_jobs=n_jobs),  # type:ignore\n",
    "            checks=False,\n",
    "        )\n",
    "        Rp = Rp - Rp.mean(axis=1, keepdims=True)\n",
    "        Sp = scipy.spatial.distance.squareform(\n",
    "            1 - sklearn.metrics.pairwise_distances(Rp, metric=\"cosine\", n_jobs=n_jobs),  # type:ignore\n",
    "            checks=False,\n",
    "        )\n",
    "    elif inner == \"euclidean\":\n",
    "        # take the lower triangle of RSM\n",
    "        S = scipy.spatial.distance.squareform(\n",
    "            sklearn.metrics.pairwise_distances(R, metric=inner, n_jobs=n_jobs), checks=False\n",
    "        )\n",
    "        Sp = scipy.spatial.distance.squareform(\n",
    "            sklearn.metrics.pairwise_distances(Rp, metric=inner, n_jobs=n_jobs), checks=False\n",
    "        )\n",
    "    else:\n",
    "        raise NotImplementedError(f\"{inner=}\")\n",
    "\n",
    "    if outer == \"spearman\":\n",
    "        return scipy.stats.spearmanr(S, Sp).statistic  # type:ignore\n",
    "    elif outer == \"euclidean\":\n",
    "        return float(np.linalg.norm(S - Sp, ord=2))\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown outer similarity function: {outer}\")\n",
    "\n",
    "\n",
    "class RSA(RSMSimilarityMeasure):\n",
    "    def __init__(self):\n",
    "        # choice of inner/outer in __call__ if fixed to default values, so these values are always the same\n",
    "        super().__init__(\n",
    "            sim_func=representational_similarity_analysis,\n",
    "            larger_is_more_similar=True,\n",
    "            is_metric=False,\n",
    "            is_symmetric=True,\n",
    "            invariant_to_affine=False,\n",
    "            invariant_to_invertible_linear=False,\n",
    "            invariant_to_ortho=False,\n",
    "            invariant_to_permutation=True,\n",
    "            invariant_to_isotropic_scaling=True,\n",
    "            invariant_to_translation=True,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LO8o8I5owA7p"
   },
   "outputs": [],
   "source": [
    "##################################################################################\n",
    "# Copied from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/cca_core.py\n",
    "# Copyright 2018 Google Inc.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\"\"\"\n",
    "The core code for applying Canonical Correlation Analysis to deep networks.\n",
    "\n",
    "This module contains the core functions to apply canonical correlation analysis\n",
    "to deep neural networks. The main function is get_cca_similarity, which takes in\n",
    "two sets of activations, typically the neurons in two layers and their outputs\n",
    "on all of the datapoints D = [d_1,...,d_m] that have been passed through.\n",
    "\n",
    "Inputs have shape (num_neurons1, m), (num_neurons2, m). This can be directly\n",
    "applied used on fully connected networks. For convolutional layers, the 3d block\n",
    "of neurons can either be flattened entirely, along channels, or alternatively,\n",
    "the dft_ccas (Discrete Fourier Transform) module can be used.\n",
    "\n",
    "See:\n",
    "https://arxiv.org/abs/1706.05806\n",
    "https://arxiv.org/abs/1806.05759\n",
    "for full details.\n",
    "\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "# from repsim.measures.utils import align_spatial_dimensions\n",
    "\n",
    "num_cca_trials = 5\n",
    "\n",
    "\n",
    "def positivedef_matrix_sqrt(array):\n",
    "    \"\"\"Stable method for computing matrix square roots, supports complex matrices.\n",
    "\n",
    "    Args:\n",
    "              array: A numpy 2d array, can be complex valued that is a positive\n",
    "                     definite symmetric (or hermitian) matrix\n",
    "\n",
    "    Returns:\n",
    "              sqrtarray: The matrix square root of array\n",
    "    \"\"\"\n",
    "    w, v = np.linalg.eigh(array)\n",
    "    #  A - np.dot(v, np.dot(np.diag(w), v.T))\n",
    "    wsqrt = np.sqrt(w)\n",
    "    sqrtarray = np.dot(v, np.dot(np.diag(wsqrt), np.conj(v).T))\n",
    "    return sqrtarray\n",
    "\n",
    "\n",
    "def remove_small(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon):\n",
    "    \"\"\"Takes covariance between X, Y, and removes values of small magnitude.\n",
    "\n",
    "    Args:\n",
    "              sigma_xx: 2d numpy array, variance matrix for x\n",
    "              sigma_xy: 2d numpy array, crossvariance matrix for x,y\n",
    "              sigma_yx: 2d numpy array, crossvariance matrixy for x,y,\n",
    "                        (conjugate) transpose of sigma_xy\n",
    "              sigma_yy: 2d numpy array, variance matrix for y\n",
    "              epsilon : cutoff value for norm below which directions are thrown\n",
    "                         away\n",
    "\n",
    "    Returns:\n",
    "              sigma_xx_crop: 2d array with low x norm directions removed\n",
    "              sigma_xy_crop: 2d array with low x and y norm directions removed\n",
    "              sigma_yx_crop: 2d array with low x and y norm directiosn removed\n",
    "              sigma_yy_crop: 2d array with low y norm directions removed\n",
    "              x_idxs: indexes of sigma_xx that were removed\n",
    "              y_idxs: indexes of sigma_yy that were removed\n",
    "    \"\"\"\n",
    "\n",
    "    x_diag = np.abs(np.diagonal(sigma_xx))\n",
    "    y_diag = np.abs(np.diagonal(sigma_yy))\n",
    "    x_idxs = x_diag >= epsilon\n",
    "    y_idxs = y_diag >= epsilon\n",
    "\n",
    "    sigma_xx_crop = sigma_xx[x_idxs][:, x_idxs]\n",
    "    sigma_xy_crop = sigma_xy[x_idxs][:, y_idxs]\n",
    "    sigma_yx_crop = sigma_yx[y_idxs][:, x_idxs]\n",
    "    sigma_yy_crop = sigma_yy[y_idxs][:, y_idxs]\n",
    "\n",
    "    return (sigma_xx_crop, sigma_xy_crop, sigma_yx_crop, sigma_yy_crop, x_idxs, y_idxs)\n",
    "\n",
    "\n",
    "def compute_ccas(sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon, verbose=True):\n",
    "    \"\"\"Main cca computation function, takes in variances and crossvariances.\n",
    "\n",
    "    This function takes in the covariances and cross covariances of X, Y,\n",
    "    preprocesses them (removing small magnitudes) and outputs the raw results of\n",
    "    the cca computation, including cca directions in a rotated space, and the\n",
    "    cca correlation coefficient values.\n",
    "\n",
    "    Args:\n",
    "              sigma_xx: 2d numpy array, (num_neurons_x, num_neurons_x)\n",
    "                        variance matrix for x\n",
    "              sigma_xy: 2d numpy array, (num_neurons_x, num_neurons_y)\n",
    "                        crossvariance matrix for x,y\n",
    "              sigma_yx: 2d numpy array, (num_neurons_y, num_neurons_x)\n",
    "                        crossvariance matrix for x,y (conj) transpose of sigma_xy\n",
    "              sigma_yy: 2d numpy array, (num_neurons_y, num_neurons_y)\n",
    "                        variance matrix for y\n",
    "              epsilon:  small float to help with stabilizing computations\n",
    "              verbose:  boolean on whether to print intermediate outputs\n",
    "\n",
    "    Returns:\n",
    "              [ux, sx, vx]: [numpy 2d array, numpy 1d array, numpy 2d array]\n",
    "                            ux and vx are (conj) transposes of each other, being\n",
    "                            the canonical directions in the X subspace.\n",
    "                            sx is the set of canonical correlation coefficients-\n",
    "                            how well corresponding directions in vx, Vy correlate\n",
    "                            with each other.\n",
    "              [uy, sy, vy]: Same as above, but for Y space\n",
    "              invsqrt_xx:   Inverse square root of sigma_xx to transform canonical\n",
    "                            directions back to original space\n",
    "              invsqrt_yy:   Same as above but for sigma_yy\n",
    "              x_idxs:       The indexes of the input sigma_xx that were pruned\n",
    "                            by remove_small\n",
    "              y_idxs:       Same as above but for sigma_yy\n",
    "    \"\"\"\n",
    "\n",
    "    (sigma_xx, sigma_xy, sigma_yx, sigma_yy, x_idxs, y_idxs) = remove_small(\n",
    "        sigma_xx, sigma_xy, sigma_yx, sigma_yy, epsilon\n",
    "    )\n",
    "\n",
    "    numx = sigma_xx.shape[0]\n",
    "    numy = sigma_yy.shape[0]\n",
    "\n",
    "    if numx == 0 or numy == 0:\n",
    "        return (\n",
    "            [0, 0, 0],\n",
    "            [0, 0, 0],\n",
    "            np.zeros_like(sigma_xx),\n",
    "            np.zeros_like(sigma_yy),\n",
    "            x_idxs,\n",
    "            y_idxs,\n",
    "        )\n",
    "\n",
    "    if verbose:\n",
    "        print(\"adding eps to diagonal and taking inverse\")\n",
    "    sigma_xx += epsilon * np.eye(numx)\n",
    "    sigma_yy += epsilon * np.eye(numy)\n",
    "    inv_xx = np.linalg.pinv(sigma_xx)\n",
    "    inv_yy = np.linalg.pinv(sigma_yy)\n",
    "\n",
    "    if verbose:\n",
    "        print(\"taking square root\")\n",
    "    invsqrt_xx = positivedef_matrix_sqrt(inv_xx)\n",
    "    invsqrt_yy = positivedef_matrix_sqrt(inv_yy)\n",
    "\n",
    "    if verbose:\n",
    "        print(\"dot products...\")\n",
    "    arr = np.dot(invsqrt_xx, np.dot(sigma_xy, invsqrt_yy))\n",
    "\n",
    "    if verbose:\n",
    "        print(\"trying to take final svd\")\n",
    "    u, s, v = np.linalg.svd(arr)\n",
    "\n",
    "    if verbose:\n",
    "        print(\"computed everything!\")\n",
    "\n",
    "    return [u, np.abs(s), v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs\n",
    "\n",
    "\n",
    "def sum_threshold(array, threshold):\n",
    "    \"\"\"Computes threshold index of decreasing nonnegative array by summing.\n",
    "\n",
    "    This function takes in a decreasing array nonnegative floats, and a\n",
    "    threshold between 0 and 1. It returns the index i at which the sum of the\n",
    "    array up to i is threshold*total mass of the array.\n",
    "\n",
    "    Args:\n",
    "              array: a 1d numpy array of decreasing, nonnegative floats\n",
    "              threshold: a number between 0 and 1\n",
    "\n",
    "    Returns:\n",
    "              i: index at which np.sum(array[:i]) >= threshold\n",
    "    \"\"\"\n",
    "    assert (threshold >= 0) and (threshold <= 1), \"print incorrect threshold\"\n",
    "\n",
    "    for i in range(len(array)):\n",
    "        if np.sum(array[:i]) / np.sum(array) >= threshold:\n",
    "            return i\n",
    "\n",
    "\n",
    "def create_zero_dict(compute_dirns, dimension):\n",
    "    \"\"\"Outputs a zero dict when neuron activation norms too small.\n",
    "\n",
    "    This function creates a return_dict with appropriately shaped zero entries\n",
    "    when all neuron activations are very small.\n",
    "\n",
    "    Args:\n",
    "              compute_dirns: boolean, whether to have zero vectors for directions\n",
    "              dimension: int, defines shape of directions\n",
    "\n",
    "    Returns:\n",
    "              return_dict: a dict of appropriately shaped zero entries\n",
    "    \"\"\"\n",
    "    return_dict = {}\n",
    "    return_dict[\"mean\"] = (np.asarray(0), np.asarray(0))\n",
    "    return_dict[\"sum\"] = (np.asarray(0), np.asarray(0))\n",
    "    return_dict[\"cca_coef1\"] = np.asarray(0)\n",
    "    return_dict[\"cca_coef2\"] = np.asarray(0)\n",
    "    return_dict[\"idx1\"] = 0\n",
    "    return_dict[\"idx2\"] = 0\n",
    "\n",
    "    if compute_dirns:\n",
    "        return_dict[\"cca_dirns1\"] = np.zeros((1, dimension))\n",
    "        return_dict[\"cca_dirns2\"] = np.zeros((1, dimension))\n",
    "\n",
    "    return return_dict\n",
    "\n",
    "\n",
    "def get_cca_similarity(\n",
    "    acts1,\n",
    "    acts2,\n",
    "    epsilon=0.0,\n",
    "    threshold=0.98,\n",
    "    compute_coefs=True,\n",
    "    compute_dirns=False,\n",
    "    verbose=True,\n",
    "):\n",
    "    \"\"\"The main function for computing cca similarities.\n",
    "\n",
    "    This function computes the cca similarity between two sets of activations,\n",
    "    returning a dict with the cca coefficients, a few statistics of the cca\n",
    "    coefficients, and (optionally) the actual directions.\n",
    "\n",
    "    Args:\n",
    "              acts1: (num_neurons1, data_points) a 2d numpy array of neurons by\n",
    "                     datapoints where entry (i,j) is the output of neuron i on\n",
    "                     datapoint j.\n",
    "              acts2: (num_neurons2, data_points) same as above, but (potentially)\n",
    "                     for a different set of neurons. Note that acts1 and acts2\n",
    "                     can have different numbers of neurons, but must agree on the\n",
    "                     number of datapoints\n",
    "\n",
    "              epsilon: small float to help stabilize computations\n",
    "\n",
    "              threshold: float between 0, 1 used to get rid of trailing zeros in\n",
    "                         the cca correlation coefficients to output more accurate\n",
    "                         summary statistics of correlations.\n",
    "\n",
    "\n",
    "              compute_coefs: boolean value determining whether coefficients\n",
    "                             over neurons are computed. Needed for computing\n",
    "                             directions\n",
    "\n",
    "              compute_dirns: boolean value determining whether actual cca\n",
    "                             directions are computed. (For very large neurons and\n",
    "                             datasets, may be better to compute these on the fly\n",
    "                             instead of store in memory.)\n",
    "\n",
    "              verbose: Boolean, whether intermediate outputs are printed\n",
    "\n",
    "    Returns:\n",
    "              return_dict: A dictionary with outputs from the cca computations.\n",
    "                           Contains neuron coefficients (combinations of neurons\n",
    "                           that correspond to cca directions), the cca correlation\n",
    "                           coefficients (how well aligned directions correlate),\n",
    "                           x and y idxs (for computing cca directions on the fly\n",
    "                           if compute_dirns=False), and summary statistics. If\n",
    "                           compute_dirns=True, the cca directions are also\n",
    "                           computed.\n",
    "    \"\"\"\n",
    "\n",
    "    # assert dimensionality equal\n",
    "    assert acts1.shape[1] == acts2.shape[1], \"dimensions don't match\"\n",
    "    # check that acts1, acts2 are transposition\n",
    "    assert acts1.shape[0] < acts1.shape[1], \"input must be number of neurons\" \"by datapoints\"\n",
    "    return_dict = {}\n",
    "\n",
    "    # compute covariance with numpy function for extra stability\n",
    "    numx = acts1.shape[0]\n",
    "    numy = acts2.shape[0]\n",
    "\n",
    "    covariance = np.cov(acts1, acts2)\n",
    "    sigmaxx = covariance[:numx, :numx]\n",
    "    sigmaxy = covariance[:numx, numx:]\n",
    "    sigmayx = covariance[numx:, :numx]\n",
    "    sigmayy = covariance[numx:, numx:]\n",
    "\n",
    "    # rescale covariance to make cca computation more stable\n",
    "    xmax = np.max(np.abs(sigmaxx))\n",
    "    ymax = np.max(np.abs(sigmayy))\n",
    "    sigmaxx /= xmax\n",
    "    sigmayy /= ymax\n",
    "    sigmaxy /= np.sqrt(xmax * ymax)\n",
    "    sigmayx /= np.sqrt(xmax * ymax)\n",
    "\n",
    "    ([u, s, v], invsqrt_xx, invsqrt_yy, x_idxs, y_idxs) = compute_ccas(\n",
    "        sigmaxx, sigmaxy, sigmayx, sigmayy, epsilon=epsilon, verbose=verbose\n",
    "    )\n",
    "\n",
    "    # if x_idxs or y_idxs is all false, return_dict has zero entries\n",
    "    if (not np.any(x_idxs)) or (not np.any(y_idxs)):\n",
    "        return create_zero_dict(compute_dirns, acts1.shape[1])\n",
    "\n",
    "    if compute_coefs:\n",
    "        # also compute full coefficients over all neurons\n",
    "        x_mask = np.dot(x_idxs.reshape((-1, 1)), x_idxs.reshape((1, -1)))\n",
    "        y_mask = np.dot(y_idxs.reshape((-1, 1)), y_idxs.reshape((1, -1)))\n",
    "\n",
    "        return_dict[\"coef_x\"] = u.T\n",
    "        return_dict[\"invsqrt_xx\"] = invsqrt_xx\n",
    "        return_dict[\"full_coef_x\"] = np.zeros((numx, numx))\n",
    "        np.place(return_dict[\"full_coef_x\"], x_mask, return_dict[\"coef_x\"])\n",
    "        return_dict[\"full_invsqrt_xx\"] = np.zeros((numx, numx))\n",
    "        np.place(return_dict[\"full_invsqrt_xx\"], x_mask, return_dict[\"invsqrt_xx\"])\n",
    "\n",
    "        return_dict[\"coef_y\"] = v\n",
    "        return_dict[\"invsqrt_yy\"] = invsqrt_yy\n",
    "        return_dict[\"full_coef_y\"] = np.zeros((numy, numy))\n",
    "        np.place(return_dict[\"full_coef_y\"], y_mask, return_dict[\"coef_y\"])\n",
    "        return_dict[\"full_invsqrt_yy\"] = np.zeros((numy, numy))\n",
    "        np.place(return_dict[\"full_invsqrt_yy\"], y_mask, return_dict[\"invsqrt_yy\"])\n",
    "\n",
    "        # compute means\n",
    "        neuron_means1 = np.mean(acts1, axis=1, keepdims=True)\n",
    "        neuron_means2 = np.mean(acts2, axis=1, keepdims=True)\n",
    "        return_dict[\"neuron_means1\"] = neuron_means1\n",
    "        return_dict[\"neuron_means2\"] = neuron_means2\n",
    "\n",
    "    if compute_dirns:\n",
    "        # orthonormal directions that are CCA directions\n",
    "        cca_dirns1 = (\n",
    "            np.dot(\n",
    "                np.dot(return_dict[\"full_coef_x\"], return_dict[\"full_invsqrt_xx\"]),\n",
    "                (acts1 - neuron_means1),\n",
    "            )\n",
    "            + neuron_means1\n",
    "        )\n",
    "        cca_dirns2 = (\n",
    "            np.dot(\n",
    "                np.dot(return_dict[\"full_coef_y\"], return_dict[\"full_invsqrt_yy\"]),\n",
    "                (acts2 - neuron_means2),\n",
    "            )\n",
    "            + neuron_means2\n",
    "        )\n",
    "\n",
    "    # get rid of trailing zeros in the cca coefficients\n",
    "    idx1 = sum_threshold(s, threshold)\n",
    "    idx2 = sum_threshold(s, threshold)\n",
    "\n",
    "    return_dict[\"cca_coef1\"] = s\n",
    "    return_dict[\"cca_coef2\"] = s\n",
    "    return_dict[\"x_idxs\"] = x_idxs\n",
    "    return_dict[\"y_idxs\"] = y_idxs\n",
    "    # summary statistics\n",
    "    return_dict[\"mean\"] = (np.mean(s[:idx1]), np.mean(s[:idx2]))\n",
    "    return_dict[\"sum\"] = (np.sum(s), np.sum(s))\n",
    "\n",
    "    if compute_dirns:\n",
    "        return_dict[\"cca_dirns1\"] = cca_dirns1\n",
    "        return_dict[\"cca_dirns2\"] = cca_dirns2\n",
    "\n",
    "    return return_dict\n",
    "\n",
    "\n",
    "def robust_cca_similarity(acts1, acts2, threshold=0.98, epsilon=1e-6, compute_dirns=True):\n",
    "    \"\"\"Calls get_cca_similarity multiple times while adding noise.\n",
    "\n",
    "    This function is very similar to get_cca_similarity, and can be used if\n",
    "    get_cca_similarity doesn't converge for some pair of inputs. This function\n",
    "    adds some noise to the activations to help convergence.\n",
    "\n",
    "    Args:\n",
    "              acts1: (num_neurons1, data_points) a 2d numpy array of neurons by\n",
    "                     datapoints where entry (i,j) is the output of neuron i on\n",
    "                     datapoint j.\n",
    "              acts2: (num_neurons2, data_points) same as above, but (potentially)\n",
    "                     for a different set of neurons. Note that acts1 and acts2\n",
    "                     can have different numbers of neurons, but must agree on the\n",
    "                     number of datapoints\n",
    "\n",
    "              threshold: float between 0, 1 used to get rid of trailing zeros in\n",
    "                         the cca correlation coefficients to output more accurate\n",
    "                         summary statistics of correlations.\n",
    "\n",
    "              epsilon: small float to help stabilize computations\n",
    "\n",
    "              compute_dirns: boolean value determining whether actual cca\n",
    "                             directions are computed. (For very large neurons and\n",
    "                             datasets, may be better to compute these on the fly\n",
    "                             instead of store in memory.)\n",
    "\n",
    "    Returns:\n",
    "              return_dict: A dictionary with outputs from the cca computations.\n",
    "                           Contains neuron coefficients (combinations of neurons\n",
    "                           that correspond to cca directions), the cca correlation\n",
    "                           coefficients (how well aligned directions correlate),\n",
    "                           x and y idxs (for computing cca directions on the fly\n",
    "                           if compute_dirns=False), and summary statistics. If\n",
    "                           compute_dirns=True, the cca directions are also\n",
    "                           computed.\n",
    "    \"\"\"\n",
    "\n",
    "    for trial in range(num_cca_trials):\n",
    "        try:\n",
    "            return_dict = get_cca_similarity(acts1, acts2, threshold, compute_dirns)\n",
    "        except np.linalg.LinAlgError:\n",
    "            acts1 = acts1 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon\n",
    "            acts2 = acts2 * 1e-1 + np.random.normal(size=acts1.shape) * epsilon\n",
    "            if trial + 1 == num_cca_trials:\n",
    "                raise\n",
    "\n",
    "    return return_dict\n",
    "    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/cca_core.py\n",
    "\n",
    "\n",
    "def top_k_pca_comps(singular_values, threshold=0.99):\n",
    "    total_variance = np.sum(singular_values**2)\n",
    "    explained_variance = (singular_values**2) / total_variance\n",
    "    cumulative_variance = np.cumsum(explained_variance)\n",
    "    return np.argmax(cumulative_variance >= threshold * total_variance) + 1\n",
    "\n",
    "\n",
    "def _svcca_original(acts1, acts2):\n",
    "    # Copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/tutorials/001_Introduction.ipynb\n",
    "    # Modification: get_cca_similarity is in the same file.\n",
    "    # Modification: top-k PCA component selection s.t. explained variance > 0.99 total variance\n",
    "    # Mean subtract activations\n",
    "    cacts1 = acts1 - np.mean(acts1, axis=1, keepdims=True)\n",
    "    cacts2 = acts2 - np.mean(acts2, axis=1, keepdims=True)\n",
    "\n",
    "    # Perform SVD\n",
    "    U1, s1, V1 = np.linalg.svd(cacts1, full_matrices=False)\n",
    "    U2, s2, V2 = np.linalg.svd(cacts2, full_matrices=False)\n",
    "\n",
    "    # top-k PCA components only\n",
    "    k1 = top_k_pca_comps(s1)\n",
    "    k2 = top_k_pca_comps(s2)\n",
    "\n",
    "    svacts1 = np.dot(s1[:k1] * np.eye(k1), V1[:k1])\n",
    "    # can also compute as svacts1 = np.dot(U1.T[:20], cacts1)\n",
    "    svacts2 = np.dot(s2[:k2] * np.eye(k2), V2[:k2])\n",
    "    # can also compute as svacts1 = np.dot(U2.T[:20], cacts2)\n",
    "\n",
    "    svcca_results = get_cca_similarity(svacts1, svacts2, epsilon=1e-10, verbose=False)\n",
    "    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/tutorials/001_Introduction.ipynb\n",
    "    return np.mean(svcca_results[\"cca_coef1\"])\n",
    "\n",
    "\n",
    "# Copied from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/pwcca.py\n",
    "# Modification: get_cca_similarity is in the same file.\n",
    "def compute_pwcca(acts1, acts2, epsilon=0.0):\n",
    "    \"\"\"Computes projection weighting for weighting CCA coefficients\n",
    "\n",
    "    Args:\n",
    "         acts1: 2d numpy array, shaped (neurons, num_datapoints)\n",
    "         acts2: 2d numpy array, shaped (neurons, num_datapoints)\n",
    "\n",
    "    Returns:\n",
    "         Original cca coefficient mean and weighted mean\n",
    "\n",
    "    \"\"\"\n",
    "    sresults = get_cca_similarity(\n",
    "        acts1,\n",
    "        acts2,\n",
    "        epsilon=epsilon,\n",
    "        compute_dirns=False,\n",
    "        compute_coefs=True,\n",
    "        verbose=False,\n",
    "    )\n",
    "    if np.sum(sresults[\"x_idxs\"]) <= np.sum(sresults[\"y_idxs\"]):\n",
    "        dirns = (\n",
    "            np.dot(\n",
    "                sresults[\"coef_x\"],\n",
    "                (acts1[sresults[\"x_idxs\"]] - sresults[\"neuron_means1\"][sresults[\"x_idxs\"]]),\n",
    "            )\n",
    "            + sresults[\"neuron_means1\"][sresults[\"x_idxs\"]]\n",
    "        )\n",
    "        coefs = sresults[\"cca_coef1\"]\n",
    "        acts = acts1\n",
    "        idxs = sresults[\"x_idxs\"]\n",
    "    else:\n",
    "        dirns = (\n",
    "            np.dot(\n",
    "                sresults[\"coef_y\"],\n",
    "                (acts1[sresults[\"y_idxs\"]] - sresults[\"neuron_means2\"][sresults[\"y_idxs\"]]),\n",
    "            )\n",
    "            + sresults[\"neuron_means2\"][sresults[\"y_idxs\"]]\n",
    "        )\n",
    "        coefs = sresults[\"cca_coef2\"]\n",
    "        acts = acts2\n",
    "        idxs = sresults[\"y_idxs\"]\n",
    "    P, _ = np.linalg.qr(dirns.T)\n",
    "    weights = np.sum(np.abs(np.dot(P.T, acts[idxs].T)), axis=1)\n",
    "    weights = weights / np.sum(weights)\n",
    "\n",
    "    return np.sum(weights * coefs), weights, coefs\n",
    "    # End of copy from https://github.com/google/svcca/blob/1f3fbf19bd31bd9b76e728ef75842aa1d9a4cd2b/pwcca.py\n",
    "\n",
    "\n",
    "##################################################################################\n",
    "\n",
    "from typing import Union  # noqa:e402\n",
    "\n",
    "import numpy.typing as npt  # noqa:e402\n",
    "import torch  # noqa:e402\n",
    "\n",
    "# from repsim.measures.utils import (\n",
    "#     SHAPE_TYPE,\n",
    "#     flatten,\n",
    "#     resize_wh_reps,\n",
    "#     to_numpy_if_needed,\n",
    "#     RepresentationalSimilarityMeasure,\n",
    "# )  # noqa:e402\n",
    "\n",
    "\n",
    "def svcca(\n",
    "    R: Union[torch.Tensor, npt.NDArray],\n",
    "    Rp: Union[torch.Tensor, npt.NDArray],\n",
    "    shape: SHAPE_TYPE,\n",
    ") -> float:\n",
    "    R, Rp = flatten(R, Rp, shape=shape)\n",
    "    R, Rp = to_numpy_if_needed(R, Rp)\n",
    "    return _svcca_original(R.T, Rp.T)\n",
    "\n",
    "\n",
    "def pwcca(\n",
    "    R: Union[torch.Tensor, npt.NDArray],\n",
    "    Rp: Union[torch.Tensor, npt.NDArray],\n",
    "    shape: SHAPE_TYPE,\n",
    ") -> float:\n",
    "    R, Rp = flatten(R, Rp, shape=shape)\n",
    "    R, Rp = to_numpy_if_needed(R, Rp)\n",
    "    return compute_pwcca(R.T, Rp.T)[0]\n",
    "\n",
    "\n",
    "class SVCCA(RepresentationalSimilarityMeasure):\n",
    "    def __init__(self):\n",
    "        super().__init__(\n",
    "            sim_func=svcca,\n",
    "            larger_is_more_similar=True,\n",
    "            is_metric=False,\n",
    "            is_symmetric=True,\n",
    "            invariant_to_affine=False,\n",
    "            invariant_to_invertible_linear=False,\n",
    "            invariant_to_ortho=True,\n",
    "            invariant_to_permutation=True,\n",
    "            invariant_to_isotropic_scaling=True,\n",
    "            invariant_to_translation=True,\n",
    "        )\n",
    "\n",
    "    def __call__(self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE) -> float:\n",
    "        if shape == \"nchw\":\n",
    "            # Move spatial dimensions into the sample dimension\n",
    "            # If not the same spatial dimension, resample via FFT.\n",
    "            R, Rp = align_spatial_dimensions(R, Rp)\n",
    "            shape = \"nd\"\n",
    "\n",
    "        return self.sim_func(R, Rp, shape)\n",
    "\n",
    "\n",
    "class PWCCA(RepresentationalSimilarityMeasure):\n",
    "    def __init__(self):\n",
    "        super().__init__(\n",
    "            sim_func=pwcca,\n",
    "            larger_is_more_similar=True,\n",
    "            is_metric=False,\n",
    "            is_symmetric=False,\n",
    "            invariant_to_affine=False,\n",
    "            invariant_to_invertible_linear=False,\n",
    "            invariant_to_ortho=False,\n",
    "            invariant_to_permutation=False,\n",
    "            invariant_to_isotropic_scaling=True,\n",
    "            invariant_to_translation=True,\n",
    "        )\n",
    "\n",
    "    def __call__(self, R: torch.Tensor | npt.NDArray, Rp: torch.Tensor | npt.NDArray, shape: SHAPE_TYPE) -> float:\n",
    "        if shape == \"nchw\":\n",
    "            # Move spatial dimensions into the sample dimension\n",
    "            # If not the same spatial dimension, resample via FFT.\n",
    "            R, Rp = align_spatial_dimensions(R, Rp)\n",
    "            shape = \"nd\"\n",
    "\n",
    "        return self.sim_func(R, Rp, shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aQ_3rxDtd3Mc"
   },
   "source": [
    "## get rand"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BqIiTtkid4qA"
   },
   "outputs": [],
   "source": [
    "def score_rand(num_runs, weight_matrix_np, weight_matrix_2, num_feats, sim_fn, shapereq_bool):\n",
    "    all_rand_scores = []\n",
    "    i = 0\n",
    "    # for i in range(num_runs):\n",
    "    while i < num_runs:\n",
    "        try:\n",
    "            rand_modA_feats = np.random.choice(range(weight_matrix_np.shape[0]), size=num_feats, replace=False).tolist()\n",
    "            rand_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()\n",
    "\n",
    "            if shapereq_bool:\n",
    "                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats], \"nd\")\n",
    "            else:\n",
    "                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats])\n",
    "            all_rand_scores.append(score)\n",
    "            i += 1\n",
    "        except:\n",
    "            continue\n",
    "    return sum(all_rand_scores) / len(all_rand_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dnIhrqIsNv6e"
   },
   "outputs": [],
   "source": [
    "def score_rand_corr(num_runs, weight_matrix_np, weight_matrix_2, num_feats, highest_correlations_indices_AB, sim_fn, shapereq_bool):\n",
    "    all_rand_scores = []\n",
    "    i = 0\n",
    "    # for i in range(num_runs):\n",
    "    while i < num_runs:\n",
    "        try:\n",
    "            rand_modB_feats = np.random.choice(range(weight_matrix_2.shape[0]), size=num_feats, replace=False).tolist()\n",
    "            rand_modA_feats = [highest_correlations_indices_AB[index] for index in rand_modB_feats]\n",
    "\n",
    "            if shapereq_bool:\n",
    "                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats], \"nd\")\n",
    "            else:\n",
    "                score = sim_fn(weight_matrix_np[rand_modA_feats], weight_matrix_2[rand_modB_feats])\n",
    "            all_rand_scores.append(score)\n",
    "            i += 1\n",
    "        except:\n",
    "            continue\n",
    "    # print(sum(all_rand_scores) / len(all_rand_scores))\n",
    "    # plt.hist(all_rand_scores)\n",
    "    # plt.show()\n",
    "    return all_rand_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eC82qzWXPEFv"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "def shuffle_rand(num_runs, weight_matrix_np, weight_matrix_2, num_feats, sim_fn, shapereq_bool):\n",
    "    all_rand_scores = []\n",
    "    for i in range(num_runs):\n",
    "        row_idxs = list(range(num_feats))\n",
    "        random.shuffle(row_idxs)\n",
    "        if shapereq_bool:\n",
    "            score = sim_fn(weight_matrix_np, weight_matrix_2[row_idxs], \"nd\")\n",
    "        else:\n",
    "            score = sim_fn(weight_matrix_np, weight_matrix_2[row_idxs])\n",
    "        all_rand_scores.append(score)\n",
    "    # return sum(all_rand_scores) / len(all_rand_scores)\n",
    "    return all_rand_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PK8NhdojHdUW"
   },
   "source": [
    "## interpret fns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XN2_7D7ADn8H"
   },
   "outputs": [],
   "source": [
    "def highest_activating_tokens(\n",
    "    feature_acts,\n",
    "    feature_idx: int,\n",
    "    k: int = 10,  # num batch_seq samples\n",
    "    batch_tokens=None\n",
    "): # -> Tuple[Int[Tensor, \"k 2\"], Float[Tensor, \"k\"]]:\n",
    "    '''\n",
    "    Returns the indices & values for the highest-activating tokens in the given batch of data.\n",
    "    '''\n",
    "    batch_size, seq_len = batch_tokens.shape\n",
    "\n",
    "    # Get the top k largest activations for only targeted feature\n",
    "    # need to flatten (batch,seq) into batch*seq first because it's ANY batch_seq, even if in same batch or same pos\n",
    "    flattened_feature_acts = feature_acts[:, :, feature_idx].reshape(-1)\n",
    "\n",
    "    top_acts_values, top_acts_indices = flattened_feature_acts.topk(k)\n",
    "    # top_acts_values should be 1D\n",
    "    # top_acts_indices should be also be 1D. Now, turn it back to 2D\n",
    "    # Convert the indices into (batch, seq) indices\n",
    "    top_acts_batch = top_acts_indices // seq_len\n",
    "    top_acts_seq = top_acts_indices % seq_len\n",
    "\n",
    "    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Gr17XzJIhVUM"
   },
   "outputs": [],
   "source": [
    "def store_top_toks(top_acts_indices, top_acts_values, batch_tokens):\n",
    "    feat_samps = []\n",
    "    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):\n",
    "        new_str_token = tokenizer.decode(batch_tokens[batch_idx, seq_idx]).replace(\"\\n\", \"\\\\n\").replace(\"<|BOS|>\", \"|BOS|\")\n",
    "        feat_samps.append(new_str_token)\n",
    "    return feat_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FEbmjLd7c-Ju"
   },
   "outputs": [],
   "source": [
    "def find_indices_with_keyword(fList, keyword):\n",
    "    \"\"\"\n",
    "    Find all indices of fList which contain the keyword in the string at those indices.\n",
    "\n",
    "    Args:\n",
    "    fList (list of str): List of strings to search within.\n",
    "    keyword (str): Keyword to search for within the strings of fList.\n",
    "\n",
    "    Returns:\n",
    "    list of int: List of indices where the keyword is found within the strings of fList.\n",
    "    \"\"\"\n",
    "    index_list = []\n",
    "    for index, split_list in enumerate(fList):\n",
    "        no_space_list = [i.replace(' ', '').lower() for i in split_list]\n",
    "        for tok in no_space_list:\n",
    "            if keyword.lower() == tok:\n",
    "                index_list.append(index)\n",
    "        # if keyword in no_space_list:\n",
    "            # index_list.append(index)\n",
    "    return index_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tn42TylkNkfx"
   },
   "outputs": [],
   "source": [
    "from rich import print as rprint\n",
    "def display_top_sequences(top_acts_indices, top_acts_values, batch_tokens):\n",
    "    s = \"\"\n",
    "    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):\n",
    "        # s += f'{batch_idx}\\n'\n",
    "        s += f'batchID: {batch_idx}, '\n",
    "        # Get the sequence as a string (with some padding on either side of our sequence)\n",
    "        seq_start = max(seq_idx - 5, 0)\n",
    "        seq_end = min(seq_idx + 5, batch_tokens.shape[1])\n",
    "        seq = \"\"\n",
    "        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)\n",
    "        for i in range(seq_start, seq_end):\n",
    "            # new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace(\"\\n\", \"\\\\n\").replace(\"<|BOS|>\", \"|BOS|\")\n",
    "            new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace(\"\\n\", \"\\\\n\").replace(\"<|BOS|>\", \"|BOS|\")\n",
    "            if i == seq_idx:\n",
    "                new_str_token = f\"[bold u dark_orange]{new_str_token}[/]\"\n",
    "            seq += new_str_token\n",
    "        # Print the sequence, and the activation value\n",
    "        s += f'Act = {value:.2f}, Seq = \"{seq}\"\\n'\n",
    "\n",
    "    rprint(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AsAw_EAwGwVc"
   },
   "outputs": [],
   "source": [
    "def store_top_seqs(top_acts_indices, top_acts_values, batch_tokens):\n",
    "    feat_samps = []\n",
    "    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):\n",
    "        # Get the sequence as a string (with some padding on either side of our sequence)\n",
    "        seq_start = max(seq_idx - 2, 0)\n",
    "        seq_end = min(seq_idx + 2, batch_tokens.shape[1])\n",
    "        seq = \"\"\n",
    "        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)\n",
    "        for i in range(seq_start, seq_end):\n",
    "            new_str_token = tokenizer.decode([batch_tokens[batch_idx, i].item()]).replace(\"\\n\", \"\\\\n\").replace(\"<|BOS|>\", \"|BOS|\")\n",
    "            if i == seq_idx:\n",
    "                topTok = new_str_token\n",
    "            seq += new_str_token\n",
    "        feat_samps.append( (seq, topTok) )\n",
    "    return feat_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "U8ixT9tOHd3F"
   },
   "outputs": [],
   "source": [
    "def find_indices_with_keyword_bySeqs(fList_seqs, keyword):\n",
    "    feat_list = []\n",
    "    for feat_ind, top_seqs_andToks_lst in enumerate(fList_seqs):\n",
    "        for top_seqs_andToks in top_seqs_andToks_lst:\n",
    "            seq = top_seqs_andToks[0]\n",
    "            topTok = top_seqs_andToks[1].replace(' ', '').lower()\n",
    "            if keyword.lower() != topTok:\n",
    "                continue\n",
    "            split_list = seq.split(' ')\n",
    "            flag = False\n",
    "            for word in split_list:\n",
    "                word = word.replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('\\\\n','')\n",
    "\n",
    "                if keyword.lower() == word.lower():\n",
    "                    feat_list.append(feat_ind)\n",
    "                    flag = True\n",
    "                    break\n",
    "            if flag:\n",
    "                break\n",
    "    return feat_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pyfIXtTAOrAz"
   },
   "source": [
    "## get concept space features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eBMmyu0dWdgT"
   },
   "outputs": [],
   "source": [
    "def get_mixed_feats(fList_model_B, corr_inds, keywords):\n",
    "    mixed_modA_feats = []\n",
    "    mixed_modB_feats = []\n",
    "    added_modA_feats = set()  # To track which modA feats have been added\n",
    "    added_modB_feats = set()  # To track which modB feats have been added\n",
    "\n",
    "    for kw in keywords:\n",
    "        modB_feats = find_indices_with_keyword(fList_model_B, kw)\n",
    "        for index in modB_feats:\n",
    "            modA_feat = corr_inds[index]\n",
    "            modB_feat = index\n",
    "\n",
    "            # Check if the feature has already been added to maintain uniqueness\n",
    "            if modA_feat not in added_modA_feats and modB_feat not in added_modB_feats:\n",
    "                mixed_modA_feats.append(modA_feat)\n",
    "                mixed_modB_feats.append(modB_feat)\n",
    "                added_modA_feats.add(modA_feat)\n",
    "                added_modB_feats.add(modB_feat)\n",
    "\n",
    "    print(\"Unique modA feats: \", len(mixed_modA_feats))\n",
    "    print(\"Unique modB feats: \", len(mixed_modB_feats))\n",
    "    return mixed_modA_feats, mixed_modB_feats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y9-qjLEdeixo"
   },
   "outputs": [],
   "source": [
    "def get_mixed_feats_with_kwList(fList_model_B, corr_inds, keywords):\n",
    "    mixed_modA_feats = []\n",
    "    mixed_modB_feats = []\n",
    "    added_modA_feats = set()  # To track which modA feats have been added\n",
    "    added_modB_feats = set()  # To track which modB feats have been added\n",
    "    keywords_to_feats = {kw: 0 for kw in keywords} # kw : count\n",
    "\n",
    "    for kw in keywords:\n",
    "        modB_feats = find_indices_with_keyword(fList_model_B, kw)\n",
    "        for index in modB_feats:\n",
    "            modA_feat = corr_inds[index]\n",
    "            modB_feat = index\n",
    "\n",
    "            # Check if the feature has already been added to maintain uniqueness\n",
    "            if modA_feat not in added_modA_feats and modB_feat not in added_modB_feats:\n",
    "                mixed_modA_feats.append(modA_feat)\n",
    "                mixed_modB_feats.append(modB_feat)\n",
    "                added_modA_feats.add(modA_feat)\n",
    "                added_modB_feats.add(modB_feat)\n",
    "                keywords_to_feats[kw] += 1\n",
    "\n",
    "    print(\"Unique modA feats: \", len(mixed_modA_feats))\n",
    "    print(\"Unique modB feats: \", len(mixed_modB_feats))\n",
    "    return mixed_modA_feats, mixed_modB_feats, keywords_to_feats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AOif8cM_2zy5"
   },
   "source": [
    "## get llm actv fns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QKwsR6-N21nY"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "def get_llm_actvs_batch(model, inputs, layerID):\n",
    "    accumulated_outputs = None\n",
    "    dataset = TensorDataset(inputs['input_ids'], inputs['attention_mask'])\n",
    "    loader = DataLoader(dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "    all_hidden_states = []\n",
    "    for batch in loader:\n",
    "        input_ids, attention_mask = batch\n",
    "\n",
    "        batch_inputs = {'input_ids': input_ids.to(model.device), 'attention_mask': attention_mask.to(model.device)}\n",
    "        with torch.no_grad():  # Disable gradient calculation for memory efficiency\n",
    "            outputs = model(**batch_inputs, output_hidden_states=True)\n",
    "            if accumulated_outputs is None:\n",
    "                accumulated_outputs = outputs.hidden_states[layerID]\n",
    "            else:\n",
    "                accumulated_outputs = torch.cat((accumulated_outputs, outputs.hidden_states[layerID]), dim= 0)\n",
    "\n",
    "        del batch_inputs, outputs\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n",
    "\n",
    "    return accumulated_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dmai8v8hgqr7"
   },
   "source": [
    "## get actv fns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uzs6yngH-Chj"
   },
   "outputs": [],
   "source": [
    "def get_weights_and_acts(layer_id, outputs):\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "        release = \"gemma-2b-res-jb\",\n",
    "        sae_id = f\"blocks.{layer_id}.hook_resid_post\",\n",
    "    )\n",
    "    sae = sae.to('cuda')\n",
    "    sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "\n",
    "    weight_matrix = sae.W_dec.cpu().detach().numpy()\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        orig = sae.encode(outputs)\n",
    "        # reshaped_activations = sae.pre_acts(outputs.hidden_states[layer_id].to(\"cuda\"))\n",
    "\n",
    "    first_dim_reshaped = orig.shape[0] * orig.shape[1]\n",
    "    reshaped_activations = orig.reshape(first_dim_reshaped, orig.shape[-1]).cpu()\n",
    "\n",
    "    # return weight_matrix, reshaped_activations, orig\n",
    "    return weight_matrix, reshaped_activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EAf3Sz_6nle6"
   },
   "outputs": [],
   "source": [
    "def count_zero_columns(tensor):\n",
    "    # Check if all elements in each column are zero\n",
    "    zero_columns = np.all(tensor == 0, axis=0)\n",
    "    # Count True values in the zero_columns array\n",
    "    zero_cols_indices = np.where(zero_columns)[0]\n",
    "    return np.sum(zero_columns), zero_cols_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lBgRhpf9cLjy"
   },
   "source": [
    "## run expm fns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s_0rkWR6zE9d"
   },
   "outputs": [],
   "source": [
    "def run_semantic_subspace_expms(fList_model_B_seqs, model_A_layer_labels, inputs):\n",
    "    metric_dict = {'SVCCA': svcca, 'RSA': representational_similarity_analysis}\n",
    "    # map_type_list = ['manyTo1'] # , '1To1'\n",
    "    map_type_list = ['1To1']\n",
    "    layer_to_dictscores = {}\n",
    "    gemma1_layers = [6, 10, 12, 17]\n",
    "    num_runs = 1000\n",
    "\n",
    "    for layerID_2 in gemma1_layers:\n",
    "        print(\"Layer: \" + str(layerID_2))\n",
    "        dictscores = {}\n",
    "\n",
    "        with torch.inference_mode():\n",
    "            outputs = get_llm_actvs_batch(model, inputs, layerID_2)\n",
    "\n",
    "        weight_matrix_np, reshaped_activations_A = get_weights_and_acts(layerID_2, outputs)\n",
    "\n",
    "        all_scores_dict = {}\n",
    "        for concept_name in keywords.keys():\n",
    "            print(concept_name)\n",
    "            scores_dict = {}\n",
    "\n",
    "            fList_model_A_seqs = model_A_layer_labels[layerID_2]\n",
    "\n",
    "            new_keywords = keywords[concept_name]\n",
    "\n",
    "            mixed_modA_feats = set()\n",
    "            mixed_modB_feats = set()\n",
    "            for kw in new_keywords:\n",
    "                modB_feats = find_indices_with_keyword_bySeqs(fList_model_B_seqs, kw)\n",
    "                modA_feats = find_indices_with_keyword_bySeqs(fList_model_A_seqs, kw)\n",
    "                mixed_modA_feats.update(modA_feats)\n",
    "                mixed_modB_feats.update(modB_feats)\n",
    "            mixed_modA_feats = list(mixed_modA_feats)\n",
    "            mixed_modB_feats = list(mixed_modB_feats)\n",
    "\n",
    "            if len(mixed_modA_feats) <= 2 or len(mixed_modB_feats)  <= 2:\n",
    "                for map_type in map_type_list:\n",
    "                    for metric_name, metric_fn in metric_dict.items():\n",
    "                        scores_dict[f'num_unq_pairs'] = None\n",
    "                        scores_dict[f'paired_{metric_name}_{map_type}'] = None\n",
    "                        scores_dict[f'rand_shuff_mean_{metric_name}_{map_type}'] = None\n",
    "                        scores_dict[f'rand_shuff_pval_{metric_name}_{map_type}'] =  None\n",
    "                all_scores_dict[concept_name] = scores_dict\n",
    "                continue\n",
    "\n",
    "            subset_inds, subset_vals = batched_correlation(reshaped_activations_A[:, mixed_modA_feats],\n",
    "                                               reshaped_activations_B[:, mixed_modB_feats])\n",
    "\n",
    "            subset_sorted_feat_counts = Counter(subset_inds).most_common()\n",
    "            subset_kept_modA_feats = [feat_ID for feat_ID, count in subset_sorted_feat_counts if count == 1]\n",
    "\n",
    "            filt_corr_ind_A = []\n",
    "            filt_corr_ind_B = []\n",
    "            seen = set()\n",
    "            for ind_B, ind_A in enumerate(subset_inds):\n",
    "                if ind_A in subset_kept_modA_feats:\n",
    "                    filt_corr_ind_A.append(ind_A)\n",
    "                    filt_corr_ind_B.append(ind_B)\n",
    "                elif ind_A not in seen:  # only keep one if it's over count X\n",
    "                    seen.add(ind_A)\n",
    "                    filt_corr_ind_A.append(ind_A)\n",
    "                    filt_corr_ind_B.append(ind_B)\n",
    "            num_unq_pairs = len(list(set(filt_corr_ind_A)))\n",
    "            # print(\"% unique: \", num_unq_pairs / len(filt_corr_ind_A))\n",
    "            scores_dict[f'num_unq_pairs'] = num_unq_pairs\n",
    "\n",
    "            original_A_indices = [mixed_modA_feats[ind] for ind in filt_corr_ind_A]\n",
    "            original_B_indices = [mixed_modB_feats[i] for i in filt_corr_ind_B]\n",
    "\n",
    "            if len(original_A_indices)  <= 2 or len(original_B_indices)  <= 2:\n",
    "                for map_type in map_type_list:\n",
    "                    for metric_name, metric_fn in metric_dict.items():\n",
    "                        scores_dict[f'num_unq_pairs'] = None\n",
    "                        scores_dict[f'paired_{metric_name}_{map_type}'] = None\n",
    "                        scores_dict[f'rand_shuff_mean_{metric_name}_{map_type}'] = None\n",
    "                        scores_dict[f'rand_shuff_pval_{metric_name}_{map_type}'] =  None\n",
    "                all_scores_dict[concept_name] = scores_dict\n",
    "                continue\n",
    "\n",
    "            for map_type in map_type_list:\n",
    "                # print(map_type)\n",
    "                if map_type == '1To1':\n",
    "                    X_subset = weight_matrix_np[original_A_indices]\n",
    "                    Y_subset = weight_matrix_2[original_B_indices]\n",
    "                    num_feats = len(list(set(filt_corr_ind_A)))\n",
    "\n",
    "                for metric_name, metric_fn in metric_dict.items():\n",
    "                    try:\n",
    "                        paired_score = metric_fn(X_subset, Y_subset, \"nd\")\n",
    "                    except:\n",
    "                        scores_dict[f'paired_{metric_name}_{map_type}'] = None\n",
    "                        scores_dict[f'rand_shuff_mean_{metric_name}_{map_type}'] = None\n",
    "                        scores_dict[f'rand_shuff_pval_{metric_name}_{map_type}'] =  None\n",
    "                        continue\n",
    "                    scores_dict[f'paired_{metric_name}_{map_type}'] = paired_score\n",
    "\n",
    "                    rand_shuff_scores = shuffle_rand(num_runs, X_subset, Y_subset, Y_subset.shape[0],\n",
    "                                                    metric_fn, shapereq_bool=True)\n",
    "                    scores_dict[f'rand_shuff_mean_{metric_name}_{map_type}'] = sum(rand_shuff_scores) / len(rand_shuff_scores)\n",
    "                    scores_dict[f'rand_shuff_pval_{metric_name}_{map_type}'] =  np.mean(np.array(rand_shuff_scores) >= paired_score)\n",
    "\n",
    "            all_scores_dict[concept_name] = scores_dict\n",
    "\n",
    "        for key, value in all_scores_dict.items():\n",
    "            print(key + \": \" + str(value))\n",
    "        print(\"\\n\")\n",
    "\n",
    "        layer_to_dictscores[layerID_2] = all_scores_dict\n",
    "    return layer_to_dictscores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TAPKX8PFH4dj"
   },
   "source": [
    "# load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JBiY-knxK5mL"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-2b\")\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fOkujtWFPXIj"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "dataset = load_dataset(\"Skylion007/openwebtext\", split=\"train\", streaming=True, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "30z1aF3MPfBz"
   },
   "outputs": [],
   "source": [
    "batch_size = 150\n",
    "maxseqlen = 150\n",
    "\n",
    "def get_next_batch(dataset_iter, batch_size=100):\n",
    "    batch = []\n",
    "    for _ in range(batch_size):\n",
    "        try:\n",
    "            sample = next(dataset_iter)\n",
    "            batch.append(sample['text'])\n",
    "        except StopIteration:\n",
    "            break\n",
    "    return batch\n",
    "\n",
    "dataset_iter = iter(dataset)\n",
    "batch = get_next_batch(dataset_iter, batch_size)\n",
    "\n",
    "inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=maxseqlen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P4H47kPJzWbI"
   },
   "source": [
    "# load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tFZp3gdZzXMw"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b\")\n",
    "model_2 = AutoModelForCausalLM.from_pretrained(\"google/gemma-2-2b\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "j1By-RK81bPN"
   },
   "outputs": [],
   "source": [
    "inputs = {k: v.to('cuda') for k, v in inputs.items()}\n",
    "model = model.to('cuda')\n",
    "model_2 = model_2.to('cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K2u-VPinLNoT"
   },
   "source": [
    "# concept keywords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AJaBJ3ndNDO5"
   },
   "outputs": [],
   "source": [
    "keywords = {}\n",
    "\n",
    "keywords['Time'] = [\n",
    "    \"day\", \"night\", \"week\", \"month\", \"year\", \"hour\", \"minute\", \"second\", \"now\", \"soon\",\n",
    "    \"later\", \"early\", \"late\", \"morning\", \"evening\", \"noon\", \"midnight\", \"dawn\", \"dusk\", \"past\",\n",
    "    \"present\", \"future\", \"before\", \"after\", \"yesterday\", \"today\", \"tomorrow\", \"next\", \"previous\", \"soon\",\n",
    "    \"instant\", \"era\", \"age\", \"decade\", \"century\", \"millennium\",\n",
    "    \"moment\", \"pause\", \"wait\", \"begin\", \"start\", \"end\", \"finish\", \"stop\", \"continue\",\n",
    "    \"forever\", \"constant\", \"frequent\",\n",
    "    \"occasion\", \"season\", \"spring\", \"summer\", \"autumn\", \"fall\", \"winter\", \"anniversary\", \"deadline\", \"schedule\",\n",
    "    \"calendar\", \"clock\", \"duration\", \"interval\", \"epoch\", \"generation\", \"period\", \"cycle\", \"timespan\",\n",
    "    \"shift\", \"quarter\", \"term\", \"phase\", \"lifetime\", \"century\", \"minute\", \"timeline\", \"delay\",\n",
    "    \"prompt\", \"timely\", \"recurrent\", \"daily\", \"weekly\", \"monthly\", \"yearly\", \"annual\", \"biweekly\", \"timeframe\"\n",
    "]\n",
    "\n",
    "keywords['Calendar'] = [\n",
    "    \"day\", \"night\", \"week\", \"month\", \"year\", \"hour\", \"minute\", \"second\",\n",
    "    \"morning\", \"evening\", \"noon\", \"midnight\", \"dawn\", \"dusk\",\n",
    "    \"yesterday\", \"today\", \"tomorrow\",\n",
    "    \"decade\", \"century\", \"millennium\",\n",
    "    \"season\", \"spring\", \"summer\", \"autumn\", \"fall\", \"winter\",\n",
    "    \"calendar\", \"clock\",\n",
    "    \"century\", \"minute\",\n",
    "    \"daily\", \"weekly\", \"monthly\", \"yearly\", \"annual\", \"biweekly\", \"timeframe\"\n",
    "]\n",
    "\n",
    "keywords['People/Roles'] = [\n",
    "                \"man\", \"girl\", \"boy\", \"kid\", \"dad\", \"mom\", \"son\", \"sis\", \"bro\",\n",
    "                \"chief\", \"priest\", \"king\", \"queen\", \"duke\", \"lord\", \"friend\", \"clerk\", \"coach\",\n",
    "                \"nurse\", \"doc\", \"maid\", \"clown\", \"guest\", \"peer\",\n",
    "                \"punk\", \"nerd\", \"jock\", \"chief\"\n",
    "]\n",
    "\n",
    "keywords['Nature'] = [\n",
    "    \"tree\", \"grass\", \"stone\", \"rock\", \"cliff\", \"hill\",\n",
    "    \"dirt\", \"sand\", \"mud\", \"wind\", \"storm\", \"rain\", \"cloud\", \"sun\",\n",
    "    \"moon\", \"leaf\", \"branch\", \"twig\", \"root\", \"bark\", \"seed\",\n",
    "    \"tide\", \"lake\", \"pond\", \"creek\", \"sea\", \"wood\", \"field\",\n",
    "    \"shore\", \"snow\", \"ice\", \"flame\", \"fire\", \"fog\", \"dew\", \"hail\",\n",
    "    \"sky\", \"earth\", \"glade\", \"cave\", \"peak\", \"ridge\", \"dust\", \"air\",\n",
    "    \"mist\", \"heat\"\n",
    "]\n",
    "\n",
    "keywords['Emotions'] = [\n",
    "    \"joy\", \"glee\", \"pride\", \"grief\", \"fear\", \"hope\", \"love\", \"hate\", \"pain\", \"shame\",\n",
    "    \"bliss\", \"rage\", \"calm\", \"shock\", \"dread\", \"guilt\", \"peace\", \"trust\", \"scorn\", \"doubt\",\n",
    "    \"hurt\", \"wrath\", \"laugh\", \"cry\", \"smile\", \"frown\", \"gasp\", \"blush\", \"sigh\", \"grin\",\n",
    "    \"woe\", \"spite\", \"envy\", \"glow\", \"thrill\", \"mirth\", \"bored\", \"cheer\", \"charm\", \"grace\",\n",
    "    \"shy\", \"brave\", \"proud\", \"glad\", \"mad\", \"sad\", \"tense\", \"free\", \"kind\"\n",
    "]\n",
    "\n",
    "keywords['MonthNames'] = [\n",
    "    \"January\", \"February\", \"March\", \"April\", \"May\", \"June\", \"July\", \"August\", \"September\", \"October\", \"November\", \"December\"\n",
    "]\n",
    "\n",
    "keywords['Countries'] = [\n",
    "    \"USA\", \"Canada\", \"Brazil\", \"Mexico\", \"Germany\", \"France\", \"Italy\", \"Spain\", \"UK\", \"Australia\",\n",
    "    \"China\", \"Japan\", \"India\", \"Russia\", \"Korea\", \"Argentina\", \"Egypt\", \"Iran\", \"Turkey\"\n",
    "]\n",
    "\n",
    "keywords['Biology'] = [\n",
    "    \"gene\", \"DNA\", \"RNA\", \"virus\", \"bacteria\", \"fungus\",\n",
    "    \"brain\", \"lung\", \"blood\", \"bone\", \"skin\", \"muscle\", \"nerve\", \"vein\", \"organ\",\n",
    "    \"evolve\", \"enzyme\", \"protein\", \"lipid\", \"membrane\", \"antibody\", \"antigen\",\n",
    "    \"ligand\", \"substrate\", \"receptor\", \"cell\", \"chromosome\", \"nucleus\", \"cytoplasm\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aJDWTh4ixiiQ"
   },
   "source": [
    "# get labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aZVBjpOr35sX"
   },
   "source": [
    "## A layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "V14ywcWd35sY"
   },
   "outputs": [],
   "source": [
    "model_A_layer_labels = {} # layerID : fList_model_A_seqs\n",
    "modeltype = 'gemma1'\n",
    "gemma1_layers = [6, 10, 12, 17]\n",
    "\n",
    "for layer_id in gemma1_layers:\n",
    "    with torch.inference_mode():\n",
    "        outputs = get_llm_actvs_batch(model, inputs, layer_id)\n",
    "\n",
    "    sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "        release = \"gemma-2b-res-jb\",\n",
    "        sae_id = f\"blocks.{layer_id}.hook_resid_post\",\n",
    "    )\n",
    "    sae = sae.to('cuda')\n",
    "    sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "\n",
    "    weight_matrix_np = sae.W_dec.cpu().detach().numpy()\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        # feature_acts_A = sae.pre_acts(outputs.hidden_states[layer_id].to('cuda'))\n",
    "        feature_acts_A = sae.encode(outputs)\n",
    "\n",
    "    first_dim_reshaped = feature_acts_A.shape[0] * feature_acts_A.shape[1]\n",
    "    reshaped_activations_A = feature_acts_A.reshape(first_dim_reshaped, feature_acts_A.shape[-1]).cpu()\n",
    "\n",
    "    # store feature : lst of top strs\n",
    "    fList_model_A_seqs = []\n",
    "    samp_m = 5\n",
    "\n",
    "    for feature_idx in range(feature_acts_A.shape[-1]):\n",
    "        if feature_idx % 5000 == 0:\n",
    "            print('Feature: ', feature_idx)\n",
    "        ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_A, feature_idx, samp_m, batch_tokens= inputs['input_ids'])\n",
    "        fList_model_A_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )\n",
    "\n",
    "    model_A_layer_labels[layer_id] = fList_model_A_seqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tMBq3iee35sY"
   },
   "outputs": [],
   "source": [
    "modeltype = 'gemma1'\n",
    "with open(f'fList_allLayers_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(model_A_layer_labels, f)\n",
    "files.download(f'fList_allLayers_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X9njtqvcVFRR"
   },
   "source": [
    "### load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "e2x9jUmu35sY"
   },
   "outputs": [],
   "source": [
    "modeltype = 'gemma1'\n",
    "with open(f'fList_allLayers_{modeltype}.pkl', 'rb') as f:\n",
    "    model_A_layer_labels = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2jMKUm6K35sZ"
   },
   "outputs": [],
   "source": [
    "# del model\n",
    "# del feature_acts_A\n",
    "# torch.cuda.empty_cache()\n",
    "# gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zn1WZHTTb7Ik"
   },
   "source": [
    "# Gemma2: L6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qefncTyDb7Il"
   },
   "source": [
    "## get actvs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wmuPeyGdb7Il"
   },
   "outputs": [],
   "source": [
    "layer_id_2 = 6\n",
    "modeltype = 'gemma2'\n",
    "name = \"gemma-scope-2b-pt-res-canonical\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pVcirZtCb7Im"
   },
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    outputs_2 = get_llm_actvs_batch(model_2, inputs, layer_id_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gA51p1qSb7Im"
   },
   "outputs": [],
   "source": [
    "sae_2, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release = name,\n",
    "    sae_id = f\"layer_{layer_id_2}/width_16k/canonical\",\n",
    ")\n",
    "weight_matrix_2 = sae_2.W_dec.cpu().detach().numpy()\n",
    "\n",
    "sae_2 = sae_2.to('cuda')\n",
    "sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "with torch.no_grad():\n",
    "    # reshaped_activations_B = sae_2.encode(outputs_2.hidden_states[layer_id_2])\n",
    "    feature_acts_B = sae_2.encode(outputs_2)\n",
    "first_dim_reshaped = feature_acts_B.shape[0] * feature_acts_B.shape[1]\n",
    "reshaped_activations_B = feature_acts_B.reshape(first_dim_reshaped, feature_acts_B.shape[-1]).cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7Lj_CXchb7Im"
   },
   "source": [
    "## get labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qcMUOcS7b7Im"
   },
   "outputs": [],
   "source": [
    "# store feature : lst of top strs\n",
    "fList_model_B_seqs = []\n",
    "samp_m = 5\n",
    "\n",
    "for feature_idx in range(feature_acts_B.shape[-1]):\n",
    "# for feature_idx in range(5):\n",
    "    if feature_idx % 5000 == 0:\n",
    "        print('Feature: ', feature_idx)\n",
    "    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])\n",
    "    fList_model_B_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ySVvZeHQb7Im"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(fList_model_B_seqs, f)\n",
    "files.download(f'fList_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "neEzYvJIb7In"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'rb') as f:\n",
    "    fList_model_B_seqs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1lfa_Mo9b7In"
   },
   "source": [
    "## run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "69M3Ow2wb7In"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores = run_semantic_subspace_expms(fList_model_B_seqs, model_A_layer_labels, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3g5yQzyQb7In"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5t-0F0PVb7In"
   },
   "outputs": [],
   "source": [
    "modeltype= 'gemma'\n",
    "with open(f'concept_scores_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(layer_to_dictscores, f)\n",
    "files.download(f'concept_scores_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nAeGX4-Kb34p"
   },
   "source": [
    "# Gemma2: L10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5nJb5zMXb34q"
   },
   "source": [
    "## get actvs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WLmDSnuGb34q"
   },
   "outputs": [],
   "source": [
    "layer_id_2 = 10\n",
    "modeltype = 'gemma2'\n",
    "name = \"gemma-scope-2b-pt-res-canonical\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xIG5jolJb34q"
   },
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    outputs_2 = get_llm_actvs_batch(model_2, inputs, layer_id_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4sy3ECBkb34q"
   },
   "outputs": [],
   "source": [
    "sae_2, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release = name,\n",
    "    sae_id = f\"layer_{layer_id_2}/width_16k/canonical\",\n",
    ")\n",
    "weight_matrix_2 = sae_2.W_dec.cpu().detach().numpy()\n",
    "\n",
    "sae_2 = sae_2.to('cuda')\n",
    "sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "with torch.no_grad():\n",
    "    # reshaped_activations_B = sae_2.encode(outputs_2.hidden_states[layer_id_2])\n",
    "    feature_acts_B = sae_2.encode(outputs_2)\n",
    "first_dim_reshaped = feature_acts_B.shape[0] * feature_acts_B.shape[1]\n",
    "reshaped_activations_B = feature_acts_B.reshape(first_dim_reshaped, feature_acts_B.shape[-1]).cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wuK4KeKYb34q"
   },
   "source": [
    "## get labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "z8z9vjgmb34r"
   },
   "outputs": [],
   "source": [
    "# store feature : lst of top strs\n",
    "fList_model_B_seqs = []\n",
    "samp_m = 5\n",
    "\n",
    "for feature_idx in range(feature_acts_B.shape[-1]):\n",
    "# for feature_idx in range(5):\n",
    "    if feature_idx % 5000 == 0:\n",
    "        print('Feature: ', feature_idx)\n",
    "    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])\n",
    "    fList_model_B_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "th41klmXb34r"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(fList_model_B_seqs, f)\n",
    "files.download(f'fList_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-OifJZIib34r"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'rb') as f:\n",
    "    fList_model_B_seqs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4cgU6_0Pb34r"
   },
   "source": [
    "## run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gupqSGfTb34r"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores = run_semantic_subspace_expms(fList_model_B_seqs, model_A_layer_labels, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KIaS2-Cjb34r"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Mdu4XvhEb34s"
   },
   "outputs": [],
   "source": [
    "modeltype= 'gemma'\n",
    "with open(f'concept_scores_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(layer_to_dictscores, f)\n",
    "files.download(f'concept_scores_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AVhppXax4arY"
   },
   "source": [
    "# Gemma2: L14"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n6LSeanCXe4o"
   },
   "source": [
    "## get actvs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oVyoQhIXVUU5"
   },
   "outputs": [],
   "source": [
    "layer_id_2 = 14\n",
    "modeltype = 'gemma2'\n",
    "name = \"gemma-scope-2b-pt-res-canonical\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "348Tm9_BVUU5"
   },
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    outputs_2 = get_llm_actvs_batch(model_2, inputs, layer_id_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wn4N6rmMVUU5"
   },
   "outputs": [],
   "source": [
    "sae_2, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release = name,\n",
    "    sae_id = f\"layer_{layer_id_2}/width_16k/canonical\",\n",
    ")\n",
    "weight_matrix_2 = sae_2.W_dec.cpu().detach().numpy()\n",
    "\n",
    "sae_2 = sae_2.to('cuda')\n",
    "sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "with torch.no_grad():\n",
    "    # reshaped_activations_B = sae_2.encode(outputs_2.hidden_states[layer_id_2])\n",
    "    feature_acts_B = sae_2.encode(outputs_2)\n",
    "first_dim_reshaped = feature_acts_B.shape[0] * feature_acts_B.shape[1]\n",
    "reshaped_activations_B = feature_acts_B.reshape(first_dim_reshaped, feature_acts_B.shape[-1]).cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s4R1ZQz6XgUd"
   },
   "source": [
    "## get labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SmTrYYY7Ai0T"
   },
   "outputs": [],
   "source": [
    "# store feature : lst of top strs\n",
    "fList_model_B_seqs = []\n",
    "samp_m = 5\n",
    "\n",
    "for feature_idx in range(feature_acts_B.shape[-1]):\n",
    "# for feature_idx in range(5):\n",
    "    if feature_idx % 5000 == 0:\n",
    "        print('Feature: ', feature_idx)\n",
    "    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])\n",
    "    fList_model_B_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zDoPdBnyAi0T"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(fList_model_B_seqs, f)\n",
    "files.download(f'fList_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rsxn9_goAi0T"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'rb') as f:\n",
    "    fList_model_B_seqs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gB9wxKvaXdJW"
   },
   "source": [
    "## run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "shGyCwQ-4arZ"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores = run_semantic_subspace_expms(fList_model_B_seqs, model_A_layer_labels, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "k01Wglwk4arZ"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6wdmtVF14arZ"
   },
   "outputs": [],
   "source": [
    "modeltype= 'gemma'\n",
    "with open(f'concept_scores_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(layer_to_dictscores, f)\n",
    "files.download(f'concept_scores_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mNzXJkC-b_Cw"
   },
   "source": [
    "# Gemma2: L17"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2QnMheHfb_Cw"
   },
   "source": [
    "## get actvs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ybYOqNzCb_Cw"
   },
   "outputs": [],
   "source": [
    "layer_id_2 = 17\n",
    "modeltype = 'gemma2'\n",
    "name = \"gemma-scope-2b-pt-res-canonical\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4sXj7N4Xb_Cx"
   },
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    outputs_2 = get_llm_actvs_batch(model_2, inputs, layer_id_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "a5MXWs3Ib_Cx"
   },
   "outputs": [],
   "source": [
    "sae_2, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release = name,\n",
    "    sae_id = f\"layer_{layer_id_2}/width_16k/canonical\",\n",
    ")\n",
    "weight_matrix_2 = sae_2.W_dec.cpu().detach().numpy()\n",
    "\n",
    "sae_2 = sae_2.to('cuda')\n",
    "sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "with torch.no_grad():\n",
    "    # reshaped_activations_B = sae_2.encode(outputs_2.hidden_states[layer_id_2])\n",
    "    feature_acts_B = sae_2.encode(outputs_2)\n",
    "first_dim_reshaped = feature_acts_B.shape[0] * feature_acts_B.shape[1]\n",
    "reshaped_activations_B = feature_acts_B.reshape(first_dim_reshaped, feature_acts_B.shape[-1]).cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wAcbPhDVb_Cx"
   },
   "source": [
    "## get labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dhugq6mrb_Cx"
   },
   "outputs": [],
   "source": [
    "# store feature : lst of top strs\n",
    "fList_model_B_seqs = []\n",
    "samp_m = 5\n",
    "\n",
    "for feature_idx in range(feature_acts_B.shape[-1]):\n",
    "# for feature_idx in range(5):\n",
    "    if feature_idx % 5000 == 0:\n",
    "        print('Feature: ', feature_idx)\n",
    "    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_B, feature_idx, samp_m, batch_tokens= inputs['input_ids'])\n",
    "    fList_model_B_seqs.append(store_top_seqs(ds_top_acts_indices, ds_top_acts_values, inputs['input_ids']) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UZfe1HYHb_Cx"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(fList_model_B_seqs, f)\n",
    "files.download(f'fList_L{layer_id_2}_{modeltype}.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZcnmB0eRb_Cx"
   },
   "outputs": [],
   "source": [
    "with open(f'fList_L{layer_id_2}_{modeltype}.pkl', 'rb') as f:\n",
    "    fList_model_B_seqs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pk6gNRX7b_Cx"
   },
   "source": [
    "## run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J6ynxoQUb_Cx"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores = run_semantic_subspace_expms(fList_model_B_seqs, model_A_layer_labels, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ObVtLIuab_Cx"
   },
   "outputs": [],
   "source": [
    "layer_to_dictscores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0Si8KSpgb_Cx"
   },
   "outputs": [],
   "source": [
    "modeltype= 'gemma'\n",
    "with open(f'concept_scores_L{layer_id_2}_{modeltype}.pkl', 'wb') as f:\n",
    "    pickle.dump(layer_to_dictscores, f)\n",
    "files.download(f'concept_scores_L{layer_id_2}_{modeltype}.pkl')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyPOKA1cxCmS5xej2bQKexqY",
   "collapsed_sections": [
    "vlKdEehFvC86",
    "K2u-VPinLNoT",
    "n6LSeanCXe4o"
   ],
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
