{
 "cells": [
  {
   "cell_type": "code",
   "id": "2955d0f927ac4732",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:06.403897Z",
     "start_time": "2025-09-21T19:43:06.399421Z"
    }
   },
   "source": [
    "device = \"cuda:0\""
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "532e6d5a5bfb58d7",
   "metadata": {},
   "source": [
    "### Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "id": "dcda88145fb79c1c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:08.404383Z",
     "start_time": "2025-09-21T19:43:06.856269Z"
    }
   },
   "source": [
    "import itertools\n",
    "import random\n",
    "import collections\n",
    "\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm.auto\n",
    "from torch import Tensor"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "id": "614f8a6b0c65d853",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:08.619985Z",
     "start_time": "2025-09-21T19:43:08.484931Z"
    }
   },
   "source": [
    "def sinusoidal_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int,\n",
    "    max_value: int,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional\n",
    "    encoding works in transformers.\n",
    "\n",
    "    The encoding is an evaluation of a sine and cosine function at different frequencies, where the\n",
    "    frequency is determined by the embedding dimension and the allowed range of the input values.\n",
    "\n",
    "    >>> sinusoidal_encode(\n",
    "    ...     torch.tensor([-5, 2, 1, 0]),\n",
    "    ...     embedding_dim=6,\n",
    "    ...     min_value=-5,\n",
    "    ...     max_value=5,\n",
    "    ... )\n",
    "    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
    "            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],\n",
    "            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],\n",
    "            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])\n",
    "    \"\"\"\n",
    "\n",
    "    if embedding_dim % 2 != 0 and not use_l2_norm:\n",
    "        raise ValueError(\"Embedding dimension must be even\")\n",
    "\n",
    "    if use_l2_norm:\n",
    "        if embedding_dim % 2 == 0:\n",
    "            reserved_dim = 2\n",
    "        else:\n",
    "            reserved_dim = 1\n",
    "        embedding_dim -= reserved_dim\n",
    "    else:\n",
    "        reserved_dim = 0  # will not be used\n",
    "\n",
    "    domain = max_value - min_value\n",
    "    y_shape = x.shape + (embedding_dim,)\n",
    "    y = torch.zeros(y_shape, device=x.device)\n",
    "    even_indices = torch.arange(0, embedding_dim, 2)\n",
    "    log_term = torch.log(torch.tensor(domain)) / embedding_dim\n",
    "    div_term = torch.exp(even_indices * -log_term)\n",
    "    x = x - min_value\n",
    "    values = x.unsqueeze(-1).float() * div_term\n",
    "    y[..., 0::2] = torch.sin(values)\n",
    "    y[..., 1::2] = torch.cos(values)\n",
    "\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "\n",
    "    return y\n",
    "\n",
    "\n",
    "def binary_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int | float,\n",
    "    max_value: int | float,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)\n",
    "    reserve_dim = 0 if not use_l2_norm else 1\n",
    "    x = x - min_value\n",
    "    maximum = x.max()\n",
    "    for i in range(embedding_dim - reserve_dim):\n",
    "        coeff = 2**i\n",
    "        if maximum < coeff:\n",
    "            break\n",
    "        y[..., -i - 1] = torch.floor(x / coeff) % 2\n",
    "        x = x - coeff * y[..., -i - 1]\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "    return y"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "id": "d0d14df112e1e26c",
   "metadata": {},
   "source": [
    "### Prepare model and data"
   ]
  },
  {
   "cell_type": "code",
   "id": "25e209717977de2b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:16.254954Z",
     "start_time": "2025-09-21T19:43:08.744443Z"
    }
   },
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-3B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = model.half().to(device).eval()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0bed2996545140b1b6efa3cdf6e21993"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "id": "98e6df87a1182d14",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:16.376757Z",
     "start_time": "2025-09-21T19:43:16.372027Z"
    }
   },
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))\n",
    "train_mask = mask < 0.9\n",
    "valid_mask = ~train_mask & (mask < 0.95)\n",
    "test_mask = ~train_mask & ~valid_mask\n",
    "\n",
    "train_values = all_values[train_mask]\n",
    "valid_values = all_values[valid_mask]\n",
    "test_values = all_values[test_mask]"
   ],
   "outputs": [],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "id": "f00e5e4134d983a7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:17.458856Z",
     "start_time": "2025-09-21T19:43:16.534783Z"
    }
   },
   "source": [
    "all_inputs = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x2 and x1 / x2 < 1000]\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "\n",
    "all_inputs_add = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x1 + x2 < 1000]\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "\n",
    "train_inputs = [(x1, x2) for x1, x2 in all_inputs if x1 / x2 in train_values_set]\n",
    "train_inputs_add = [(x1, x2) for x1, x2 in all_inputs_add if x1 + x2 in train_values_set]\n",
    "valid_inputs = [(x1, x2) for x1, x2 in all_inputs if x1 / x2 in valid_values_set]\n",
    "valid_inputs_add = [(x1, x2) for x1, x2 in all_inputs_add if x1 + x2 in valid_values_set]\n",
    "test_inputs = [(x1, x2) for x1, x2 in all_inputs if x1 / x2 in test_values_set]\n",
    "test_inputs_add = [(x1, x2) for x1, x2 in all_inputs_add if x1 + x2 in test_values_set]\n",
    "\n",
    "# sanity check\n",
    "assert set(train_inputs) & set(valid_inputs) == set()\n",
    "assert set(train_inputs) & set(test_inputs) == set()\n",
    "assert set(valid_inputs) & set(test_inputs) == set()\n",
    "\n",
    "assert set(train_inputs_add) & set(valid_inputs_add) == set()\n",
    "assert set(train_inputs_add) & set(test_inputs_add) == set()\n",
    "assert set(valid_inputs_add) & set(test_inputs_add) == set()\n",
    "\n",
    "random.seed(0)\n",
    "random.shuffle(train_inputs)\n",
    "random.shuffle(valid_inputs)\n",
    "random.shuffle(test_inputs)\n",
    "\n",
    "random.shuffle(train_inputs_add)\n",
    "random.shuffle(valid_inputs_add)\n",
    "random.shuffle(test_inputs_add)\n",
    "\n",
    "valid_size = 4096\n",
    "train_size = 50_000  # TODO: change back to 100_000\n",
    "train_inputs = train_inputs[:train_size]\n",
    "train_inputs_add = train_inputs_add[:train_size]\n",
    "valid_inputs = valid_inputs[:valid_size]"
   ],
   "outputs": [],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "id": "41e33bc2ed4f3ecd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:17.682341Z",
     "start_time": "2025-09-21T19:43:17.673192Z"
    }
   },
   "source": [
    "num_templates = 5  # TODO revert\n",
    "\n",
    "def make_str_input(operands: tuple[int, int] | list[int], template_idx: int = 1) -> str:\n",
    "    x1, x2 = operands\n",
    "    options = [\n",
    "        f\"{x1} divided by {x2} is \",\n",
    "        f\"{x1} divided by {x2} equals to \",\n",
    "        f\"{x1} / {x2} = \",\n",
    "        f\"A division of {x1} by {x2} equals to \",\n",
    "        f\"A result of dividing {x1} by {x2} is \",\n",
    "    ]\n",
    "    # assert num_templates == len(options)\n",
    "    # return f\"{x1} times {x2} is \"  # 0.78\n",
    "    # return f\"{x1} multiplied by {x2} is \"  # 90.38\n",
    "    return options[template_idx]\n",
    "\n",
    "make_str_input((3, 500)), make_str_input((3, 0))"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('3 divided by 500 equals to ', '3 divided by 0 equals to ')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "id": "4d58550541c86d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:17.879759Z",
     "start_time": "2025-09-21T19:43:17.870575Z"
    }
   },
   "source": [
    "def make_str_input_add(operands: tuple[int, int] | list[int]) -> str:\n",
    "    x1, x2 = operands\n",
    "    return f\"{x1} plus {x2} is equal to \"  # TODO: maybe switch back\n",
    "\n",
    "make_str_input_add((3, 500)), make_str_input_add((3, 0))"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('3 plus 500 is equal to ', '3 plus 0 is equal to ')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "id": "97d1f38e76bac3aa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:43:18.097332Z",
     "start_time": "2025-09-21T19:43:18.084422Z"
    }
   },
   "source": [
    "def get_hidden_states_and_preds(model, str_inputs: list[str], batch_size: int) -> tuple[dict[int, Tensor], list[str]]:\n",
    "    model.eval()\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    model_preds = []\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches, desc=\"Inferring model hidden states\"):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\")\n",
    "            model_outputs = model(**batch_inputs.to(model.device), output_hidden_states=True)\n",
    "            hidden_reprs = model_outputs.hidden_states\n",
    "            logits = model_outputs.last_hidden_state @ model.embed_tokens.weight.T\n",
    "            next_token_ids = logits[:, -1, :].argmax(dim=-1)\n",
    "            model_preds.extend(tokenizer.batch_decode(next_token_ids))\n",
    "\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}, model_preds"
   ],
   "outputs": [],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "id": "8f26e3c9f4c6e5bd",
   "metadata": {
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2025-09-21T19:43:18.254707Z"
    }
   },
   "source": [
    "batch_size = 1024\n",
    "# train_hidden_states, train_preds = get_hidden_states_and_preds(\n",
    "#         model,\n",
    "#         [make_str_input(val, 0) for val in train_inputs] + [make_str_input(val, 1) for val in train_inputs] + [make_str_input(val, 2) for val in train_inputs],\n",
    "#         batch_size\n",
    "# )\n",
    "states_preds = [get_hidden_states_and_preds(model, [make_str_input(val, i) for val in train_inputs], batch_size) for i in range(num_templates)]\n",
    "\n",
    "hidden_states_all = [x[0] for x in states_preds]\n",
    "preds_all = [x[1] for x in states_preds]\n",
    "\n",
    "train_hidden_states = {k: torch.concat([hidden_states_all[i][k] for i in range(num_templates)]) for k in hidden_states_all[0].keys()}\n",
    "train_preds = list(itertools.chain(*preds_all))\n",
    "\n",
    "# train_hidden_states, train_preds = get_hidden_states_and_preds(\n",
    "#         model,\n",
    "#         [make_str_input(val) for val in train_inputs],\n",
    "#         batch_size\n",
    "# )\n",
    "# train_hidden_states, train_preds = get_hidden_states_and_preds(\n",
    "#         model,\n",
    "#         [make_str_input_add(val) for val in train_inputs_add],\n",
    "#         batch_size\n",
    "# )\n",
    "# valid_hidden_states, valid_preds = get_hidden_states_and_preds(\n",
    "#         model,\n",
    "#         [make_str_input(val) for val in valid_inputs],\n",
    "#         batch_size\n",
    "# )\n",
    "# test_hidden_states, test_preds = get_hidden_states_and_preds(\n",
    "#         model,\n",
    "#         [make_str_input(val) for val in test_inputs],\n",
    "#         batch_size\n",
    "# )"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Inferring model hidden states:   0%|          | 0/8 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "e8fc6bcfad8e43e7b9e64e7344501ffb"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5264b095f386b894",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:56.228177Z",
     "start_time": "2025-09-21T19:39:55.942182Z"
    }
   },
   "source": [
    "valid_hidden_states, valid_preds = get_hidden_states_and_preds(\n",
    "        model,\n",
    "        [make_str_input(val) for val in valid_inputs],\n",
    "        batch_size\n",
    ")\n",
    "test_hidden_states, test_preds = get_hidden_states_and_preds(\n",
    "        model,\n",
    "        [make_str_input(val) for val in test_inputs],\n",
    "        batch_size\n",
    ")"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Inferring model hidden states:   0%|          | 0/1 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b5742cae3bd84604bd7080ae34f23350"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Inferring model hidden states:   0%|          | 0/1 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "eb2c2de928254953905e09a17480d396"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 42
  },
  {
   "cell_type": "code",
   "id": "a987d1e1dd401fe2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:56.320338Z",
     "start_time": "2025-09-21T19:39:56.307681Z"
    }
   },
   "source": [
    "train_inputs_t = torch.tensor(train_inputs)\n",
    "\n",
    "train_inputs_t[:, 0] / train_inputs_t[:, 1]"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.,  8.,  1.,  ...,  2., 22.,  3.])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 43
  },
  {
   "cell_type": "code",
   "id": "598f9a35335acb13",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:56.516299Z",
     "start_time": "2025-09-21T19:39:56.426294Z"
    }
   },
   "source": [
    "def sanitize_pred(pred: str) -> int:\n",
    "    try:\n",
    "        return int(pred)\n",
    "    except ValueError:\n",
    "        return -1"
   ],
   "outputs": [],
   "execution_count": 44
  },
  {
   "cell_type": "code",
   "id": "fbb0c06e4e6a5d91",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:56.766923Z",
     "start_time": "2025-09-21T19:39:56.651554Z"
    }
   },
   "source": [
    "test_inputs_t = torch.tensor(test_inputs)\n",
    "\n",
    "train_preds_t = torch.tensor([sanitize_pred(pred) for pred in train_preds])\n",
    "valid_preds_t = torch.tensor([sanitize_pred(pred) for pred in valid_preds])\n",
    "test_preds_t = torch.tensor([sanitize_pred(pred) for pred in test_preds])\n",
    "\n",
    "# ratio of properly extracted train predictions\n",
    "sum(train_preds_t != -1) / len(train_preds_t)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 45
  },
  {
   "cell_type": "code",
   "id": "40311667-9942-46fc-8e54-8b886b4a2bdf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:56.910775Z",
     "start_time": "2025-09-21T19:39:56.896673Z"
    }
   },
   "source": [
    "# absolute model accuracy on test set\n",
    "test_labels_ref = torch.tensor([x1 / x2 for x1, x2 in test_inputs])\n",
    "sum(test_preds_t == test_labels_ref) / len(test_inputs)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8941)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 46
  },
  {
   "cell_type": "markdown",
   "id": "e246be877b3eee1c",
   "metadata": {},
   "source": [
    "### Probing"
   ]
  },
  {
   "cell_type": "code",
   "id": "95b127e4fa0b35a1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:57.091012Z",
     "start_time": "2025-09-21T19:39:57.074743Z"
    }
   },
   "source": [
    "basis_embs_sin = sinusoidal_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=train_hidden_states[0].shape[-1],\n",
    ")\n",
    "\n",
    "\n",
    "basis_embs_bin = binary_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=10,\n",
    ")"
   ],
   "outputs": [],
   "execution_count": 47
  },
  {
   "cell_type": "code",
   "id": "364edd5b36dfd13b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:57.252486Z",
     "start_time": "2025-09-21T19:39:57.228798Z"
    }
   },
   "source": [
    "class ClassifierProbe(torch.nn.Module):\n",
    "    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):\n",
    "        super().__init__()\n",
    "        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)\n",
    "        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)\n",
    "        self.basis: torch.nn.Buffer\n",
    "        self.heldout_mask: torch.nn.Buffer\n",
    "        self.register_buffer(\"basis\", basis)\n",
    "        self.register_buffer(\"heldout_mask\", heldout_mask)\n",
    "    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:\n",
    "        latent_x = self.emb_to_latent(x)\n",
    "        # during training, model learns to choose among only training tokens\n",
    "        # but during eval, model must choose among all tokens\n",
    "        # this means that the model is never exposed to the eval tokens during training\n",
    "        latent_choices = self.basis_to_latent(self.basis)\n",
    "        logits = latent_x @ latent_choices.T\n",
    "        if holdout_eval_tokens:\n",
    "            logits[:, self.heldout_mask] = float(\"-inf\")\n",
    "        return logits"
   ],
   "outputs": [],
   "execution_count": 48
  },
  {
   "cell_type": "code",
   "id": "708ffb5bef3dbd49",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:57.440473Z",
     "start_time": "2025-09-21T19:39:57.393783Z"
    }
   },
   "source": [
    "# train_labels_ref = torch.tensor([x1 * x2 for x1, x2 in train_inputs])\n",
    "# valid_labels_ref = torch.tensor([x1 * x2 for x1, x2 in valid_inputs]).to(device)\n",
    "# test_labels_ref = torch.tensor([x1 * x2 for x1, x2 in test_inputs]).to(device)\n"
   ],
   "outputs": [],
   "execution_count": 49
  },
  {
   "cell_type": "code",
   "id": "4451d0a344301b40",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:57.669358Z",
     "start_time": "2025-09-21T19:39:57.659716Z"
    }
   },
   "source": [
    "train_hidden_states[0].shape"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7636, 3072])"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 50
  },
  {
   "cell_type": "code",
   "id": "c810d88ec56cede3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:58.050551Z",
     "start_time": "2025-09-21T19:39:57.829573Z"
    }
   },
   "source": [
    "train_labels_ref = torch.tensor([x1 / x2 for x1, x2 in train_inputs]*num_templates)\n",
    "# train_labels_ref = torch.tensor([x1 + x2 for x1, x2 in train_inputs])\n",
    "\n",
    "valid_labels_ref = torch.tensor([x1 / x2 for x1, x2 in valid_inputs]).to(device)\n",
    "test_labels_ref = torch.tensor([x1 / x2 for x1, x2 in test_inputs]).to(device)\n",
    "\n",
    "train_labels = train_preds_t.detach().clone()\n",
    "train_hidden_states = {k: v[train_labels != -1] for k, v in train_hidden_states.items()}\n",
    "train_labels = train_labels[train_labels != -1]\n",
    "\n",
    "valid_labels = valid_preds_t.detach().clone()\n",
    "valid_hidden_states = {k: v[valid_labels != -1] for k, v in valid_hidden_states.items()}\n",
    "valid_labels = valid_labels[valid_labels != -1].to(device)\n",
    "\n",
    "test_labels = test_preds_t.detach().clone()\n",
    "test_hidden_states = {k: v[test_labels != -1] for k, v in test_hidden_states.items()}\n",
    "test_labels = test_labels[test_labels != -1].to(device)"
   ],
   "outputs": [],
   "execution_count": 51
  },
  {
   "cell_type": "code",
   "id": "6e0936f779eacc40",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:42:04.734460Z",
     "start_time": "2025-09-21T19:39:58.181647Z"
    }
   },
   "source": [
    "test_extracted = {}\n",
    "\n",
    "test_accuracies = {\"sin\": {}, \"bin\": {}, \"lin\": {}, \"log\": {}}\n",
    "\n",
    "basis_name = \"sin\"\n",
    "basis_embs = basis_embs_sin\n",
    "\n",
    "for layer_idx in reversed(range(len(train_hidden_states))):\n",
    "\n",
    "    torch.manual_seed(0)\n",
    "    probe = ClassifierProbe(\n",
    "        emb_dim=train_hidden_states[0].shape[-1],\n",
    "        hidden_dim=100,\n",
    "        basis=basis_embs,\n",
    "        heldout_mask=test_mask,\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)  # TODO: try with weight_decay=1e-3\n",
    "    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=30000)\n",
    "\n",
    "    rng = torch.Generator().manual_seed(0)\n",
    "    best_val_acc = -1\n",
    "    best_ckpt = None\n",
    "    for i in range(50000+1):\n",
    "        probe.train()\n",
    "        optimizer.zero_grad()\n",
    "        minibatch_idcs = torch.randint(len(train_labels), size=(128,), generator=rng)\n",
    "        x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "        y = train_labels[minibatch_idcs].to(device)\n",
    "        logits = probe(x, holdout_eval_tokens=False)\n",
    "        # add l1 regularization of all params to the loss\n",
    "        loss = torch.nn.functional.cross_entropy(logits, y)\n",
    "        loss += 0.01 * sum(p.abs().sum() for p in probe.parameters())  # L1-reg\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        if i % 500 == 0:\n",
    "            train_acc = (logits.argmax(dim=-1) == y).float().mean().item()\n",
    "            probe.eval()\n",
    "            with torch.no_grad():\n",
    "                valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)  # TODO: holdout_eval_tokens switched to False -- incompatible with using model's own predictions as labels!\n",
    "                valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                valid_accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                if valid_accuracy > best_val_acc:\n",
    "                    best_val_acc = valid_accuracy\n",
    "                    best_ckpt = probe.state_dict()\n",
    "            print(f\"{basis_name} {i=:>5} train loss: {loss.item():5.2f}  train acc: {train_acc:.2f}  val loss: {valid_loss.item():5.2f}  valid acc: {valid_accuracy:.2f}\")\n",
    "    probe.load_state_dict(best_ckpt)\n",
    "    probe.eval()\n",
    "    with torch.no_grad():\n",
    "        test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "        test_extracted[layer_idx] = test_logits.argmax(dim=-1)\n",
    "        test_accuracy = (test_extracted[layer_idx] == test_labels).float().mean().item()\n",
    "\n",
    "    test_accuracies[basis_name][layer_idx] = test_accuracy\n",
    "    print(f\"-> {basis_name} layer idx: {layer_idx:<3}, best valid accuracy: {best_val_acc:.2f}, test accuracy: {test_accuracy:.2f}\")\n",
    "    # best test_accuracy so far=0.64"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin i=    0 train loss: 68.16  train acc: 0.00  val loss: 44.59  valid acc: 0.00\n",
      "sin i=  500 train loss:  6.24  train acc: 0.76  val loss:  3.63  valid acc: 0.19\n",
      "sin i= 1000 train loss:  2.62  train acc: 0.66  val loss:  2.93  valid acc: 0.18\n",
      "sin i= 1500 train loss:  2.15  train acc: 0.78  val loss:  2.67  valid acc: 0.25\n",
      "sin i= 2000 train loss:  2.08  train acc: 0.76  val loss:  2.44  valid acc: 0.30\n",
      "sin i= 2500 train loss:  2.17  train acc: 0.77  val loss:  2.51  valid acc: 0.31\n",
      "sin i= 3000 train loss:  2.04  train acc: 0.87  val loss:  2.29  valid acc: 0.29\n",
      "sin i= 3500 train loss:  2.14  train acc: 0.86  val loss:  2.14  valid acc: 0.34\n",
      "sin i= 4000 train loss:  2.13  train acc: 0.88  val loss:  1.89  valid acc: 0.43\n",
      "sin i= 4500 train loss:  2.06  train acc: 0.88  val loss:  2.28  valid acc: 0.35\n",
      "sin i= 5000 train loss:  2.08  train acc: 0.86  val loss:  2.45  valid acc: 0.26\n",
      "sin i= 5500 train loss:  2.19  train acc: 0.81  val loss:  1.81  valid acc: 0.43\n",
      "sin i= 6000 train loss:  2.14  train acc: 0.89  val loss:  2.13  valid acc: 0.40\n",
      "sin i= 6500 train loss:  2.32  train acc: 0.77  val loss:  2.19  valid acc: 0.30\n",
      "sin i= 7000 train loss:  2.16  train acc: 0.83  val loss:  1.85  valid acc: 0.43\n",
      "sin i= 7500 train loss:  2.19  train acc: 0.86  val loss:  2.12  valid acc: 0.35\n",
      "sin i= 8000 train loss:  2.25  train acc: 0.83  val loss:  1.83  valid acc: 0.39\n",
      "sin i= 8500 train loss:  2.17  train acc: 0.80  val loss:  1.84  valid acc: 0.43\n",
      "sin i= 9000 train loss:  2.15  train acc: 0.84  val loss:  2.03  valid acc: 0.40\n",
      "sin i= 9500 train loss:  2.06  train acc: 0.92  val loss:  1.95  valid acc: 0.40\n",
      "sin i=10000 train loss:  2.12  train acc: 0.80  val loss:  1.87  valid acc: 0.47\n",
      "sin i=10500 train loss:  2.02  train acc: 0.83  val loss:  1.90  valid acc: 0.43\n",
      "sin i=11000 train loss:  2.02  train acc: 0.88  val loss:  1.86  valid acc: 0.46\n",
      "sin i=11500 train loss:  2.10  train acc: 0.83  val loss:  2.28  valid acc: 0.39\n",
      "sin i=12000 train loss:  2.09  train acc: 0.88  val loss:  1.96  valid acc: 0.42\n",
      "sin i=12500 train loss:  1.93  train acc: 0.88  val loss:  1.90  valid acc: 0.42\n",
      "sin i=13000 train loss:  1.94  train acc: 0.88  val loss:  2.11  valid acc: 0.40\n",
      "sin i=13500 train loss:  1.87  train acc: 0.88  val loss:  2.33  valid acc: 0.29\n",
      "sin i=14000 train loss:  1.98  train acc: 0.81  val loss:  2.33  valid acc: 0.40\n",
      "sin i=14500 train loss:  1.96  train acc: 0.83  val loss:  2.16  valid acc: 0.37\n",
      "sin i=15000 train loss:  1.75  train acc: 0.92  val loss:  2.20  valid acc: 0.41\n",
      "sin i=15500 train loss:  1.88  train acc: 0.80  val loss:  2.14  valid acc: 0.42\n",
      "sin i=16000 train loss:  1.81  train acc: 0.90  val loss:  2.36  valid acc: 0.40\n",
      "sin i=16500 train loss:  1.78  train acc: 0.88  val loss:  1.93  valid acc: 0.45\n",
      "sin i=17000 train loss:  1.81  train acc: 0.88  val loss:  2.06  valid acc: 0.40\n",
      "sin i=17500 train loss:  1.80  train acc: 0.88  val loss:  2.10  valid acc: 0.43\n",
      "sin i=18000 train loss:  1.62  train acc: 0.93  val loss:  1.86  valid acc: 0.48\n",
      "sin i=18500 train loss:  1.66  train acc: 0.92  val loss:  2.05  valid acc: 0.39\n",
      "sin i=19000 train loss:  1.72  train acc: 0.91  val loss:  2.03  valid acc: 0.45\n",
      "sin i=19500 train loss:  1.66  train acc: 0.90  val loss:  1.87  valid acc: 0.45\n",
      "sin i=20000 train loss:  1.66  train acc: 0.84  val loss:  1.88  valid acc: 0.48\n",
      "sin i=20500 train loss:  1.65  train acc: 0.86  val loss:  1.79  valid acc: 0.47\n",
      "sin i=21000 train loss:  1.62  train acc: 0.88  val loss:  2.14  valid acc: 0.43\n",
      "sin i=21500 train loss:  1.57  train acc: 0.92  val loss:  1.87  valid acc: 0.43\n",
      "sin i=22000 train loss:  1.66  train acc: 0.86  val loss:  2.18  valid acc: 0.38\n",
      "sin i=22500 train loss:  1.68  train acc: 0.85  val loss:  2.12  valid acc: 0.41\n",
      "sin i=23000 train loss:  1.51  train acc: 0.89  val loss:  1.98  valid acc: 0.39\n",
      "sin i=23500 train loss:  1.60  train acc: 0.93  val loss:  1.90  valid acc: 0.41\n",
      "sin i=24000 train loss:  1.57  train acc: 0.88  val loss:  2.06  valid acc: 0.39\n",
      "sin i=24500 train loss:  1.48  train acc: 0.94  val loss:  1.96  valid acc: 0.45\n",
      "sin i=25000 train loss:  1.45  train acc: 0.93  val loss:  1.98  valid acc: 0.47\n",
      "sin i=25500 train loss:  1.47  train acc: 0.95  val loss:  2.03  valid acc: 0.45\n",
      "sin i=26000 train loss:  1.47  train acc: 0.89  val loss:  1.88  valid acc: 0.46\n",
      "sin i=26500 train loss:  1.46  train acc: 0.91  val loss:  2.00  valid acc: 0.42\n",
      "sin i=27000 train loss:  1.42  train acc: 0.91  val loss:  1.98  valid acc: 0.41\n",
      "sin i=27500 train loss:  1.44  train acc: 0.90  val loss:  1.93  valid acc: 0.48\n",
      "sin i=28000 train loss:  1.33  train acc: 0.96  val loss:  1.92  valid acc: 0.44\n",
      "sin i=28500 train loss:  1.44  train acc: 0.92  val loss:  2.02  valid acc: 0.41\n",
      "sin i=29000 train loss:  1.41  train acc: 0.88  val loss:  1.93  valid acc: 0.46\n",
      "sin i=29500 train loss:  1.38  train acc: 0.94  val loss:  1.94  valid acc: 0.43\n",
      "sin i=30000 train loss:  1.41  train acc: 0.88  val loss:  1.95  valid acc: 0.44\n",
      "sin i=30500 train loss:  1.31  train acc: 0.92  val loss:  1.99  valid acc: 0.45\n",
      "sin i=31000 train loss:  1.30  train acc: 0.93  val loss:  1.95  valid acc: 0.44\n",
      "sin i=31500 train loss:  1.36  train acc: 0.96  val loss:  1.93  valid acc: 0.45\n",
      "sin i=32000 train loss:  1.30  train acc: 0.91  val loss:  1.90  valid acc: 0.44\n",
      "sin i=32500 train loss:  1.37  train acc: 0.90  val loss:  1.88  valid acc: 0.45\n",
      "sin i=33000 train loss:  1.32  train acc: 0.91  val loss:  1.97  valid acc: 0.45\n",
      "sin i=33500 train loss:  1.33  train acc: 0.92  val loss:  1.95  valid acc: 0.43\n",
      "sin i=34000 train loss:  1.29  train acc: 0.96  val loss:  1.92  valid acc: 0.46\n",
      "sin i=34500 train loss:  1.32  train acc: 0.93  val loss:  1.93  valid acc: 0.46\n",
      "sin i=35000 train loss:  1.36  train acc: 0.91  val loss:  1.95  valid acc: 0.45\n",
      "sin i=35500 train loss:  1.36  train acc: 0.91  val loss:  1.91  valid acc: 0.44\n",
      "sin i=36000 train loss:  1.35  train acc: 0.89  val loss:  1.91  valid acc: 0.45\n",
      "sin i=36500 train loss:  1.29  train acc: 0.94  val loss:  1.95  valid acc: 0.43\n",
      "sin i=37000 train loss:  1.35  train acc: 0.91  val loss:  1.93  valid acc: 0.44\n",
      "sin i=37500 train loss:  1.33  train acc: 0.91  val loss:  1.90  valid acc: 0.45\n",
      "sin i=38000 train loss:  1.36  train acc: 0.95  val loss:  1.91  valid acc: 0.43\n",
      "sin i=38500 train loss:  1.39  train acc: 0.91  val loss:  1.88  valid acc: 0.43\n",
      "sin i=39000 train loss:  1.37  train acc: 0.94  val loss:  1.91  valid acc: 0.45\n",
      "sin i=39500 train loss:  1.34  train acc: 0.94  val loss:  1.87  valid acc: 0.45\n",
      "sin i=40000 train loss:  1.36  train acc: 0.91  val loss:  1.96  valid acc: 0.45\n",
      "sin i=40500 train loss:  1.35  train acc: 0.91  val loss:  1.94  valid acc: 0.44\n",
      "sin i=41000 train loss:  1.26  train acc: 0.94  val loss:  1.93  valid acc: 0.46\n",
      "sin i=41500 train loss:  1.30  train acc: 0.92  val loss:  1.88  valid acc: 0.44\n",
      "sin i=42000 train loss:  1.33  train acc: 0.91  val loss:  1.92  valid acc: 0.44\n",
      "sin i=42500 train loss:  1.39  train acc: 0.90  val loss:  1.92  valid acc: 0.46\n",
      "sin i=43000 train loss:  1.33  train acc: 0.92  val loss:  1.93  valid acc: 0.45\n",
      "sin i=43500 train loss:  1.29  train acc: 0.94  val loss:  1.91  valid acc: 0.44\n",
      "sin i=44000 train loss:  1.30  train acc: 0.92  val loss:  1.90  valid acc: 0.44\n",
      "sin i=44500 train loss:  1.33  train acc: 0.95  val loss:  1.90  valid acc: 0.45\n",
      "sin i=45000 train loss:  1.23  train acc: 0.95  val loss:  1.91  valid acc: 0.44\n",
      "sin i=45500 train loss:  1.43  train acc: 0.86  val loss:  1.94  valid acc: 0.45\n",
      "sin i=46000 train loss:  1.28  train acc: 0.94  val loss:  1.89  valid acc: 0.46\n",
      "sin i=46500 train loss:  1.28  train acc: 0.95  val loss:  1.90  valid acc: 0.45\n",
      "sin i=47000 train loss:  1.31  train acc: 0.92  val loss:  1.93  valid acc: 0.45\n",
      "sin i=47500 train loss:  1.39  train acc: 0.88  val loss:  1.93  valid acc: 0.44\n",
      "sin i=48000 train loss:  1.37  train acc: 0.94  val loss:  1.95  valid acc: 0.43\n",
      "sin i=48500 train loss:  1.31  train acc: 0.92  val loss:  1.90  valid acc: 0.45\n",
      "sin i=49000 train loss:  1.27  train acc: 0.94  val loss:  1.91  valid acc: 0.44\n",
      "sin i=49500 train loss:  1.27  train acc: 0.96  val loss:  1.92  valid acc: 0.43\n",
      "sin i=50000 train loss:  1.35  train acc: 0.91  val loss:  1.90  valid acc: 0.45\n",
      "-> sin layer idx: 28 , best valid accuracy: 0.48, test accuracy: 0.62\n",
      "sin i=    0 train loss: 62.79  train acc: 0.01  val loss: 33.26  valid acc: 0.00\n",
      "sin i=  500 train loss:  2.53  train acc: 0.73  val loss:  3.05  valid acc: 0.15\n",
      "sin i= 1000 train loss:  2.65  train acc: 0.75  val loss:  2.97  valid acc: 0.17\n",
      "sin i= 1500 train loss:  2.43  train acc: 0.84  val loss:  2.73  valid acc: 0.21\n",
      "sin i= 2000 train loss:  2.38  train acc: 0.84  val loss:  2.80  valid acc: 0.20\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
      "\u001B[31mKeyboardInterrupt\u001B[39m                         Traceback (most recent call last)",
      "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[52]\u001B[39m\u001B[32m, line 28\u001B[39m\n\u001B[32m     26\u001B[39m optimizer.zero_grad()\n\u001B[32m     27\u001B[39m minibatch_idcs = torch.randint(\u001B[38;5;28mlen\u001B[39m(train_labels), size=(\u001B[32m128\u001B[39m,), generator=rng)\n\u001B[32m---> \u001B[39m\u001B[32m28\u001B[39m x = \u001B[43mtrain_hidden_states\u001B[49m\u001B[43m[\u001B[49m\u001B[43mlayer_idx\u001B[49m\u001B[43m]\u001B[49m\u001B[43m[\u001B[49m\u001B[43mminibatch_idcs\u001B[49m\u001B[43m]\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfloat\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m.\u001B[49m\u001B[43mto\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m     29\u001B[39m y = train_labels[minibatch_idcs].to(device)\n\u001B[32m     30\u001B[39m logits = probe(x, holdout_eval_tokens=\u001B[38;5;28;01mFalse\u001B[39;00m)\n",
      "\u001B[31mKeyboardInterrupt\u001B[39m: "
     ]
    }
   ],
   "execution_count": 52
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "id": "48a64b28-a3f0-42d3-abff-22add531a611",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 64.2337875366211 train acc: 0.0009765625 LR: [0.0009999505]\n",
      "step=0        0.0% \n",
      "Train loss: 3.203324317932129 train acc: 0.6171875 LR: [0.0009504505000000151]\n",
      "step=1000    37.1% \n",
      "Train loss: 3.225451946258545 train acc: 0.6845703125 LR: [0.0009009505000000161]\n",
      "step=2000    49.6% \n",
      "Train loss: 3.116640090942383 train acc: 0.6923828125 LR: [0.0008514505000000177]\n",
      "step=3000    51.1% \n",
      "Train loss: 3.18668532371521 train acc: 0.673828125 LR: [0.0008019505000000194]\n",
      "step=4000    51.7% \n",
      "Train loss: 3.021243095397949 train acc: 0.744140625 LR: [0.0007524505000000225]\n",
      "step=5000    54.5% \n",
      "Train loss: 3.089547634124756 train acc: 0.720703125 LR: [0.0007029505000000275]\n",
      "step=6000    52.2% \n",
      "Train loss: 2.987320899963379 train acc: 0.724609375 LR: [0.0006534505000000315]\n",
      "step=7000    48.1% \n",
      "Train loss: 2.9233715534210205 train acc: 0.724609375 LR: [0.0006039505000000363]\n",
      "step=8000    51.7% \n",
      "Train loss: 2.9112982749938965 train acc: 0.7109375 LR: [0.0005544505000000445]\n",
      "step=9000    58.0% \n",
      "Train loss: 2.858201026916504 train acc: 0.7373046875 LR: [0.0005049505000000547]\n",
      "step=10000   52.8% \n",
      "Train loss: 2.807520627975464 train acc: 0.7509765625 LR: [0.0004554505000000548]\n",
      "step=11000   51.1% \n",
      "Train loss: 2.818568229675293 train acc: 0.73046875 LR: [0.0004059505000000473]\n",
      "step=12000   55.8% \n",
      "Train loss: 2.7056360244750977 train acc: 0.7431640625 LR: [0.0003564505000000392]\n",
      "step=13000   50.9% \n",
      "Train loss: 2.5240285396575928 train acc: 0.7841796875 LR: [0.000306950500000029]\n",
      "step=14000   57.8% \n",
      "Train loss: 2.6640219688415527 train acc: 0.7509765625 LR: [0.00025745050000001773]\n",
      "step=15000   53.0% \n",
      "Train loss: 2.56095027923584 train acc: 0.7568359375 LR: [0.0002079505000000123]\n",
      "step=16000   57.1% \n",
      "Train loss: 2.543759346008301 train acc: 0.7548828125 LR: [0.00015845050000001058]\n",
      "step=17000   56.0% \n",
      "Train loss: 2.46024227142334 train acc: 0.7783203125 LR: [0.00010895050000000908]\n",
      "step=18000   55.2% \n",
      "Train loss: 2.447075843811035 train acc: 0.7607421875 LR: [5.945050000000361e-05]\n",
      "step=19000   57.3% \n",
      "Train loss: 2.419114589691162 train acc: 0.7529296875 LR: [1.000000000000058e-05]\n",
      "step=20000   55.2% \n",
      "Train loss: 2.437913417816162 train acc: 0.78125 LR: [1.000000000000058e-05]\n",
      "step=21000   55.8% \n",
      "Train loss: 2.3444290161132812 train acc: 0.763671875 LR: [1.000000000000058e-05]\n",
      "step=22000   56.9% \n",
      "Train loss: 2.484499216079712 train acc: 0.736328125 LR: [1.000000000000058e-05]\n",
      "step=23000   55.6% \n",
      "Train loss: 2.4394032955169678 train acc: 0.7509765625 LR: [1.000000000000058e-05]\n",
      "step=24000   57.8% \n",
      "Train loss: 2.321702480316162 train acc: 0.791015625 LR: [1.000000000000058e-05]\n",
      "step=25000   55.2% \n",
      "Train loss: 2.3407979011535645 train acc: 0.7783203125 LR: [1.000000000000058e-05]\n",
      "step=26000   57.5% \n",
      "Train loss: 2.3601112365722656 train acc: 0.755859375 LR: [1.000000000000058e-05]\n",
      "step=27000   56.5% \n",
      "Train loss: 2.4831812381744385 train acc: 0.7412109375 LR: [1.000000000000058e-05]\n",
      "step=28000   56.9% \n",
      "Train loss: 2.353656530380249 train acc: 0.7763671875 LR: [1.000000000000058e-05]\n",
      "step=29000   55.0% \n",
      "Train loss: 2.388251304626465 train acc: 0.7509765625 LR: [1.000000000000058e-05]\n",
      "step=30000   57.3% \n",
      "Test accuracy: 0.42278480529785156\n"
     ]
    }
   ],
   "source": [
    "# rng = torch.Generator().manual_seed(0)\n",
    "# rng_py = random.Random(0)\n",
    "#\n",
    "# assert list(train_hidden_states.keys()) == list(range(len(train_hidden_states)))\n",
    "# train_hidden_states_tensor = torch.stack(list(train_hidden_states.values()), dim=0)\n",
    "#\n",
    "# histories = []\n",
    "#\n",
    "# probe = ClassifierProbe(\n",
    "#     emb_dim=train_hidden_states[0].shape[-1],\n",
    "#     hidden_dim=100,\n",
    "#     basis=basis_embs_sin,\n",
    "#     heldout_mask=test_mask,\n",
    "# ).to(device)\n",
    "#\n",
    "# optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3, weight_decay=0)\n",
    "# scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=20000)\n",
    "#\n",
    "# train_layers = list(range(len(train_hidden_states)-2, len(train_hidden_states)))\n",
    "#\n",
    "# for step in range(30000+1):\n",
    "#     probe.train()\n",
    "#     optimizer.zero_grad()\n",
    "#     layer_idcs = torch.tensor(random.choices(train_layers, k=1024))\n",
    "#     minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "#     x = train_hidden_states_tensor[layer_idcs, minibatch_idcs].float().to(device)\n",
    "#     y = train_labels[minibatch_idcs].to(device)\n",
    "#     train_logits = probe(x, holdout_eval_tokens=False)\n",
    "#     loss = torch.nn.functional.cross_entropy(train_logits, y)\n",
    "#     loss += 1e-2 * sum(p.abs().sum() for p in probe.parameters()) # L1 regularization\n",
    "#     loss.backward()\n",
    "#     optimizer.step()\n",
    "#     scheduler.step()\n",
    "#\n",
    "#     best_val_acc = -1\n",
    "#     best_ckpt = probe.state_dict()\n",
    "#\n",
    "#     if step % 1000 == 0:\n",
    "#         print(\"Train loss: %s train acc: %s LR: %s\" % (loss.item(),\n",
    "#                                                        sum(train_logits.argmax(-1) == y).item() / len(y),\n",
    "#                                                        scheduler.get_last_lr()))\n",
    "#         probe.eval()\n",
    "#         valid_accs = []\n",
    "#         with torch.no_grad():\n",
    "#             print(f\"{step=:<5}\", end=\"  \")\n",
    "#             # for layer_idx in range(0, len(train_hidden_states)):\n",
    "#             layer_idx = len(train_hidden_states)-1\n",
    "#             valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "#             valid_acc = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "#             valid_accs.append(valid_acc)\n",
    "#             histories.append({\"step\": step, \"eval_layer\": layer_idx, \"valid_acc\": valid_acc})\n",
    "#             acc_out = f\"{valid_acc:>6.1%}\"\n",
    "#             if layer_idx not in train_layers:\n",
    "#                 print('\\033[94m' + acc_out + '\\033[0m', end=\" \")\n",
    "#             else:\n",
    "#                 print(acc_out, end=\" \")\n",
    "#             print()\n",
    "#             valid_acc = sum(valid_accs) / len(valid_accs)\n",
    "#             if valid_acc > best_val_acc:\n",
    "#                 best_val_acc = valid_acc\n",
    "#                 best_ckpt = probe.state_dict()\n",
    "#\n",
    "# probe.load_state_dict(best_ckpt)\n",
    "# probe.eval()\n",
    "# with torch.no_grad():\n",
    "#     layer_idx = len(train_hidden_states)-1\n",
    "#     test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "#     test_extracted[layer_idx] = test_logits.argmax(dim=-1)\n",
    "#     test_accuracy = (test_extracted[layer_idx] == test_labels).float().mean().item()\n",
    "#     print(\"Test accuracy: %s\" % test_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "f8349f5c7a383567",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T19:39:24.712312Z",
     "start_time": "2025-09-20T22:04:30.577310Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame({\"probe_l%s\" % k: v.cpu() for k, v in test_extracted.items()})\n",
    "df[\"model_predictions\"] = test_preds_t.cpu()\n",
    "df[\"inputs\"] = [make_str_input(op) for op in test_inputs]\n",
    "df[\"labels\"] = test_labels_ref.cpu()\n",
    "df.to_csv(\"/home/xstefan3/tmp/pycharm_project_437/notebooks/logs/model_vs_probes_preds_llama3b_div_210925.csv\", index=False)  # TODO: visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "9f1525c48a64f3ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cpu'), device(type='cuda', index=0))"
      ]
     },
     "execution_count": 176,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_preds_t.device, test_labels.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "e253b515feab91e9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T10:18:23.312638Z",
     "start_time": "2025-09-21T10:18:23.287473Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ever computed correctly internally and NOT correctly returned (out of incorrect): tensor(0.2105)\n",
      "NOT ever computed correctly internally and correctly returned: (out of correct) tensor(0.2717)\n",
      "Ever computed correctly internally and correctly returned (out of correct): tensor(0.7283)\n",
      "NOT ever computed correctly internally and NOT correctly returned (out of incorrect): tensor(0.7895)\n",
      "Ever computed as returned: tensor(0.6734)\n"
     ]
    }
   ],
   "source": [
    "is_result_computed_per_l = torch.vstack([test_extracted[l_key] == test_labels_ref for l_key in test_extracted])\n",
    "is_result_computed_internally = torch.any(is_result_computed_per_l, dim=0).cpu()\n",
    "returned_val_is_computed = torch.vstack([test_extracted[l_key].cpu() == test_preds_t for l_key in test_extracted])\n",
    "is_result_returned = (test_preds_t == test_labels_ref.cpu())\n",
    "print(\n",
    "      \"Ever computed correctly internally and NOT correctly returned (out of incorrect): %s\\n\"\n",
    "      \"NOT ever computed correctly internally and correctly returned: (out of correct) %s\\n\"\n",
    "      \"Ever computed correctly internally and correctly returned (out of correct): %s\\n\"\n",
    "      \"NOT ever computed correctly internally and NOT correctly returned (out of incorrect): %s\\n\"\n",
    "      \"Ever computed as returned: %s\"\n",
    "      % (\n",
    "         torch.sum(is_result_computed_internally & ~is_result_returned).item() / (~is_result_returned).sum(),\n",
    "         torch.sum(~is_result_computed_internally & is_result_returned).item() / is_result_returned.sum(),\n",
    "         torch.sum(is_result_computed_internally & is_result_returned).item() / is_result_returned.sum(),\n",
    "         torch.sum(~is_result_computed_internally & ~is_result_returned).item() / (~is_result_returned).sum(),\n",
    "         returned_val_is_computed.any(dim=0).sum() / len(returned_val_is_computed[0]))\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a3f7ba0a324efb68",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-21T10:19:37.230223Z",
     "start_time": "2025-09-21T10:19:37.219407Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(140, tensor(30))"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 15 out of 25 correct contains multiplication involving \"1\" or \"2\" --> effectively solvable by addition\n",
    "# only 30 out of 140 in the case of Llama 3B\n",
    "(len(test_labels_ref[~is_result_computed_internally & is_result_returned]),\n",
    "torch.isin(test_inputs_t[~is_result_computed_internally & is_result_returned], torch.tensor([1, 2])).any(dim=1).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c6a0ac0e3171398",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_linear_layer(x: Tensor, y: Tensor) -> torch.nn.Linear:\n",
    "    if y.ndim == 1:\n",
    "        y = y.unsqueeze(-1)\n",
    "    if not y.is_floating_point():\n",
    "        y = y.float()\n",
    "   \n",
    "    lin = torch.nn.Linear(x.shape[-1], y.shape[-1], device=x.device)\n",
    "    x_aug = torch.cat([x, torch.ones(len(x), 1, device=x.device)], dim=1)\n",
    "    coeffs = torch.linalg.lstsq(x_aug, y).solution\n",
    "    w, b = coeffs[:-1], coeffs[-1]\n",
    "    with torch.no_grad():\n",
    "        lin.weight[:] = w.T\n",
    "        lin.bias[:] = b\n",
    "    return lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "743ac235337625af",
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer_idx in range(len(train_hidden_states)):\n",
    "    lin_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.to(device),\n",
    "    )\n",
    "    log_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.log1p().to(device),\n",
    "    )\n",
    "    lin_test_pred = lin_probe(test_hidden_states[layer_idx].float().to(device)).flatten().round().int()\n",
    "    lin_test_accuracy = (lin_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    log_test_pred = log_probe(test_hidden_states[layer_idx].float().to(device)).flatten().exp().add(1).round().int()\n",
    "    log_test_accuracy = (log_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    test_accuracies[\"lin\"][layer_idx] = lin_test_accuracy\n",
    "    test_accuracies[\"log\"][layer_idx] = log_test_accuracy\n",
    "\n",
    "    print(f\"layer idx: {layer_idx:<3}, linear probe acc: {lin_test_accuracy:.2f}, log probe acc: {log_test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "265eb99ebd08e3d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, accs in test_accuracies.items():\n",
    "    print(f\"{name} accs: | \" + \" | \".join([f\"{x:.0%}\" for layer, x in sorted(accs.items())]) + \" |\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
