{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "32c44634",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import itertools\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import datasets\n",
    "import plotly.express\n",
    "import einops\n",
    "import tqdm.auto\n",
    "from lovely_tensors import lovely"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "01d5be68",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-1B\"\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(model_ckpt).eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dedbd09c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = datasets.concatenate_datasets(\n",
    "    [\n",
    "        datasets.load_dataset(\"RealTimeData/bbc_news_alltime\", f\"2024-{i:02d}\", split=\"train\").select_columns(\"content\") for i in range(1, 13)\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6cf03738",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = list(set(ds[\"content\"]))\n",
    "texts.sort(key=lambda x: len(x), reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5545d567",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized = tokenizer(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c8fd89b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([11778, 512]), torch.Size([1000, 512]), torch.Size([1000, 512]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "desired_length = 512\n",
    "ds = [inp[:desired_length] for inp in tokenized.input_ids if len(inp) >= desired_length]\n",
    "random.Random(0).shuffle(ds)\n",
    "eval_size = 1000\n",
    "train_ds, valid_ds, test_ds = torch.tensor(ds, dtype=torch.long).tensor_split([-eval_size*2, -eval_size])\n",
    "train_ds.shape, valid_ds.shape, test_ds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ee9aebd4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128256, 2048)\n",
       "    (layers): ModuleList(\n",
       "      (0-15): 16 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (k_proj): Linear(in_features=2048, out_features=512, bias=False)\n",
       "          (v_proj): Linear(in_features=2048, out_features=512, bias=False)\n",
       "          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)\n",
       "          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)\n",
       "          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((2048,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:7\"\n",
    "dtype = torch.bfloat16\n",
    "model.to(device, dtype).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2de05a30",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5c5fe29e",
   "metadata": {},
   "outputs": [],
   "source": [
    "probes = {}\n",
    "optims = {}\n",
    "\n",
    "# layer_idcs = range(len(model.model.layers) + 1)\n",
    "layer_idcs = [0, 1, 2]\n",
    "offset_idcs = [0, 1, 2]\n",
    "\n",
    "for layer_idx in layer_idcs:\n",
    "    probes[layer_idx] = {}\n",
    "    optims[layer_idx] = {}\n",
    "    for offset_idx in offset_idcs:\n",
    "        probe = torch.nn.Linear(model.config.hidden_size, model.config.hidden_size, bias=False, device=model.device, dtype=model.dtype)\n",
    "        optim = torch.optim.AdamW(probe.parameters(), lr=1e-4)\n",
    "        probes[layer_idx][offset_idx] = probe\n",
    "        optims[layer_idx][offset_idx] = optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "36bbb4e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a4511231c5944b8a827789977a49d88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b113e3263bff465e92f915576e65c634",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=10\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.045\n",
      "  For offset: 2 - best layer: 2  from step: 10    - valid acc: 0.037 and train acc: 0.038\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18fba6d6b76d4c11b033b01818e4b0d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=20\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 20    - valid acc: 0.034 and train acc: 0.048\n",
      "  For offset: 2 - best layer: 1  from step: 20    - valid acc: 0.039 and train acc: 0.041\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6feb898f2e7b4c0b84e62279f99f594f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=30\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 30    - valid acc: 0.040 and train acc: 0.037\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b8279661cb5450d98488bdcc3a64623",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=40\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3b40bc8b0a54fec9be980dbf6d139f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=50\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3dfe4983dd21402ab75276b990e5bf88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=60\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4dec2b2521974ddeb2958bb539dac494",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=70\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6bb662e41de04f7da9b758d0d5d05d11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=80\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92eaf7d05d924da99606525b82e0f557",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=90\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1272046a9ff94accbfc596675d235c69",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=100\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48e1ead24f514d46a0791efcae93c9d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=110\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09a2b9c9c795464388fa18de2d980542",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=120\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 30    - valid acc: 0.034 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 40    - valid acc: 0.040 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b5253a2d7de2439182bc258141128d5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=130\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 130   - valid acc: 0.034 and train acc: 0.039\n",
      "  For offset: 2 - best layer: 1  from step: 130   - valid acc: 0.041 and train acc: 0.034\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "478e6d2bd74141c5817c0031e0bdb95c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=140\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 140   - valid acc: 0.035 and train acc: 0.050\n",
      "  For offset: 2 - best layer: 1  from step: 140   - valid acc: 0.041 and train acc: 0.043\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "abf06d7923fe4c7f85889d64640226c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=150\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 150   - valid acc: 0.035 and train acc: 0.046\n",
      "  For offset: 2 - best layer: 1  from step: 150   - valid acc: 0.043 and train acc: 0.038\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d57210d980d74bc7ae1f84fc9cfee381",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=160\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 160   - valid acc: 0.035 and train acc: 0.043\n",
      "  For offset: 2 - best layer: 1  from step: 160   - valid acc: 0.046 and train acc: 0.034\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51b7750c4f55447fba0e516a05344463",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=170\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 170   - valid acc: 0.036 and train acc: 0.054\n",
      "  For offset: 2 - best layer: 1  from step: 170   - valid acc: 0.050 and train acc: 0.041\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "05e1cd436b534fda8dc6b08e7271910c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=180\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 180   - valid acc: 0.036 and train acc: 0.057\n",
      "  For offset: 2 - best layer: 1  from step: 180   - valid acc: 0.055 and train acc: 0.039\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37f4c2ab5aa04abfac080d4c21eaf44d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=190\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 190   - valid acc: 0.036 and train acc: 0.061\n",
      "  For offset: 2 - best layer: 2  from step: 190   - valid acc: 0.059 and train acc: 0.059\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65ad516425ec46a39efb2c42ebbee386",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=200\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 1  from step: 200   - valid acc: 0.037 and train acc: 0.068\n",
      "  For offset: 2 - best layer: 2  from step: 200   - valid acc: 0.064 and train acc: 0.059\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a76e50c21502460fbec7f32d8d955e71",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=210\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 210   - valid acc: 0.039 and train acc: 0.109\n",
      "  For offset: 2 - best layer: 2  from step: 210   - valid acc: 0.069 and train acc: 0.071\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a30cf26cbe544dc7901b4f7e4187a66d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=220\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 220   - valid acc: 0.042 and train acc: 0.103\n",
      "  For offset: 2 - best layer: 2  from step: 220   - valid acc: 0.072 and train acc: 0.066\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84ec29d2d55c4290a1bac873bfc5dacf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=230\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 230   - valid acc: 0.045 and train acc: 0.110\n",
      "  For offset: 2 - best layer: 2  from step: 230   - valid acc: 0.075 and train acc: 0.075\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e943fa25f3b64f8e840f48e848f4192f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=240\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 240   - valid acc: 0.047 and train acc: 0.110\n",
      "  For offset: 2 - best layer: 2  from step: 240   - valid acc: 0.077 and train acc: 0.077\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a3491dede2346e0b18a55f9f3e6c983",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=250\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 250   - valid acc: 0.050 and train acc: 0.118\n",
      "  For offset: 2 - best layer: 2  from step: 250   - valid acc: 0.079 and train acc: 0.084\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "474c0da3e22e4dbcbb06c7ef2666f896",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=260\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 260   - valid acc: 0.052 and train acc: 0.111\n",
      "  For offset: 2 - best layer: 2  from step: 260   - valid acc: 0.080 and train acc: 0.078\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30d4bfca71434b678dda06918e082965",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=270\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 270   - valid acc: 0.054 and train acc: 0.113\n",
      "  For offset: 2 - best layer: 2  from step: 270   - valid acc: 0.081 and train acc: 0.082\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0cee002c92ca44ca80ad68796408e927",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=280\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 280   - valid acc: 0.056 and train acc: 0.117\n",
      "  For offset: 2 - best layer: 2  from step: 280   - valid acc: 0.082 and train acc: 0.082\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "704aa03f02bd437ea995a0f288f5c8b3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=290\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 290   - valid acc: 0.057 and train acc: 0.125\n",
      "  For offset: 2 - best layer: 2  from step: 290   - valid acc: 0.083 and train acc: 0.085\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfd30d2fe7c5426c82b1a5d778f19365",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=300\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 300   - valid acc: 0.058 and train acc: 0.134\n",
      "  For offset: 2 - best layer: 2  from step: 300   - valid acc: 0.083 and train acc: 0.091\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "77762dba54af429da005bdeec6981287",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=310\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 310   - valid acc: 0.059 and train acc: 0.127\n",
      "  For offset: 2 - best layer: 2  from step: 310   - valid acc: 0.084 and train acc: 0.082\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d52eb6afa15a4f4984518ed35b36851d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_step=320\n",
      "  For offset: 0 - best layer: 2  from step: 10    - valid acc: 0.034 and train acc: 0.101\n",
      "  For offset: 1 - best layer: 2  from step: 320   - valid acc: 0.059 and train acc: 0.123\n",
      "  For offset: 2 - best layer: 2  from step: 320   - valid acc: 0.084 and train acc: 0.080\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71b1ee8d725e43e898e98a344f054c3e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 57\u001b[39m\n\u001b[32m     55\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m val_batch_idx, valid_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(tqdm.auto.tqdm(valid_loader, desc=\u001b[33m\"\u001b[39m\u001b[33mValidating\u001b[39m\u001b[33m\"\u001b[39m, leave=\u001b[38;5;28;01mFalse\u001b[39;00m)):\n\u001b[32m     56\u001b[39m     valid_batch, = valid_batch \u001b[38;5;66;03m# loader outputs tuples even if there is only one x without y, we need to unpack\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m57\u001b[39m     valid_batch = \u001b[43mvalid_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     58\u001b[39m     hidden_states = model(valid_batch, output_hidden_states=\u001b[38;5;28;01mTrue\u001b[39;00m).hidden_states\n\u001b[32m     60\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m layer_idx, offset_idx \u001b[38;5;129;01min\u001b[39;00m itertools.product(layer_idcs, offset_idcs):\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "batch_size = 16\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(train_ds),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    pin_memory=True,\n",
    "    pin_memory_device=device\n",
    ")\n",
    "valid_loader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(valid_ds),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    pin_memory=True,\n",
    "    pin_memory_device=device\n",
    ")\n",
    "\n",
    "batches = itertools.cycle(train_loader)\n",
    "pbar = tqdm.auto.tqdm(batches, desc=\"Training\")\n",
    "\n",
    "n_valid_batches = len(valid_loader)\n",
    "\n",
    "train_accs = {layer_idx: {offset_idx: [] for offset_idx in offset_idcs} for layer_idx in layer_idcs}\n",
    "valid_accs = {layer_idx: {offset_idx: [] for offset_idx in offset_idcs} for layer_idx in layer_idcs}\n",
    "\n",
    "for train_step, train_batch in enumerate(pbar):\n",
    "    train_batch, = train_batch # loader outputs tuples even if there is only one x without y, we need to unpack\n",
    "    train_batch = train_batch.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        hidden_states = model(train_batch, output_hidden_states=True).hidden_states\n",
    "\n",
    "    curr_train_accs = torch.zeros((len(layer_idcs), len(offset_idcs)), device=device, dtype=torch.float32)\n",
    "    for layer_idx, offset_idx in itertools.product(layer_idcs, offset_idcs):\n",
    "        probe = probes[layer_idx][offset_idx]\n",
    "        optim = optims[layer_idx][offset_idx]\n",
    "\n",
    "        probe.train()\n",
    "        prediction: torch.Tensor = probe(hidden_states[layer_idx])\n",
    "        logits = model.lm_head(prediction)\n",
    "        logits = einops.rearrange(logits, \"batch seq vocab -> batch vocab seq\")\n",
    "        logits = logits[..., offset_idx:] # slice from start\n",
    "        labels = train_batch[:, :logits.shape[-1]] # slice from end\n",
    "        optim.zero_grad()\n",
    "        loss = torch.nn.functional.cross_entropy(logits, labels)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        curr_train_accs[layer_idx][offset_idx] = (logits.argmax(dim=1) == labels).float().mean()\n",
    "\n",
    "    eval_every_n_steps = 10\n",
    "    if train_step % eval_every_n_steps == 0 and train_step != 0:\n",
    "        probe.eval()\n",
    "        with torch.no_grad():\n",
    "            curr_valid_accs = torch.zeros((len(layer_idcs), len(offset_idcs), n_valid_batches), device=device, dtype=torch.float32)\n",
    "            for val_batch_idx, valid_batch in enumerate(tqdm.auto.tqdm(valid_loader, desc=\"Validating\", leave=False)):\n",
    "                valid_batch, = valid_batch # loader outputs tuples even if there is only one x without y, we need to unpack\n",
    "                valid_batch = valid_batch.to(device)\n",
    "                hidden_states = model(valid_batch, output_hidden_states=True).hidden_states\n",
    "\n",
    "                for layer_idx, offset_idx in itertools.product(layer_idcs, offset_idcs):\n",
    "                    prediction = probe(hidden_states[layer_idx])\n",
    "                    logits = model.lm_head(prediction)\n",
    "                    logits = einops.rearrange(logits, \"batch seq vocab -> batch vocab seq\")\n",
    "                    logits = logits[..., offset_idx:] # slice from start\n",
    "                    labels = valid_batch[:, :logits.shape[-1]] # slice from end\n",
    "                    valid_acc_batch = (logits.argmax(dim=1) == labels).float().mean()\n",
    "                    curr_valid_accs[layer_idx, offset_idx, val_batch_idx] = valid_acc_batch\n",
    "            curr_valid_accs = curr_valid_accs.mean(dim=-1) # average over all valid batches\n",
    "\n",
    "        for layer_idx, offset_idx in itertools.product(layer_idcs, offset_idcs):\n",
    "            train_accs[layer_idx][offset_idx].append(curr_train_accs[layer_idx][offset_idx].item())\n",
    "            valid_accs[layer_idx][offset_idx].append(curr_valid_accs[layer_idx][offset_idx].item())\n",
    "\n",
    "        print(f\"{train_step=}\")\n",
    "        for offset_idx in offset_idcs:\n",
    "            # find best performing layer and timestep\n",
    "            best_layer = max(layer_idcs, key=lambda l: max(valid_accs[l][offset_idx]))\n",
    "            best_step = max(range(len(valid_accs[best_layer][offset_idx])), key=lambda s: valid_accs[best_layer][offset_idx][s])\n",
    "            best_valid_acc = valid_accs[best_layer][offset_idx][best_step]\n",
    "            train_acc = train_accs[best_layer][offset_idx][best_step]\n",
    "            print(f\"  For offset: {offset_idx} - best layer: {best_layer:<2} from step: {(best_step+1)*eval_every_n_steps:<5} - valid acc: {best_valid_acc:.3f} and train acc: {train_acc:.3f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbddd3d6",
   "metadata": {},
   "source": [
    "# Stuff Bellow\n",
    "Code below might be useful for the experimenting on the number data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daae081c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fmt_num(seq: list[int]) -> str:\n",
    "    fst, *rest = seq\n",
    "    return str(fst) + \"\".join(f\"{x:03d}\" for x in rest)\n",
    "\n",
    "rng = random.Random(0)\n",
    "nums = [[rng.randint(0, 999) for _ in range(10)] for _ in range(100)]\n",
    "nums_input = tokenizer([fmt_num(seq) for seq in nums], return_tensors=\"pt\").input_ids\n",
    "nums_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6434eee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    hidden_states = model(nums_input.to(model.device), output_hidden_states=True).hidden_states\n",
    "    prediction = probe(hidden_states[layer_idx])\n",
    "    logits = model.lm_head(prediction)\n",
    "    logits = logits[:, max_offset:, :]\n",
    "    logits = einops.rearrange(logits, \"b s v -> b v s\")\n",
    "    labels = nums_input[:, max_offset-offset_idx:nums_input.shape[1]-offset_idx] # shifted to predict the past token\n",
    "    print(logits.argmax(dim=1).shape, labels.shape)\n",
    "    valid_acc_batch = (logits.argmax(dim=1).cpu() == labels).float().mean()\n",
    "    print(valid_acc_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3a074e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.batch_decode(logits.argmax(dim=1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
