{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a542467",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "model_ckpt = \"meta-llama/Llama-3.2-1B\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b9c875c",
   "metadata": {},
   "source": [
    "### Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b08b50c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import random\n",
    "import collections\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm.auto\n",
    "import plotly.express\n",
    "import plotly.graph_objects\n",
    "import sklearn.decomposition\n",
    "import matplotlib\n",
    "import PIL.Image\n",
    "import numpy as np\n",
    "from torch import Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5bca2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sinusoidal_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int,\n",
    "    max_value: int,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional\n",
    "    encoding works in transformers.\n",
    "\n",
    "    The encoding is an evaluation of a sine and cosine function at different frequencies, where the\n",
    "    frequency is determined by the embedding dimension and the allowed range of the input values.\n",
    "\n",
    "    >>> sinusoidal_encode(\n",
    "    ...     torch.tensor([-5, 2, 1, 0]),\n",
    "    ...     embedding_dim=6,\n",
    "    ...     min_value=-5,\n",
    "    ...     max_value=5,\n",
    "    ... )\n",
    "    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
    "            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],\n",
    "            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],\n",
    "            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])\n",
    "    \"\"\"\n",
    "\n",
    "    if embedding_dim % 2 != 0 and not use_l2_norm:\n",
    "        raise ValueError(\"Embedding dimension must be even\")\n",
    "\n",
    "    if use_l2_norm:\n",
    "        if embedding_dim % 2 == 0:\n",
    "            reserved_dim = 2\n",
    "        else:\n",
    "            reserved_dim = 1\n",
    "        embedding_dim -= reserved_dim\n",
    "    else:\n",
    "        reserved_dim = 0  # will not be used\n",
    "\n",
    "    domain = max_value - min_value\n",
    "    y_shape = x.shape + (embedding_dim,)\n",
    "    y = torch.zeros(y_shape, device=x.device)\n",
    "    even_indices = torch.arange(0, embedding_dim, 2)\n",
    "    log_term = torch.log(torch.tensor(domain)) / embedding_dim\n",
    "    div_term = torch.exp(even_indices * -log_term)\n",
    "    x = x - min_value\n",
    "    values = x.unsqueeze(-1).float() * div_term\n",
    "    y[..., 0::2] = torch.sin(values)\n",
    "    y[..., 1::2] = torch.cos(values)\n",
    "\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "\n",
    "    return y\n",
    "\n",
    "def binary_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int | float,\n",
    "    max_value: int | float,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)\n",
    "    reserve_dim = 0 if not use_l2_norm else 1\n",
    "    x = x - min_value\n",
    "    maximum = x.max()\n",
    "    for i in range(embedding_dim - reserve_dim):\n",
    "        coeff = 2**i\n",
    "        if maximum < coeff:\n",
    "            break\n",
    "        y[..., -i - 1] = torch.floor(x / coeff) % 2\n",
    "        x = x - coeff * y[..., -i - 1]\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ff97ee3",
   "metadata": {},
   "source": [
    "### Prepare model and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8723d3d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = model.half().to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b394470",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))\n",
    "train_mask = mask < 0.9\n",
    "valid_mask = ~train_mask & (mask < 0.95)\n",
    "test_mask = ~train_mask & ~valid_mask\n",
    "\n",
    "train_values = all_values[train_mask]\n",
    "valid_values = all_values[valid_mask]\n",
    "test_values = all_values[test_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c15ab591",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_inputs = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x1 + x2 < 1000]\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "        \n",
    "train_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in train_values_set]\n",
    "valid_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in valid_values_set]\n",
    "test_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in test_values_set]\n",
    "\n",
    "# sanity check\n",
    "assert set(train_inputs) & set(valid_inputs) == set()\n",
    "assert set(train_inputs) & set(test_inputs) == set()\n",
    "assert set(valid_inputs) & set(test_inputs) == set()\n",
    "\n",
    "rng_py = random.Random(0)\n",
    "rng_py.shuffle(train_inputs)\n",
    "rng_py.shuffle(valid_inputs)\n",
    "rng_py.shuffle(test_inputs)\n",
    "valid_size = 4096\n",
    "train_size = 100_000\n",
    "train_inputs = train_inputs[:train_size]\n",
    "valid_inputs = valid_inputs[:valid_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c41d1f88",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_str_input(operands: tuple[int, int] | list[int]) -> str:\n",
    "    x1, x2 = operands\n",
    "    return f\"{x1} + {x2}\"\n",
    "\n",
    "make_str_input((3, 500)), make_str_input((3, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36dd2f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hidden_states(model, str_inputs: list[str], batch_size: int) -> collections.defaultdict[int, Tensor]:\n",
    "    model.eval()\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\")\n",
    "            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da87aca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1024\n",
    "train_hidden_states = get_hidden_states(model, [make_str_input(val) for val in train_inputs], batch_size)\n",
    "valid_hidden_states = get_hidden_states(model, [make_str_input(val) for val in valid_inputs], batch_size)\n",
    "test_hidden_states = get_hidden_states(model, [make_str_input(val) for val in test_inputs], batch_size)\n",
    "\n",
    "train_labels = torch.tensor([x2 for x1, x2 in train_inputs])\n",
    "valid_labels = torch.tensor([x2 for x1, x2 in valid_inputs]).to(device)\n",
    "test_labels = torch.tensor([x2 for x1, x2 in test_inputs]).to(device) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6288a397",
   "metadata": {},
   "source": [
    "### Probe definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cff73d20",
   "metadata": {},
   "outputs": [],
   "source": [
    "basis_embs_sin = sinusoidal_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=train_hidden_states[0].shape[-1],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "184992d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierProbe(torch.nn.Module):\n",
    "    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):\n",
    "        super().__init__()\n",
    "        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)\n",
    "        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)\n",
    "        self.basis: torch.nn.Buffer\n",
    "        self.heldout_mask: torch.nn.Buffer\n",
    "        self.register_buffer(\"basis\", basis)\n",
    "        self.register_buffer(\"heldout_mask\", heldout_mask)\n",
    "        \n",
    "    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:\n",
    "        latent_x = self.emb_to_latent(x)\n",
    "        # during training, model learns to choose among only training tokens\n",
    "        # but during eval, model must choose among all tokens\n",
    "        # this means that the model is never exposed to the eval tokens during training\n",
    "        latent_choices = self.basis_to_latent(self.basis)\n",
    "        logits = latent_x @ latent_choices.T\n",
    "        if holdout_eval_tokens:\n",
    "            logits[:, self.heldout_mask] = float(\"-inf\")\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e7b9947",
   "metadata": {},
   "source": [
    "### L1-regularized probes and evaluation of cross-layer transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d3d2c7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "probes_l1 = {}\n",
    "\n",
    "histories_l1 = []\n",
    "for layer_idx in range(0, len(train_hidden_states)):\n",
    "\n",
    "    torch.manual_seed(0)\n",
    "    probe = ClassifierProbe(\n",
    "        emb_dim=train_hidden_states[0].shape[-1],\n",
    "        hidden_dim=100,\n",
    "        basis=basis_embs_sin,\n",
    "        heldout_mask=test_mask,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(probe.parameters(), lr=1e-4, weight_decay=0)\n",
    "\n",
    "    rng = torch.Generator().manual_seed(0)\n",
    "    best_val_acc = -1\n",
    "    best_ckpt = None\n",
    "    for step in range(10000+1):\n",
    "        probe.train()\n",
    "        optimizer.zero_grad()\n",
    "        minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "        x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "        y = train_labels[minibatch_idcs].to(device)\n",
    "        train_logits = probe(x, holdout_eval_tokens=True)\n",
    "        loss = torch.nn.functional.cross_entropy(train_logits, y)\n",
    "        loss += 1e-3 * sum(p.abs().sum() for p in probe.parameters()) # L1 regularization\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if step % 1000 == 0:\n",
    "            train_acc = (train_logits.argmax(dim=-1) == y).float().mean().item()\n",
    "            probe.eval()\n",
    "            with torch.no_grad():\n",
    "                valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                valid_acc = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                if valid_acc > best_val_acc:\n",
    "                    best_val_acc = valid_acc\n",
    "                    best_ckpt = probe.state_dict()\n",
    "            entry = {\"layer\": layer_idx, \"step\": step, \"train_loss\": loss.item(), \"train_acc\": train_acc, \"valid_loss\": valid_loss.item(), \"valid_acc\": valid_acc}\n",
    "            histories_l1.append(entry)\n",
    "            print(f\"{layer_idx=:<3}  {step=:<7}  {loss=:<7.2f}  {train_acc=:<8.2%}  {valid_loss=:<7.2f}  {valid_acc=:<7.2%}\")\n",
    "    print()\n",
    "    probe.load_state_dict(best_ckpt)\n",
    "    probes_l1[layer_idx] = probe\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06cf9529",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_accuracies_l1 = torch.zeros((len(probes_l1), len(test_hidden_states))) - float(\"nan\")\n",
    "for probe_idx, probe in enumerate(probes_l1.values()):\n",
    "    probe.eval()\n",
    "    for layer_idx in range(len(test_hidden_states)):\n",
    "        with torch.no_grad():\n",
    "            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        test_accuracies_l1[probe_idx, layer_idx] = test_accuracy\n",
    "\n",
    "plotly.express.imshow(\n",
    "    test_accuracies_l1,\n",
    "    labels={\"y\": \"Layer fit idx\", \"x\": \"Layer eval idx\", \"color\": \"Test accuracy\"},\n",
    "    color_continuous_scale=\"Reds\"\n",
    ").update_layout(\n",
    "    yaxis=dict(tickvals=list(range(len(probes_l1))), ticktext=list(probes_l1.keys()))\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dad1281b",
   "metadata": {},
   "source": [
    "### L2-regularized probes and evalution of cross-layer transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b79c0165",
   "metadata": {},
   "outputs": [],
   "source": [
    "probes_l2 = {}\n",
    "\n",
    "histories_l2 = []\n",
    "for layer_idx in range(0, len(train_hidden_states)):\n",
    "\n",
    "    torch.manual_seed(0)\n",
    "    probe = ClassifierProbe(\n",
    "        emb_dim=train_hidden_states[0].shape[-1],\n",
    "        hidden_dim=100,\n",
    "        basis=basis_embs_sin,\n",
    "        heldout_mask=test_mask,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(probe.parameters(), lr=1e-4, weight_decay=1e-3)\n",
    "\n",
    "    rng = torch.Generator().manual_seed(0)\n",
    "    best_val_acc = -1\n",
    "    best_ckpt = None\n",
    "    for step in range(10000+1):\n",
    "        probe.train()\n",
    "        optimizer.zero_grad()\n",
    "        minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "        x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "        y = train_labels[minibatch_idcs].to(device)\n",
    "        train_logits = probe(x, holdout_eval_tokens=True)\n",
    "        loss = torch.nn.functional.cross_entropy(train_logits, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if step % 1000 == 0:\n",
    "            train_acc = (train_logits.argmax(dim=-1) == y).float().mean().item()\n",
    "            probe.eval()\n",
    "            with torch.no_grad():\n",
    "                valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                valid_acc = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                if valid_acc > best_val_acc:\n",
    "                    best_val_acc = valid_acc\n",
    "                    best_ckpt = probe.state_dict()\n",
    "            entry = {\"layer\": layer_idx, \"step\": step, \"train_loss\": loss.item(), \"train_acc\": train_acc, \"valid_loss\": valid_loss.item(), \"valid_acc\": valid_acc}\n",
    "            histories_l2.append(entry)\n",
    "            print(f\"{layer_idx=:<3}  {step=:<7}  {loss=:<7.2f}  {train_acc=:<8.2%}  {valid_loss=:<7.2f}  {valid_acc=:<7.2%}\")\n",
    "    print()\n",
    "    probe.load_state_dict(best_ckpt)\n",
    "    probes_l2[layer_idx] = probe\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1536cda",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_accuracies_l2 = torch.zeros((len(probes_l2), len(test_hidden_states))) - float(\"nan\")\n",
    "for probe_idx, probe in enumerate(probes_l2.values()):\n",
    "    probe.eval()\n",
    "    for layer_idx in range(len(test_hidden_states)):\n",
    "        with torch.no_grad():\n",
    "            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        test_accuracies_l2[probe_idx, layer_idx] = test_accuracy\n",
    "\n",
    "plotly.express.imshow(\n",
    "    test_accuracies_l2,\n",
    "    labels={\"y\": \"Layer fit idx\", \"x\": \"Layer eval idx\", \"color\": \"Test accuracy\"},\n",
    "    color_continuous_scale=\"Reds\"\n",
    ").update_layout(\n",
    "    yaxis=dict(tickvals=list(range(len(probes_l2))), ticktext=list(probes_l2.keys()))\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8107644c",
   "metadata": {},
   "source": [
    "### Held-one-out cross-layer transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0577fd10",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = torch.Generator().manual_seed(0)\n",
    "rng_py = random.Random(0)\n",
    "\n",
    "assert list(train_hidden_states.keys()) == list(range(len(train_hidden_states)))\n",
    "train_hidden_states_tensor = torch.stack(list(train_hidden_states.values()), dim=0)\n",
    "\n",
    "heldoneout_probes = {}\n",
    "heldoneout_histories = []\n",
    "\n",
    "for heldout_layer_idx in range(len(train_hidden_states)):\n",
    "\n",
    "    torch.manual_seed(0)\n",
    "    probe = ClassifierProbe(\n",
    "        emb_dim=train_hidden_states[0].shape[-1],\n",
    "        hidden_dim=100,\n",
    "        basis=basis_embs_sin,\n",
    "        heldout_mask=test_mask,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(probe.parameters(), lr=1e-4, weight_decay=0)\n",
    "\n",
    "    train_layers = [i for i in range(len(train_hidden_states)) if i != heldout_layer_idx]\n",
    "\n",
    "    print(\"HELDOUT LAYER:\", heldout_layer_idx)\n",
    "    for step in range(10000+1):\n",
    "        probe.train()\n",
    "        optimizer.zero_grad()\n",
    "        layer_idcs = torch.tensor(rng_py.choices(train_layers, k=1024))\n",
    "        minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "        x = train_hidden_states_tensor[layer_idcs, minibatch_idcs].float().to(device)\n",
    "        y = train_labels[minibatch_idcs].to(device)\n",
    "        train_logits = probe(x, holdout_eval_tokens=True)\n",
    "        loss = torch.nn.functional.cross_entropy(train_logits, y)\n",
    "        loss += 1e-3 * sum(p.abs().sum() for p in probe.parameters()) # L1 regularization\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        best_val_acc = -1\n",
    "        best_ckpt = probe.state_dict()\n",
    "\n",
    "        if step % 1000 == 0:\n",
    "            probe.eval()\n",
    "            valid_accs = []\n",
    "            with torch.no_grad():\n",
    "                print(f\"{step=:<5}\", end=\"  \")\n",
    "                for layer_idx in range(0, len(train_hidden_states)):\n",
    "                    valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                    valid_acc = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                    valid_accs.append(valid_acc)\n",
    "                    heldoneout_histories.append({\"heldout_layer\": heldout_layer_idx, \"step\": step, \"eval_layer\": layer_idx, \"valid_acc\": valid_acc})\n",
    "                    acc_out = f\"{valid_acc:>6.1%}\"\n",
    "                    if layer_idx not in train_layers:\n",
    "                        print('\\033[94m' + acc_out + '\\033[0m', end=\" \")\n",
    "                    else:\n",
    "                        print(acc_out, end=\" \")\n",
    "                print()\n",
    "                if valid_accs[heldout_layer_idx] > best_val_acc:\n",
    "                    best_val_acc = valid_accs[heldout_layer_idx]\n",
    "                    best_ckpt = probe.state_dict()\n",
    "\n",
    "        probe.load_state_dict(best_ckpt)\n",
    "        heldoneout_probes[heldout_layer_idx] = probe\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68466a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "heldoneout_test_accs = torch.zeros(len(heldoneout_probes)) - float(\"nan\")\n",
    "for heldout_layer_idx, probe in heldoneout_probes.items():\n",
    "    probe.eval()\n",
    "    with torch.no_grad():\n",
    "        test_logits = probe(test_hidden_states[heldout_layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "        test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        heldoneout_test_accs[heldout_layer_idx] = test_accuracy\n",
    "\n",
    "plotly.express.bar(\n",
    "    x=list(heldoneout_probes.keys()),\n",
    "    y=heldoneout_test_accs.numpy(),\n",
    "    labels={\"x\": \"Heldout layer idx\", \"y\": \"Test accuracy\"},\n",
    "    title=\"Heldout Layer Probe Test Accuracies\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7062ff2c",
   "metadata": {},
   "source": [
    "### Select layers for further visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119c612c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# take first 3, last 3 and middle 3 from sorted(probes_l1.keys())\n",
    "layer_idcs_to_plot = sorted(probes_l1.keys())\n",
    "layer_idcs_to_plot = layer_idcs_to_plot[:3] + layer_idcs_to_plot[len(layer_idcs_to_plot)//2-2:len(layer_idcs_to_plot)//2+1] + layer_idcs_to_plot[-3:]\n",
    "layer_idcs_to_plot = sorted(set(layer_idcs_to_plot))\n",
    "layer_idcs_to_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fec31bd6",
   "metadata": {},
   "source": [
    "### Visualizations of fourier of activations' PCA in various layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "662789d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pca(embs: Tensor, low_dim: int) -> tuple[Tensor, Tensor]:\n",
    "    pca = sklearn.decomposition.PCA(n_components=low_dim)\n",
    "    reduced_embs = pca.fit_transform(embs.detach().numpy())\n",
    "    return torch.tensor(reduced_embs), torch.tensor(pca.explained_variance_ratio_)\n",
    "\n",
    "def fourier(embs: Tensor) -> Tensor:\n",
    "    return torch.fft.fft(embs, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0e85e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "inputs = tokenizer([str(x) for x in all_values.tolist()], return_tensors=\"pt\")\n",
    "all_representations = torch.stack(model(**inputs.to(device), output_hidden_states=True).hidden_states)\n",
    "all_representations = all_representations[:, :, -1, :] # get the last token (remove BOS)\n",
    "all_representations.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "758adf40",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_ffts = []\n",
    "for layer_idx in range(len(all_representations)):\n",
    "    representations = all_representations[layer_idx].cpu()\n",
    "    repr_pca, explained_var = pca(representations.float(), low_dim=100)\n",
    "    repr_fft = fourier(repr_pca).abs().T.sum(dim=0)\n",
    "    layer_ffts.append(repr_fft)\n",
    "\n",
    "layer_ffts = torch.stack(layer_ffts)\n",
    "fft_correlations = torch.corrcoef(layer_ffts)\n",
    "\n",
    "plotly.express.imshow(\n",
    "    # correlations (similarity) between layers' Fourier spectra\n",
    "    fft_correlations,\n",
    "    zmin=0,\n",
    "    zmax=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85d3085d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer_idx in layer_idcs_to_plot:\n",
    "    representations = all_representations[layer_idx].cpu()\n",
    "    repr_pca, explained_var = pca(representations.float(), low_dim=100)\n",
    "    repr_fft = fourier(repr_pca)\n",
    "    \n",
    "    plotly.express.bar(\n",
    "        torch.abs(repr_fft).T.sum(dim=0).cpu().detach(),\n",
    "    ).update_layout(title=f\"Layer {layer_idx} PCA Fourier Frequencies\", showlegend=False).update_xaxes(title=\"Frequency\").show()\n",
    "\n",
    "    # plotly.express.imshow(\n",
    "    #     torch.abs(repr_fft).T.cpu().detach(),\n",
    "    #     color_continuous_scale=\"Reds\",\n",
    "    #     aspect='auto',\n",
    "    # ).update_layout(title=f\"Layer {layer_idx} PCA Fourier Frequency Heatmap\").update_xaxes(title=\"Frequency\").update_yaxes(title=\"Feature\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1f760ed",
   "metadata": {},
   "source": [
    "### Visualization of probe sparsity (features' scatteredness)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98822c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "usages = []\n",
    "for constant in [1e-1, 1e-2, 1e-3, 1e-4]:\n",
    "    usages.append([])\n",
    "    for layer_idx in range(len(probes_l1)):\n",
    "        probe = probes_l1[layer_idx]\n",
    "        feature_usage = (probe.emb_to_latent.weight.abs() > constant).any(dim=0).float().mean().item()\n",
    "        usages[-1].append(feature_usage)\n",
    "usages = torch.tensor(usages)\n",
    "usages = usages.diff(dim=0, prepend=torch.zeros((1, usages.shape[1])))\n",
    "\n",
    "plotly.express.bar(\n",
    "    x=list(probes_l1.keys()),\n",
    "    y=usages.tolist().copy(),\n",
    "    # make a stacked bar chart with darkest color at the bottom, pick predefined colors\n",
    "    color_discrete_sequence=plotly.colors.sequential.Blues[5::-1],\n",
    ").update_layout(title=\"How many features are used by the probe\").update_xaxes(title=\"Layer idx\").show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de53d2c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer_idx in layer_idcs_to_plot:\n",
    "    probe = probes_l1[layer_idx]\n",
    "    x: Tensor = probe.emb_to_latent.weight.cpu().detach().abs()\n",
    "    x_normalized = (x - x.min()) / (x.max() - x.min())\n",
    "    vis = PIL.Image.fromarray((x_normalized * 255).byte().numpy())\n",
    "    print(f\"Layer {layer_idx} Probe's Weights:\", flush=True)\n",
    "    display(vis)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
