{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:20.804054Z",
     "start_time": "2025-09-15T18:37:20.801755Z"
    }
   },
   "cell_type": "code",
   "source": "device = \"cuda\"",
   "id": "966f831c279e30cc",
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Preliminaries",
   "id": "560b3e3e4dbf7eed"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:22.889726Z",
     "start_time": "2025-09-15T18:37:20.934117Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import itertools\n",
    "import random\n",
    "import collections\n",
    "\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm.auto\n",
    "from torch import Tensor"
   ],
   "id": "6267d52fa4cf1f38",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:22.977730Z",
     "start_time": "2025-09-15T18:37:22.970194Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def sinusoidal_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int,\n",
    "    max_value: int,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional\n",
    "    encoding works in transformers.\n",
    "\n",
    "    The encoding is an evaluation of a sine and cosine function at different frequencies, where the\n",
    "    frequency is determined by the embedding dimension and the allowed range of the input values.\n",
    "\n",
    "    >>> sinusoidal_encode(\n",
    "    ...     torch.tensor([-5, 2, 1, 0]),\n",
    "    ...     embedding_dim=6,\n",
    "    ...     min_value=-5,\n",
    "    ...     max_value=5,\n",
    "    ... )\n",
    "    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
    "            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],\n",
    "            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],\n",
    "            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])\n",
    "    \"\"\"\n",
    "\n",
    "    if embedding_dim % 2 != 0 and not use_l2_norm:\n",
    "        raise ValueError(\"Embedding dimension must be even\")\n",
    "\n",
    "    if use_l2_norm:\n",
    "        if embedding_dim % 2 == 0:\n",
    "            reserved_dim = 2\n",
    "        else:\n",
    "            reserved_dim = 1\n",
    "        embedding_dim -= reserved_dim\n",
    "    else:\n",
    "        reserved_dim = 0  # will not be used\n",
    "\n",
    "    domain = max_value - min_value\n",
    "    y_shape = x.shape + (embedding_dim,)\n",
    "    y = torch.zeros(y_shape, device=x.device)\n",
    "    even_indices = torch.arange(0, embedding_dim, 2)\n",
    "    log_term = torch.log(torch.tensor(domain)) / embedding_dim\n",
    "    div_term = torch.exp(even_indices * -log_term)\n",
    "    x = x - min_value\n",
    "    values = x.unsqueeze(-1).float() * div_term\n",
    "    y[..., 0::2] = torch.sin(values)\n",
    "    y[..., 1::2] = torch.cos(values)\n",
    "\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "\n",
    "    return y\n",
    "\n",
    "def binary_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int | float,\n",
    "    max_value: int | float,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)\n",
    "    reserve_dim = 0 if not use_l2_norm else 1\n",
    "    x = x - min_value\n",
    "    maximum = x.max()\n",
    "    for i in range(embedding_dim - reserve_dim):\n",
    "        coeff = 2**i\n",
    "        if maximum < coeff:\n",
    "            break\n",
    "        y[..., -i - 1] = torch.floor(x / coeff) % 2\n",
    "        x = x - coeff * y[..., -i - 1]\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "    return y"
   ],
   "id": "aee6481c6481b767",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Prepare model and data",
   "id": "785b4c78f485ff64"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:28.543365Z",
     "start_time": "2025-09-15T18:37:23.039031Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-1B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt, token=\"XXX\").eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "tokenizer.add_special_tokens({'pad_token': '<|end_of_text|>'})\n",
    "model = model.half().to(device).eval()"
   ],
   "id": "11f1d7f066a43d8",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:28.646815Z",
     "start_time": "2025-09-15T18:37:28.642814Z"
    }
   },
   "cell_type": "code",
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))\n",
    "train_mask = mask < 0.9\n",
    "valid_mask = ~train_mask & (mask < 0.95)\n",
    "test_mask = ~train_mask & ~valid_mask\n",
    "\n",
    "train_values = all_values[train_mask]\n",
    "valid_values = all_values[valid_mask]\n",
    "test_values = all_values[test_mask]"
   ],
   "id": "77893cb1a5a3c08e",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:29.014358Z",
     "start_time": "2025-09-15T18:37:28.793807Z"
    }
   },
   "cell_type": "code",
   "source": [
    "all_inputs = all_values.tolist()\n",
    "all_inputs_val = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x1 + x2 < 1000]\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "        \n",
    "train_inputs = [x for x in all_inputs if x in train_values_set]\n",
    "valid_inputs = [(x1, x2) for x1, x2 in all_inputs_val if x2 in valid_values_set]\n",
    "test_inputs = [(x1, x2) for x1, x2 in all_inputs_val if x2 in test_values_set]\n",
    "\n",
    "# sanity check\n",
    "assert set(train_inputs) & set(valid_inputs) == set()\n",
    "assert set(train_inputs) & set(test_inputs) == set()\n",
    "assert set(valid_inputs) & set(test_inputs) == set()\n",
    "\n",
    "random.seed(0)\n",
    "random.shuffle(train_inputs)\n",
    "random.shuffle(valid_inputs)\n",
    "random.shuffle(test_inputs)\n",
    "valid_size = 4096\n",
    "train_size = 100_000\n",
    "train_inputs = train_inputs[:train_size]\n",
    "valid_inputs = valid_inputs[:valid_size]"
   ],
   "id": "b7034237fc92de8a",
   "outputs": [],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:29.164817Z",
     "start_time": "2025-09-15T18:37:29.157021Z"
    }
   },
   "cell_type": "code",
   "source": "len(train_inputs)",
   "id": "b977f354e932dc29",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "888"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Constructing altered natural texts -- with all numbers from pre-defined ranges",
   "id": "665d5464c926042c"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:43.316650Z",
     "start_time": "2025-09-15T18:37:29.325285Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# cell loading the input texts\n",
    "import json\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import datasets\n",
    "from git import Repo\n",
    "import os\n",
    "\n",
    "import itertools\n",
    "\n",
    "\n",
    "HOME_PATH = \"./\"\n",
    "\n",
    "def load_data(genre=\"food-1\", downsample_to=0):\n",
    "    \"\"\"\n",
    "    genre: input , genre of dataset you want to load\n",
    "    data :  output,\n",
    "\n",
    "    \"\"\"\n",
    "    if genre ==  'food-1':\n",
    "        directory_path = \"./FoodRecipe-ImageCaptioning/\"\n",
    "        if os.path.exists(directory_path) and os.path.isdir(directory_path):\n",
    "            1;\n",
    "        else:\n",
    "            Repo.clone_from(\"https://github.com/samsatp/FoodRecipe-ImageCaptioning.git/\", \"./FoodRecipe-ImageCaptioning/\")\n",
    "\n",
    "        with open(HOME_PATH + directory_path + \"data/data_strings_local.json\", \"r\") as fp:\n",
    "            recipes = json.load(fp)\n",
    "            #print(recipes)\n",
    "            concated_data = [' '.join(d) for d in recipes.values()]\n",
    "            data = concated_data\n",
    "            print(len(data))\n",
    "\n",
    "    elif genre == 'food-2':\n",
    "        reciepe_data2 = datasets.load_dataset(\"m3hrdadfi/recipe_nlg_lite\",trust_remote_code=True) #steps o ingredients\n",
    "        #train 6118 test 1000\n",
    "        # ['uid', 'name', 'description', 'link', 'ner', 'ingredients', 'steps']\n",
    "        data  = reciepe_data2['train']['steps']\n",
    "\n",
    "    elif genre == 'arthmetic-1':\n",
    "\n",
    "        metamathqa = datasets.load_dataset(\"meta-math/MetaMathQA\") #original_question\n",
    "        data = metamathqa['train']['original_question']\n",
    "\n",
    "    elif genre == 'arthmetic-2':\n",
    "\n",
    "        drop = datasets.load_dataset(\"ucinlp/drop\") #passage\n",
    "        data = drop['train']['passage']#['section_id', 'query_id', 'passage', 'question', 'answers_spans']\n",
    "\n",
    "    elif genre == 'arthmetic-3':\n",
    "        aquarat = datasets.load_dataset(\"deepmind/aqua_rat\") #['question', 'options', 'rationale', 'correct'] go question or rationale\n",
    "        data = aquarat['train']['question']\n",
    "\n",
    "    elif genre == 'technical-1':\n",
    "        icdatta = datasets.load_dataset(\"atta00/icd10-codes\") #['chapter', 'section', 'category', 'category_code', 'code', 'description']\n",
    "        data = [f\"description: {d} | code: {c}\" for d,c in zip(icdatta['train']['description'], icdatta['train']['code'] )] # go for description + code\n",
    "\n",
    "    elif genre == 'technical-2':\n",
    "        icdcm = datasets.load_dataset(\"Gokul-waterlabs/ICD-10-CM\")#input+output\n",
    "        data = [f\"Description: {d} | code: {c}\" for d,c in zip(icdcm['train']['input'], icdcm['train']['output'] )]\n",
    "\n",
    "    elif genre == 'datetime-1':\n",
    "\n",
    "        directory_path = \"./TimeLineExtractionDecisionLettersCASE/\"\n",
    "        if os.path.exists(directory_path) and os.path.isdir(directory_path):\n",
    "            1;\n",
    "        else:\n",
    "            Repo.clone_from(\"https://github.com/irlabamsterdam/TimeLineExtractionDecisionLettersCASE.git\", directory_path)\n",
    "\n",
    "        data = []\n",
    "        for file in tqdm(glob(HOME_PATH + directory_path + 'data/txt_files/train/*txt')):\n",
    "            with open(file, 'r') as fp:\n",
    "                data.append(fp.read())\n",
    "    else:\n",
    "        data=\"ERROR : Pick a genre from [food-1/2, arthmetic-1/2/3, techincal-1/2, datetime]\"\n",
    "        print(data)\n",
    "    print(\"Number of samples in the data loaded:\", len(data))\n",
    "    if downsample_to and len(data) > downsample_to:\n",
    "        print(\"Downsampling to %s\" % downsample_to)\n",
    "        data = data[:downsample_to]\n",
    "\n",
    "    return data\n",
    "\n",
    "texts = list(itertools.chain(*(load_data(k) for k in ['food-1', 'food-2', 'arthmetic-1', 'arthmetic-2', 'arthmetic-3', 'technical-1', 'technical-2', 'datetime-1'])))\n",
    "print(len(texts))"
   ],
   "id": "f25ec2c4beaa530f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "719\n",
      "Number of samples in the data loaded: 719\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples in the data loaded: 6118\n",
      "Number of samples in the data loaded: 395000\n",
      "Number of samples in the data loaded: 77400\n",
      "Number of samples in the data loaded: 97467\n",
      "Number of samples in the data loaded: 25719\n",
      "Number of samples in the data loaded: 74044\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:00<00:00, 3672.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples in the data loaded: 50\n",
      "676517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:43.511962Z",
     "start_time": "2025-09-15T18:37:43.507683Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import re\n",
    "\n",
    "def make_str_input(all_possible_operands: list[int]) -> str:\n",
    "    selected_text = random.choice(texts)\n",
    "    text_with_replaced_nums = re.sub(r\"\\d+\", lambda _: str(random.choice(all_possible_operands)), selected_text)\n",
    "    return text_with_replaced_nums\n",
    "\n",
    "make_str_input(train_inputs), make_str_input(valid_inputs)"
   ],
   "id": "92cc3db4cd042de6",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Jared likes to draw monsters. He drew a monster family portrait. The mom had 263 eye and the dad had 958. They had 562 kids, each with 180 eyes. How many eyes did the whole family have?',\n",
       " 'Compared with its metropolitan area, the city of Houstons population has a higher proportion of minorities. According to the (406, 498) United States Census, whites made up (669, 81)% of the city of Houstons population; (60, 170)% of the total population was non-Hispanic whites. Blacks or African Americans made up (227, 570)% of Houstons population, Native Americans in the United States made up (64, 383).(162, 276)% of the population,  Asians made up (255, 377)% ((145, 377).(267, 455)% Vietnamese Americans, (455, 498).(923, 73)% Chinese Americans, (728, 231).(9, 900)% Indian Americans, (704, 21).(445, 73)% Pakistani Americans, (4, 37).(560, 218)% Filipino Americans, (840, 31).(380, 455)% Korean Americans, (12, 707).(478, 338)% Japanese Americans) and Pacific Islanders made up (565, 345).(181, 74)%. Individuals from some other race made up (383, 75).(41, 233)% of the citys population, of which (719, 31).(305, 652)% were non-Hispanic. Individuals from two or more races made up (534, 268).(251, 21)% of the city.')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:43.713450Z",
     "start_time": "2025-09-15T18:37:43.708449Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def make_str_input_nums(operands: tuple[int, int] | list[int]) -> str:\n",
    "    x1, x2 = operands\n",
    "    return f\"{x1} + {x2}\"\n",
    "\n",
    "make_str_input_nums((3, 500)), make_str_input_nums((3, 0))"
   ],
   "id": "b9c43ab54c443cf7",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('3 + 500', '3 + 0')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Inference of model's hidden states",
   "id": "576cfa3c4fb69ddf"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:43.887381Z",
     "start_time": "2025-09-15T18:37:43.879876Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import tqdm\n",
    "\n",
    "def get_hidden_states(model, str_inputs: list[str], batch_size: int) -> tuple[dict[int, Tensor], Tensor]:\n",
    "    model.eval()\n",
    "    num_input_ids = tokenizer(list(map(str, all_inputs)), add_special_tokens=False, return_tensors=\"pt\").input_ids[:, 0]\n",
    "\n",
    "    nums: list[str] = []\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\", padding=True, truncation=True)\n",
    "            num_pos = torch.isin(batch_inputs.input_ids, num_input_ids)\n",
    "            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[num_pos].detach().cpu())\n",
    "            new_nums = tokenizer.batch_decode(batch_inputs.input_ids[num_pos])\n",
    "            nums.extend(new_nums)\n",
    "\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}, torch.tensor(list(map(int, nums)), device=device)"
   ],
   "id": "2a7460caa4e09ba7",
   "outputs": [],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:37:44.063748Z",
     "start_time": "2025-09-15T18:37:44.057767Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_hidden_states_raw_numbers(model, str_inputs: list[str], batch_size: int) -> collections.defaultdict[int, Tensor]:\n",
    "    model.eval()\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\")\n",
    "            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}"
   ],
   "id": "8240cb6b5fa39470",
   "outputs": [],
   "execution_count": 12
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:40:44.255992Z",
     "start_time": "2025-09-15T18:37:44.254281Z"
    }
   },
   "cell_type": "code",
   "source": [
    "batch_size = 8\n",
    "\n",
    "train_input_texts = [make_str_input(train_inputs) for _ in range(10_000)]\n",
    "train_hidden_states, train_labels = get_hidden_states(model, train_input_texts, batch_size)\n",
    "assert train_hidden_states[0].shape[0] == len(train_labels)\n",
    "\n",
    "valid_hidden_states = get_hidden_states_raw_numbers(model, [make_str_input_nums(val) for val in valid_inputs], batch_size)\n",
    "test_hidden_states = get_hidden_states_raw_numbers(model, [make_str_input_nums(val) for val in test_inputs], batch_size)\n",
    "\n",
    "# hidden_state, new_nums = get_hidden_states(model, train_input_texts, batch_size)\n"
   ],
   "id": "831000bb04b5ffd7",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/1250 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b7422d39299a4ee3987b62a7ff86b940"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/512 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "e9143ca65ed84d9f885cfa444ba9a497"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/3032 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5d7856b5a3294551b02a2fd577d058b8"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 13
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Probing",
   "id": "bfa6b454ea64b77f"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:40:44.564310Z",
     "start_time": "2025-09-15T18:40:44.496800Z"
    }
   },
   "cell_type": "code",
   "source": [
    "basis_embs_sin = sinusoidal_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=train_hidden_states[0].shape[-1],\n",
    ")\n",
    "\n",
    "basis_embs_bin = binary_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=10,\n",
    ")"
   ],
   "id": "1788ca655c6b2f13",
   "outputs": [],
   "execution_count": 14
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T18:40:44.980884Z",
     "start_time": "2025-09-15T18:40:44.975795Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class ClassifierProbe(torch.nn.Module):\n",
    "    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):\n",
    "        super().__init__()\n",
    "        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)\n",
    "        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)\n",
    "        self.basis: torch.nn.Buffer\n",
    "        self.heldout_mask: torch.nn.Buffer\n",
    "        self.register_buffer(\"basis\", basis)\n",
    "        self.register_buffer(\"heldout_mask\", heldout_mask)\n",
    "    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:\n",
    "        latent_x = self.emb_to_latent(x)\n",
    "        # during training, model learns to choose among only training tokens\n",
    "        # but during eval, model must choose among all tokens\n",
    "        # this means that the model is never exposed to the eval tokens during training\n",
    "        latent_choices = self.basis_to_latent(self.basis)\n",
    "        logits = latent_x @ latent_choices.T\n",
    "        if holdout_eval_tokens:\n",
    "            logits[:, self.heldout_mask] = float(\"-inf\")\n",
    "        return logits"
   ],
   "id": "37cff5a15e9a866f",
   "outputs": [],
   "execution_count": 15
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:12:59.555349Z",
     "start_time": "2025-09-15T18:40:45.142179Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# train_labels = torch.tensor([x2 for x1, x2 in train_inputs])\n",
    "valid_labels = torch.tensor([x2 for x1, x2 in valid_inputs]).to(device)\n",
    "test_labels = torch.tensor([x2 for x1, x2 in test_inputs]).to(device)\n",
    "\n",
    "test_accuracies = {\"sin\": {}, \"bin\": {}, \"lin\": {}, \"log\": {}}\n",
    "\n",
    "for basis_name, basis_embs in {\"sin\": basis_embs_sin, \"bin\": basis_embs_bin}.items():\n",
    "    for layer_idx in range(len(train_hidden_states)):\n",
    "\n",
    "        torch.manual_seed(0)\n",
    "        probe = ClassifierProbe(\n",
    "            emb_dim=train_hidden_states[0].shape[-1],\n",
    "            hidden_dim=100,\n",
    "            basis=basis_embs,\n",
    "            heldout_mask=test_mask,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)\n",
    "\n",
    "        rng = torch.Generator().manual_seed(0)\n",
    "        best_val_acc = -1\n",
    "        best_ckpt = None\n",
    "        for i in range(10000+1):\n",
    "            probe.train()\n",
    "            optimizer.zero_grad()\n",
    "            minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "            x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "            y = train_labels[minibatch_idcs].to(device)\n",
    "            logits = probe(x, holdout_eval_tokens=True)\n",
    "            # add l1 regularization of all params to the loss\n",
    "            loss = torch.nn.functional.cross_entropy(logits, y) + 0.001 * sum(p.abs().sum() for p in probe.parameters())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if i % 500 == 0:\n",
    "                train_acc = (logits.argmax(dim=-1) == y).float().mean().item()\n",
    "                probe.eval()\n",
    "                with torch.no_grad():\n",
    "                    valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                    valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                    valid_accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                    if valid_accuracy > best_val_acc:\n",
    "                        best_val_acc = valid_accuracy\n",
    "                        best_ckpt = probe.state_dict()\n",
    "                print(f\"{basis_name} {i=:>5} train loss: {loss.item():5.2f}  train acc: {train_acc:.2f}  val loss: {valid_loss.item():5.2f}  valid acc: {valid_accuracy:.2f}\")\n",
    "        probe.load_state_dict(best_ckpt)\n",
    "        probe.eval()\n",
    "        with torch.no_grad():\n",
    "            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        test_accuracies[basis_name][layer_idx] = test_accuracy\n",
    "        print(f\"->  {basis_name}  layer idx: {layer_idx:<3}, best valid accuracy: {best_val_acc:.2f}, test accuracy: {test_accuracy:.2f}\")\n"
   ],
   "id": "d248aea218155782",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.87  valid acc: 0.00\n",
      "sin i=  500 train loss:  2.40  train acc: 0.98  val loss:  1.49  valid acc: 0.77\n",
      "sin i= 1000 train loss:  1.95  train acc: 1.00  val loss:  1.03  valid acc: 0.94\n",
      "sin i= 1500 train loss:  1.74  train acc: 1.00  val loss:  0.86  valid acc: 0.90\n",
      "sin i= 2000 train loss:  1.62  train acc: 1.00  val loss:  0.81  valid acc: 0.90\n",
      "sin i= 2500 train loss:  1.54  train acc: 1.00  val loss:  0.79  valid acc: 0.85\n",
      "sin i= 3000 train loss:  1.46  train acc: 1.00  val loss:  0.78  valid acc: 0.85\n",
      "sin i= 3500 train loss:  1.39  train acc: 1.00  val loss:  0.76  valid acc: 0.84\n",
      "sin i= 4000 train loss:  1.35  train acc: 1.00  val loss:  0.74  valid acc: 0.82\n",
      "sin i= 4500 train loss:  1.30  train acc: 1.00  val loss:  0.69  valid acc: 0.86\n",
      "sin i= 5000 train loss:  1.27  train acc: 1.00  val loss:  0.67  valid acc: 0.88\n",
      "sin i= 5500 train loss:  1.23  train acc: 1.00  val loss:  0.65  valid acc: 0.88\n",
      "sin i= 6000 train loss:  1.21  train acc: 1.00  val loss:  0.62  valid acc: 0.88\n",
      "sin i= 6500 train loss:  1.19  train acc: 1.00  val loss:  0.59  valid acc: 0.86\n",
      "sin i= 7000 train loss:  1.17  train acc: 1.00  val loss:  0.61  valid acc: 0.88\n",
      "sin i= 7500 train loss:  1.14  train acc: 1.00  val loss:  0.59  valid acc: 0.89\n",
      "sin i= 8000 train loss:  1.12  train acc: 1.00  val loss:  0.60  valid acc: 0.89\n",
      "sin i= 8500 train loss:  1.10  train acc: 1.00  val loss:  0.57  valid acc: 0.94\n",
      "sin i= 9000 train loss:  1.08  train acc: 1.00  val loss:  0.58  valid acc: 0.94\n",
      "sin i= 9500 train loss:  1.06  train acc: 1.00  val loss:  0.59  valid acc: 0.90\n",
      "sin i=10000 train loss:  1.04  train acc: 1.00  val loss:  0.61  valid acc: 0.87\n",
      "->  sin  layer idx: 0  , best valid accuracy: 0.94, test accuracy: 0.85\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.68  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.48  train acc: 1.00  val loss:  0.49  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.14  train acc: 1.00  val loss:  0.28  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.98  train acc: 1.00  val loss:  0.21  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.86  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.78  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.71  train acc: 1.00  val loss:  0.15  valid acc: 0.99\n",
      "sin i= 3500 train loss:  0.67  train acc: 1.00  val loss:  0.13  valid acc: 0.98\n",
      "sin i= 4000 train loss:  0.63  train acc: 1.00  val loss:  0.13  valid acc: 0.97\n",
      "sin i= 4500 train loss:  0.60  train acc: 1.00  val loss:  0.12  valid acc: 0.97\n",
      "sin i= 5000 train loss:  0.58  train acc: 1.00  val loss:  0.12  valid acc: 0.97\n",
      "sin i= 5500 train loss:  0.57  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 6000 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 6500 train loss:  0.54  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 7000 train loss:  0.53  train acc: 1.00  val loss:  0.12  valid acc: 0.97\n",
      "sin i= 7500 train loss:  0.52  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 8000 train loss:  0.51  train acc: 1.00  val loss:  0.12  valid acc: 0.97\n",
      "sin i= 8500 train loss:  0.50  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 9000 train loss:  0.50  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 9500 train loss:  0.49  train acc: 1.00  val loss:  0.12  valid acc: 0.97\n",
      "sin i=10000 train loss:  0.49  train acc: 1.00  val loss:  0.12  valid acc: 0.97\n",
      "->  sin  layer idx: 1  , best valid accuracy: 1.00, test accuracy: 0.93\n",
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  6.69  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.44  train acc: 1.00  val loss:  0.45  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.12  train acc: 1.00  val loss:  0.27  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.94  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.82  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.74  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.68  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.63  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.60  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.57  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.56  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.54  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.53  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.52  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.52  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.50  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.49  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.49  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.49  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "->  sin  layer idx: 2  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.71  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.38  train acc: 1.00  val loss:  0.47  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.07  train acc: 1.00  val loss:  0.27  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.91  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.82  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.74  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.69  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.65  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.61  train acc: 1.00  val loss:  0.13  valid acc: 0.98\n",
      "sin i= 4500 train loss:  0.58  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.56  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 5500 train loss:  0.54  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.52  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.50  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.50  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.50  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 3  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.79  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.29  train acc: 1.00  val loss:  0.43  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.99  train acc: 1.00  val loss:  0.25  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.84  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.74  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.68  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.63  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.60  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.57  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.55  train acc: 1.00  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 5000 train loss:  0.54  train acc: 1.00  val loss:  0.13  valid acc: 0.98\n",
      "sin i= 5500 train loss:  0.52  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.51  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.50  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.49  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.47  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.47  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.46  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 4  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.40  train acc: 0.00  val loss:  6.65  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.27  train acc: 1.00  val loss:  0.39  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.00  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.84  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.72  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.65  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.59  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.57  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.54  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.52  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.47  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.47  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.46  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.45  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.44  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.44  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.43  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.43  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.42  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "->  sin  layer idx: 5  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.42  train acc: 0.00  val loss:  6.76  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.28  train acc: 1.00  val loss:  0.39  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.99  train acc: 1.00  val loss:  0.22  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.84  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.74  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.68  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.63  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.59  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.56  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.54  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.53  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.51  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.51  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.50  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.50  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.48  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.47  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.46  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.45  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.44  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.44  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "->  sin  layer idx: 6  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.42  train acc: 0.00  val loss:  6.79  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.29  train acc: 1.00  val loss:  0.39  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.00  train acc: 1.00  val loss:  0.22  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.84  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.74  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.67  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.60  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.57  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.49  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.48  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 6000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.46  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.46  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.43  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "->  sin  layer idx: 7  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.41  train acc: 0.00  val loss:  6.83  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.26  train acc: 1.00  val loss:  0.36  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.98  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.83  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.73  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.66  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.60  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.57  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.55  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.52  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.51  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.49  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.47  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.46  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.43  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.42  train acc: 1.00  val loss:  0.07  valid acc: 0.99\n",
      "sin i= 9500 train loss:  0.42  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.41  train acc: 1.00  val loss:  0.08  valid acc: 0.98\n",
      "->  sin  layer idx: 8  , best valid accuracy: 1.00, test accuracy: 0.97\n",
      "sin i=    0 train loss: 11.42  train acc: 0.00  val loss:  6.90  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.23  train acc: 1.00  val loss:  0.31  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.96  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.81  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.70  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.64  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.58  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.56  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.53  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.51  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.49  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.47  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.47  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.45  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.44  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.44  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.43  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.43  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.42  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.42  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.42  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "->  sin  layer idx: 9  , best valid accuracy: 1.00, test accuracy: 0.97\n",
      "sin i=    0 train loss: 11.45  train acc: 0.00  val loss:  6.89  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.16  train acc: 1.00  val loss:  0.28  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.90  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.77  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.69  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.62  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.58  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.55  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.52  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.50  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.48  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.46  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.45  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.44  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.44  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.43  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.42  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.42  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.41  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.41  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.41  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "->  sin  layer idx: 10 , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.48  train acc: 0.00  val loss:  7.10  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.05  train acc: 0.99  val loss:  0.22  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.81  train acc: 0.99  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.68  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.60  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.55  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.51  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.48  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.46  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.44  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.43  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.42  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.42  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.40  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.40  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.40  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.39  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.38  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.38  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.37  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.37  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "->  sin  layer idx: 11 , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.51  train acc: 0.00  val loss:  7.36  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.00  train acc: 1.00  val loss:  0.24  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.79  train acc: 0.99  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.67  train acc: 0.99  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.58  train acc: 0.99  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.52  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.48  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.46  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.44  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.42  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.40  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.39  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.39  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.38  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.38  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.37  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.36  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.36  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.36  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.35  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.35  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "->  sin  layer idx: 12 , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.57  train acc: 0.00  val loss:  8.05  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.01  train acc: 0.99  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.79  train acc: 0.99  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.66  train acc: 0.99  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.59  train acc: 0.99  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.53  train acc: 0.99  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.49  train acc: 0.99  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.46  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.44  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.43  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.41  train acc: 0.99  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.40  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.39  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.38  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.38  train acc: 0.99  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.37  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.36  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.36  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.35  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.35  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.35  train acc: 1.00  val loss:  0.02  valid acc: 1.00\n",
      "->  sin  layer idx: 13 , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.60  train acc: 0.00  val loss:  8.37  valid acc: 0.01\n",
      "sin i=  500 train loss:  1.00  train acc: 0.99  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.77  train acc: 0.98  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.65  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.58  train acc: 0.99  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.53  train acc: 0.99  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.50  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.48  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.46  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.45  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.44  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.42  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.41  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.40  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.40  train acc: 0.99  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.38  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.38  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.37  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.37  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.36  train acc: 0.99  val loss:  0.05  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.36  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "->  sin  layer idx: 14 , best valid accuracy: 1.00, test accuracy: 0.98\n",
      "sin i=    0 train loss: 11.70  train acc: 0.00  val loss:  7.70  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.06  train acc: 0.98  val loss:  0.26  valid acc: 0.99\n",
      "sin i= 1000 train loss:  0.80  train acc: 0.98  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.68  train acc: 0.99  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.62  train acc: 0.98  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.56  train acc: 0.99  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.54  train acc: 0.99  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.51  train acc: 0.99  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.48  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.47  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.46  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.63  train acc: 0.95  val loss:  1.18  valid acc: 0.79\n",
      "sin i= 6000 train loss:  0.44  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.42  train acc: 0.99  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.41  train acc: 0.99  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.40  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.40  train acc: 0.99  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.40  train acc: 0.99  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.38  train acc: 0.99  val loss:  0.02  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.38  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.38  train acc: 0.99  val loss:  0.03  valid acc: 1.00\n",
      "->  sin  layer idx: 15 , best valid accuracy: 1.00, test accuracy: 0.97\n",
      "sin i=    0 train loss: 19.01  train acc: 0.00  val loss: 27.86  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.37  train acc: 0.46  val loss:  2.96  valid acc: 0.11\n",
      "sin i= 1000 train loss:  3.63  train acc: 0.66  val loss:  2.13  valid acc: 0.31\n",
      "sin i= 1500 train loss:  2.76  train acc: 0.70  val loss:  1.46  valid acc: 0.47\n",
      "sin i= 2000 train loss: 56.16  train acc: 0.15  val loss: 32.16  valid acc: 0.05\n",
      "sin i= 2500 train loss:  6.21  train acc: 0.58  val loss:  2.09  valid acc: 0.36\n",
      "sin i= 3000 train loss:  5.46  train acc: 0.68  val loss:  2.44  valid acc: 0.34\n",
      "sin i= 3500 train loss:  4.98  train acc: 0.68  val loss:  2.55  valid acc: 0.37\n",
      "sin i= 4000 train loss:  5.22  train acc: 0.70  val loss:  1.89  valid acc: 0.43\n",
      "sin i= 4500 train loss:  4.27  train acc: 0.76  val loss:  1.51  valid acc: 0.48\n",
      "sin i= 5000 train loss:  3.88  train acc: 0.79  val loss:  1.53  valid acc: 0.50\n",
      "sin i= 5500 train loss:  3.30  train acc: 0.79  val loss:  1.45  valid acc: 0.50\n",
      "sin i= 6000 train loss:  7.40  train acc: 0.58  val loss:  2.03  valid acc: 0.44\n",
      "sin i= 6500 train loss:  5.62  train acc: 0.73  val loss:  2.28  valid acc: 0.39\n",
      "sin i= 7000 train loss:  5.39  train acc: 0.74  val loss:  2.42  valid acc: 0.38\n",
      "sin i= 7500 train loss:  4.55  train acc: 0.80  val loss:  1.67  valid acc: 0.43\n",
      "sin i= 8000 train loss:  3.96  train acc: 0.82  val loss:  1.35  valid acc: 0.47\n",
      "sin i= 8500 train loss: 10.47  train acc: 0.47  val loss:  3.48  valid acc: 0.30\n",
      "sin i= 9000 train loss:  6.15  train acc: 0.75  val loss:  2.46  valid acc: 0.27\n",
      "sin i= 9500 train loss:  5.43  train acc: 0.80  val loss:  2.22  valid acc: 0.33\n",
      "sin i=10000 train loss:  7.89  train acc: 0.62  val loss:  4.06  valid acc: 0.28\n",
      "->  sin  layer idx: 16 , best valid accuracy: 0.50, test accuracy: 0.45\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.89  valid acc: 0.00\n",
      "bin i=  500 train loss:  5.27  train acc: 0.24  val loss:  4.14  valid acc: 0.07\n",
      "bin i= 1000 train loss:  4.41  train acc: 0.41  val loss:  3.39  valid acc: 0.14\n",
      "bin i= 1500 train loss:  3.90  train acc: 0.62  val loss:  3.14  valid acc: 0.14\n",
      "bin i= 2000 train loss:  3.50  train acc: 0.81  val loss:  3.05  valid acc: 0.23\n",
      "bin i= 2500 train loss:  3.25  train acc: 0.87  val loss:  3.01  valid acc: 0.27\n",
      "bin i= 3000 train loss:  3.05  train acc: 0.92  val loss:  3.01  valid acc: 0.28\n",
      "bin i= 3500 train loss:  2.89  train acc: 0.96  val loss:  3.00  valid acc: 0.26\n",
      "bin i= 4000 train loss:  2.77  train acc: 0.97  val loss:  2.97  valid acc: 0.26\n",
      "bin i= 4500 train loss:  2.66  train acc: 0.98  val loss:  2.94  valid acc: 0.23\n",
      "bin i= 5000 train loss:  2.58  train acc: 0.99  val loss:  2.92  valid acc: 0.22\n",
      "bin i= 5500 train loss:  2.48  train acc: 0.99  val loss:  2.90  valid acc: 0.27\n",
      "bin i= 6000 train loss:  2.40  train acc: 0.99  val loss:  2.92  valid acc: 0.27\n",
      "bin i= 6500 train loss:  2.34  train acc: 0.99  val loss:  2.91  valid acc: 0.26\n",
      "bin i= 7000 train loss:  2.26  train acc: 1.00  val loss:  2.93  valid acc: 0.26\n",
      "bin i= 7500 train loss:  2.22  train acc: 1.00  val loss:  2.94  valid acc: 0.26\n",
      "bin i= 8000 train loss:  2.18  train acc: 1.00  val loss:  2.96  valid acc: 0.26\n",
      "bin i= 8500 train loss:  2.14  train acc: 1.00  val loss:  2.94  valid acc: 0.21\n",
      "bin i= 9000 train loss:  2.11  train acc: 1.00  val loss:  2.97  valid acc: 0.29\n",
      "bin i= 9500 train loss:  2.08  train acc: 1.00  val loss:  2.96  valid acc: 0.27\n",
      "bin i=10000 train loss:  2.05  train acc: 1.00  val loss:  2.99  valid acc: 0.22\n",
      "->  bin  layer idx: 0  , best valid accuracy: 0.29, test accuracy: 0.33\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.87  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.64  train acc: 0.16  val loss:  3.84  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.07  train acc: 0.32  val loss:  3.46  valid acc: 0.08\n",
      "bin i= 1500 train loss:  3.71  train acc: 0.47  val loss:  3.19  valid acc: 0.12\n",
      "bin i= 2000 train loss:  3.31  train acc: 0.65  val loss:  3.16  valid acc: 0.22\n",
      "bin i= 2500 train loss:  3.00  train acc: 0.78  val loss:  3.17  valid acc: 0.18\n",
      "bin i= 3000 train loss:  2.78  train acc: 0.83  val loss:  3.21  valid acc: 0.20\n",
      "bin i= 3500 train loss:  2.64  train acc: 0.88  val loss:  3.23  valid acc: 0.20\n",
      "bin i= 4000 train loss:  2.53  train acc: 0.88  val loss:  3.29  valid acc: 0.23\n",
      "bin i= 4500 train loss:  2.41  train acc: 0.89  val loss:  3.36  valid acc: 0.22\n",
      "bin i= 5000 train loss:  2.33  train acc: 0.91  val loss:  3.41  valid acc: 0.21\n",
      "bin i= 5500 train loss:  2.27  train acc: 0.94  val loss:  3.51  valid acc: 0.19\n",
      "bin i= 6000 train loss:  2.19  train acc: 0.94  val loss:  3.54  valid acc: 0.21\n",
      "bin i= 6500 train loss:  2.14  train acc: 0.95  val loss:  3.53  valid acc: 0.16\n",
      "bin i= 7000 train loss:  2.06  train acc: 0.96  val loss:  3.60  valid acc: 0.19\n",
      "bin i= 7500 train loss:  2.02  train acc: 0.96  val loss:  3.55  valid acc: 0.19\n",
      "bin i= 8000 train loss:  1.98  train acc: 0.96  val loss:  3.64  valid acc: 0.19\n",
      "bin i= 8500 train loss:  1.96  train acc: 0.96  val loss:  3.64  valid acc: 0.23\n",
      "bin i= 9000 train loss:  1.92  train acc: 0.96  val loss:  3.70  valid acc: 0.22\n",
      "bin i= 9500 train loss:  1.90  train acc: 0.96  val loss:  3.69  valid acc: 0.23\n",
      "bin i=10000 train loss:  1.86  train acc: 0.97  val loss:  3.79  valid acc: 0.23\n",
      "->  bin  layer idx: 1  , best valid accuracy: 0.23, test accuracy: 0.18\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.87  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.65  train acc: 0.13  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 1000 train loss:  4.21  train acc: 0.21  val loss:  3.60  valid acc: 0.09\n",
      "bin i= 1500 train loss:  3.96  train acc: 0.30  val loss:  3.39  valid acc: 0.08\n",
      "bin i= 2000 train loss:  3.58  train acc: 0.44  val loss:  3.35  valid acc: 0.12\n",
      "bin i= 2500 train loss:  3.33  train acc: 0.55  val loss:  3.32  valid acc: 0.14\n",
      "bin i= 3000 train loss:  3.09  train acc: 0.62  val loss:  3.36  valid acc: 0.14\n",
      "bin i= 3500 train loss:  2.97  train acc: 0.67  val loss:  3.29  valid acc: 0.16\n",
      "bin i= 4000 train loss:  2.82  train acc: 0.70  val loss:  3.36  valid acc: 0.19\n",
      "bin i= 4500 train loss:  2.72  train acc: 0.74  val loss:  3.43  valid acc: 0.18\n",
      "bin i= 5000 train loss:  2.64  train acc: 0.74  val loss:  3.43  valid acc: 0.19\n",
      "bin i= 5500 train loss:  2.60  train acc: 0.75  val loss:  3.47  valid acc: 0.19\n",
      "bin i= 6000 train loss:  2.53  train acc: 0.76  val loss:  3.62  valid acc: 0.18\n",
      "bin i= 6500 train loss:  2.46  train acc: 0.80  val loss:  3.50  valid acc: 0.17\n",
      "bin i= 7000 train loss:  2.39  train acc: 0.81  val loss:  3.61  valid acc: 0.18\n",
      "bin i= 7500 train loss:  2.37  train acc: 0.80  val loss:  3.54  valid acc: 0.18\n",
      "bin i= 8000 train loss:  2.35  train acc: 0.80  val loss:  3.57  valid acc: 0.19\n",
      "bin i= 8500 train loss:  2.31  train acc: 0.83  val loss:  3.67  valid acc: 0.19\n",
      "bin i= 9000 train loss:  2.29  train acc: 0.82  val loss:  3.61  valid acc: 0.18\n",
      "bin i= 9500 train loss:  2.28  train acc: 0.82  val loss:  3.68  valid acc: 0.19\n",
      "bin i=10000 train loss:  2.23  train acc: 0.85  val loss:  3.77  valid acc: 0.21\n",
      "->  bin  layer idx: 2  , best valid accuracy: 0.21, test accuracy: 0.15\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.86  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.63  train acc: 0.11  val loss:  4.03  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.33  train acc: 0.15  val loss:  3.88  valid acc: 0.01\n",
      "bin i= 1500 train loss:  4.23  train acc: 0.14  val loss:  3.71  valid acc: 0.02\n",
      "bin i= 2000 train loss:  4.03  train acc: 0.18  val loss:  3.61  valid acc: 0.05\n",
      "bin i= 2500 train loss:  3.84  train acc: 0.29  val loss:  3.58  valid acc: 0.07\n",
      "bin i= 3000 train loss:  3.66  train acc: 0.32  val loss:  3.53  valid acc: 0.10\n",
      "bin i= 3500 train loss:  3.54  train acc: 0.34  val loss:  3.58  valid acc: 0.11\n",
      "bin i= 4000 train loss:  3.39  train acc: 0.40  val loss:  3.52  valid acc: 0.14\n",
      "bin i= 4500 train loss:  3.27  train acc: 0.43  val loss:  3.58  valid acc: 0.14\n",
      "bin i= 5000 train loss:  3.20  train acc: 0.45  val loss:  3.57  valid acc: 0.16\n",
      "bin i= 5500 train loss:  3.15  train acc: 0.49  val loss:  3.65  valid acc: 0.15\n",
      "bin i= 6000 train loss:  3.08  train acc: 0.49  val loss:  3.82  valid acc: 0.14\n",
      "bin i= 6500 train loss:  3.02  train acc: 0.52  val loss:  3.77  valid acc: 0.14\n",
      "bin i= 7000 train loss:  2.94  train acc: 0.54  val loss:  3.81  valid acc: 0.14\n",
      "bin i= 7500 train loss:  2.89  train acc: 0.56  val loss:  3.83  valid acc: 0.15\n",
      "bin i= 8000 train loss:  2.87  train acc: 0.54  val loss:  3.86  valid acc: 0.14\n",
      "bin i= 8500 train loss:  2.86  train acc: 0.57  val loss:  3.90  valid acc: 0.15\n",
      "bin i= 9000 train loss:  2.78  train acc: 0.60  val loss:  4.07  valid acc: 0.14\n",
      "bin i= 9500 train loss:  2.80  train acc: 0.57  val loss:  3.88  valid acc: 0.16\n",
      "bin i=10000 train loss:  2.78  train acc: 0.58  val loss:  3.99  valid acc: 0.14\n",
      "->  bin  layer idx: 3  , best valid accuracy: 0.16, test accuracy: 0.18\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.54  train acc: 0.10  val loss:  4.09  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.32  train acc: 0.12  val loss:  4.04  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.30  train acc: 0.10  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 2000 train loss:  4.17  train acc: 0.12  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 2500 train loss:  4.09  train acc: 0.14  val loss:  3.77  valid acc: 0.05\n",
      "bin i= 3000 train loss:  3.96  train acc: 0.16  val loss:  3.70  valid acc: 0.06\n",
      "bin i= 3500 train loss:  3.88  train acc: 0.16  val loss:  3.65  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.77  train acc: 0.16  val loss:  3.61  valid acc: 0.07\n",
      "bin i= 4500 train loss:  3.74  train acc: 0.19  val loss:  3.58  valid acc: 0.08\n",
      "bin i= 5000 train loss:  3.66  train acc: 0.18  val loss:  3.54  valid acc: 0.08\n",
      "bin i= 5500 train loss:  3.64  train acc: 0.21  val loss:  3.50  valid acc: 0.07\n",
      "bin i= 6000 train loss:  3.58  train acc: 0.22  val loss:  3.52  valid acc: 0.07\n",
      "bin i= 6500 train loss:  3.55  train acc: 0.21  val loss:  3.51  valid acc: 0.09\n",
      "bin i= 7000 train loss:  3.46  train acc: 0.24  val loss:  3.52  valid acc: 0.09\n",
      "bin i= 7500 train loss:  3.46  train acc: 0.25  val loss:  3.53  valid acc: 0.09\n",
      "bin i= 8000 train loss:  3.42  train acc: 0.24  val loss:  3.52  valid acc: 0.09\n",
      "bin i= 8500 train loss:  3.36  train acc: 0.28  val loss:  3.56  valid acc: 0.09\n",
      "bin i= 9000 train loss:  3.35  train acc: 0.28  val loss:  3.59  valid acc: 0.09\n",
      "bin i= 9500 train loss:  3.36  train acc: 0.26  val loss:  3.58  valid acc: 0.10\n",
      "bin i=10000 train loss:  3.31  train acc: 0.30  val loss:  3.62  valid acc: 0.09\n",
      "->  bin  layer idx: 4  , best valid accuracy: 0.10, test accuracy: 0.08\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.83  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.53  train acc: 0.09  val loss:  4.06  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.32  train acc: 0.11  val loss:  3.98  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.32  train acc: 0.09  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.20  train acc: 0.11  val loss:  3.85  valid acc: 0.03\n",
      "bin i= 2500 train loss:  4.15  train acc: 0.12  val loss:  3.84  valid acc: 0.04\n",
      "bin i= 3000 train loss:  4.03  train acc: 0.12  val loss:  3.77  valid acc: 0.04\n",
      "bin i= 3500 train loss:  3.97  train acc: 0.13  val loss:  3.75  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.90  train acc: 0.12  val loss:  3.72  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.90  train acc: 0.13  val loss:  3.75  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.80  train acc: 0.12  val loss:  3.66  valid acc: 0.05\n",
      "bin i= 5500 train loss:  3.79  train acc: 0.14  val loss:  3.66  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.79  train acc: 0.14  val loss:  3.63  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.73  train acc: 0.14  val loss:  3.66  valid acc: 0.03\n",
      "bin i= 7000 train loss:  3.71  train acc: 0.15  val loss:  3.63  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.69  train acc: 0.17  val loss:  3.70  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.67  train acc: 0.15  val loss:  3.62  valid acc: 0.04\n",
      "bin i= 8500 train loss:  3.63  train acc: 0.17  val loss:  3.63  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.64  train acc: 0.15  val loss:  3.69  valid acc: 0.04\n",
      "bin i= 9500 train loss:  3.62  train acc: 0.15  val loss:  3.66  valid acc: 0.05\n",
      "bin i=10000 train loss:  3.65  train acc: 0.15  val loss:  3.67  valid acc: 0.03\n",
      "->  bin  layer idx: 5  , best valid accuracy: 0.05, test accuracy: 0.03\n",
      "bin i=    0 train loss:  9.31  train acc: 0.00  val loss:  6.82  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.54  train acc: 0.09  val loss:  4.06  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.32  train acc: 0.11  val loss:  3.99  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.34  train acc: 0.09  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.22  train acc: 0.12  val loss:  3.89  valid acc: 0.03\n",
      "bin i= 2500 train loss:  4.17  train acc: 0.11  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.05  train acc: 0.13  val loss:  3.86  valid acc: 0.03\n",
      "bin i= 3500 train loss:  4.01  train acc: 0.11  val loss:  3.87  valid acc: 0.01\n",
      "bin i= 4000 train loss:  3.95  train acc: 0.11  val loss:  3.86  valid acc: 0.01\n",
      "bin i= 4500 train loss:  3.96  train acc: 0.11  val loss:  3.89  valid acc: 0.02\n",
      "bin i= 5000 train loss:  3.87  train acc: 0.11  val loss:  3.78  valid acc: 0.03\n",
      "bin i= 5500 train loss:  3.87  train acc: 0.14  val loss:  3.80  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.85  train acc: 0.12  val loss:  3.83  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.82  train acc: 0.12  val loss:  3.83  valid acc: 0.02\n",
      "bin i= 7000 train loss:  3.78  train acc: 0.11  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.78  train acc: 0.15  val loss:  3.87  valid acc: 0.02\n",
      "bin i= 8000 train loss:  3.75  train acc: 0.12  val loss:  3.81  valid acc: 0.02\n",
      "bin i= 8500 train loss:  3.71  train acc: 0.15  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.76  train acc: 0.12  val loss:  3.86  valid acc: 0.02\n",
      "bin i= 9500 train loss:  3.72  train acc: 0.12  val loss:  3.82  valid acc: 0.02\n",
      "bin i=10000 train loss:  3.74  train acc: 0.13  val loss:  3.92  valid acc: 0.01\n",
      "->  bin  layer idx: 6  , best valid accuracy: 0.03, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.31  train acc: 0.00  val loss:  6.86  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.55  train acc: 0.08  val loss:  4.06  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.36  train acc: 0.10  val loss:  4.00  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.36  train acc: 0.09  val loss:  3.95  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.24  train acc: 0.12  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 2500 train loss:  4.21  train acc: 0.10  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.07  train acc: 0.12  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 3500 train loss:  4.05  train acc: 0.11  val loss:  3.86  valid acc: 0.03\n",
      "bin i= 4000 train loss:  3.98  train acc: 0.10  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 4500 train loss:  4.00  train acc: 0.12  val loss:  3.96  valid acc: 0.02\n",
      "bin i= 5000 train loss:  3.89  train acc: 0.10  val loss:  3.82  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.90  train acc: 0.12  val loss:  3.82  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.89  train acc: 0.11  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.86  train acc: 0.11  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 7000 train loss:  3.82  train acc: 0.12  val loss:  3.92  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.80  train acc: 0.14  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 8000 train loss:  3.80  train acc: 0.12  val loss:  3.86  valid acc: 0.02\n",
      "bin i= 8500 train loss:  3.74  train acc: 0.14  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.78  train acc: 0.12  val loss:  3.99  valid acc: 0.01\n",
      "bin i= 9500 train loss:  3.74  train acc: 0.11  val loss:  3.90  valid acc: 0.02\n",
      "bin i=10000 train loss:  3.75  train acc: 0.13  val loss:  3.92  valid acc: 0.02\n",
      "->  bin  layer idx: 7  , best valid accuracy: 0.04, test accuracy: 0.01\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.86  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.52  train acc: 0.09  val loss:  4.10  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.33  train acc: 0.11  val loss:  4.03  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.34  train acc: 0.08  val loss:  4.01  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.23  train acc: 0.11  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.20  train acc: 0.10  val loss:  3.96  valid acc: 0.02\n",
      "bin i= 3000 train loss:  4.07  train acc: 0.13  val loss:  3.92  valid acc: 0.03\n",
      "bin i= 3500 train loss:  4.03  train acc: 0.09  val loss:  3.92  valid acc: 0.01\n",
      "bin i= 4000 train loss:  3.98  train acc: 0.11  val loss:  3.92  valid acc: 0.01\n",
      "bin i= 4500 train loss:  3.99  train acc: 0.11  val loss:  3.98  valid acc: 0.01\n",
      "bin i= 5000 train loss:  3.89  train acc: 0.10  val loss:  3.85  valid acc: 0.03\n",
      "bin i= 5500 train loss:  3.90  train acc: 0.12  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.88  train acc: 0.13  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.87  train acc: 0.12  val loss:  3.89  valid acc: 0.02\n",
      "bin i= 7000 train loss:  3.82  train acc: 0.12  val loss:  3.89  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.80  train acc: 0.14  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 8000 train loss:  3.79  train acc: 0.12  val loss:  3.83  valid acc: 0.02\n",
      "bin i= 8500 train loss:  3.75  train acc: 0.13  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.78  train acc: 0.11  val loss:  3.98  valid acc: 0.01\n",
      "bin i= 9500 train loss:  3.74  train acc: 0.10  val loss:  3.89  valid acc: 0.01\n",
      "bin i=10000 train loss:  3.76  train acc: 0.12  val loss:  3.89  valid acc: 0.01\n",
      "->  bin  layer idx: 8  , best valid accuracy: 0.03, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.90  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.49  train acc: 0.09  val loss:  4.06  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.32  train acc: 0.11  val loss:  4.03  valid acc: 0.01\n",
      "bin i= 1500 train loss:  4.33  train acc: 0.08  val loss:  4.00  valid acc: 0.02\n",
      "bin i= 2000 train loss:  4.22  train acc: 0.11  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.17  train acc: 0.11  val loss:  3.95  valid acc: 0.01\n",
      "bin i= 3000 train loss:  4.06  train acc: 0.11  val loss:  3.96  valid acc: 0.02\n",
      "bin i= 3500 train loss:  4.04  train acc: 0.11  val loss:  4.00  valid acc: 0.00\n",
      "bin i= 4000 train loss:  3.96  train acc: 0.12  val loss:  3.93  valid acc: 0.01\n",
      "bin i= 4500 train loss:  3.98  train acc: 0.12  val loss:  3.99  valid acc: 0.02\n",
      "bin i= 5000 train loss:  3.88  train acc: 0.10  val loss:  3.87  valid acc: 0.01\n",
      "bin i= 5500 train loss:  3.87  train acc: 0.12  val loss:  3.90  valid acc: 0.01\n",
      "bin i= 6000 train loss:  3.86  train acc: 0.12  val loss:  3.92  valid acc: 0.01\n",
      "bin i= 6500 train loss:  3.84  train acc: 0.12  val loss:  3.98  valid acc: 0.01\n",
      "bin i= 7000 train loss:  3.82  train acc: 0.13  val loss:  3.96  valid acc: 0.01\n",
      "bin i= 7500 train loss:  3.78  train acc: 0.14  val loss:  4.03  valid acc: 0.01\n",
      "bin i= 8000 train loss:  3.78  train acc: 0.11  val loss:  3.89  valid acc: 0.02\n",
      "bin i= 8500 train loss:  3.73  train acc: 0.14  val loss:  3.96  valid acc: 0.02\n",
      "bin i= 9000 train loss:  3.76  train acc: 0.11  val loss:  4.05  valid acc: 0.03\n",
      "bin i= 9500 train loss:  3.73  train acc: 0.11  val loss:  3.95  valid acc: 0.02\n",
      "bin i=10000 train loss:  3.75  train acc: 0.12  val loss:  3.98  valid acc: 0.02\n",
      "->  bin  layer idx: 9  , best valid accuracy: 0.03, test accuracy: 0.06\n",
      "bin i=    0 train loss:  9.31  train acc: 0.00  val loss:  6.94  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.42  train acc: 0.09  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.26  train acc: 0.12  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.28  train acc: 0.08  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.18  train acc: 0.10  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 2500 train loss:  4.12  train acc: 0.11  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 3000 train loss:  4.02  train acc: 0.12  val loss:  3.97  valid acc: 0.02\n",
      "bin i= 3500 train loss:  3.99  train acc: 0.11  val loss:  4.04  valid acc: 0.01\n",
      "bin i= 4000 train loss:  3.93  train acc: 0.12  val loss:  3.95  valid acc: 0.02\n",
      "bin i= 4500 train loss:  3.95  train acc: 0.09  val loss:  3.98  valid acc: 0.01\n",
      "bin i= 5000 train loss:  3.86  train acc: 0.10  val loss:  3.91  valid acc: 0.02\n",
      "bin i= 5500 train loss:  3.82  train acc: 0.13  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.83  train acc: 0.13  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.83  train acc: 0.12  val loss:  3.96  valid acc: 0.02\n",
      "bin i= 7000 train loss:  3.79  train acc: 0.11  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.77  train acc: 0.14  val loss:  3.96  valid acc: 0.03\n",
      "bin i= 8000 train loss:  3.77  train acc: 0.10  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 8500 train loss:  3.73  train acc: 0.14  val loss:  3.91  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.76  train acc: 0.11  val loss:  4.01  valid acc: 0.01\n",
      "bin i= 9500 train loss:  3.70  train acc: 0.11  val loss:  3.97  valid acc: 0.01\n",
      "bin i=10000 train loss:  3.71  train acc: 0.14  val loss:  4.03  valid acc: 0.02\n",
      "->  bin  layer idx: 10 , best valid accuracy: 0.04, test accuracy: 0.09\n",
      "bin i=    0 train loss:  9.32  train acc: 0.00  val loss:  7.00  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.31  train acc: 0.10  val loss:  3.95  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.16  train acc: 0.11  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.19  train acc: 0.09  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.08  train acc: 0.10  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 2500 train loss:  4.02  train acc: 0.11  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 3000 train loss:  3.91  train acc: 0.12  val loss:  3.87  valid acc: 0.04\n",
      "bin i= 3500 train loss:  3.92  train acc: 0.12  val loss:  3.86  valid acc: 0.02\n",
      "bin i= 4000 train loss:  3.84  train acc: 0.13  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.87  train acc: 0.12  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.75  train acc: 0.12  val loss:  3.80  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.74  train acc: 0.13  val loss:  3.87  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.73  train acc: 0.12  val loss:  3.77  valid acc: 0.04\n",
      "bin i= 6500 train loss:  3.73  train acc: 0.13  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 7000 train loss:  3.68  train acc: 0.14  val loss:  3.89  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.66  train acc: 0.14  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.69  train acc: 0.11  val loss:  3.88  valid acc: 0.04\n",
      "bin i= 8500 train loss:  3.64  train acc: 0.14  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.64  train acc: 0.13  val loss:  3.96  valid acc: 0.03\n",
      "bin i= 9500 train loss:  3.62  train acc: 0.12  val loss:  3.86  valid acc: 0.03\n",
      "bin i=10000 train loss:  3.62  train acc: 0.14  val loss:  3.97  valid acc: 0.03\n",
      "->  bin  layer idx: 11 , best valid accuracy: 0.04, test accuracy: 0.06\n",
      "bin i=    0 train loss:  9.33  train acc: 0.00  val loss:  7.16  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.28  train acc: 0.09  val loss:  3.93  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.14  train acc: 0.11  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.16  train acc: 0.08  val loss:  3.86  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.04  train acc: 0.11  val loss:  3.86  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.01  train acc: 0.12  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 3000 train loss:  3.91  train acc: 0.11  val loss:  3.87  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.90  train acc: 0.12  val loss:  3.83  valid acc: 0.02\n",
      "bin i= 4000 train loss:  3.82  train acc: 0.12  val loss:  3.84  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.87  train acc: 0.12  val loss:  3.86  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.75  train acc: 0.11  val loss:  3.79  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.74  train acc: 0.12  val loss:  3.81  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.75  train acc: 0.13  val loss:  3.74  valid acc: 0.05\n",
      "bin i= 6500 train loss:  3.73  train acc: 0.12  val loss:  3.80  valid acc: 0.02\n",
      "bin i= 7000 train loss:  3.70  train acc: 0.11  val loss:  3.81  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.66  train acc: 0.15  val loss:  3.80  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.70  train acc: 0.11  val loss:  3.76  valid acc: 0.03\n",
      "bin i= 8500 train loss:  3.67  train acc: 0.13  val loss:  3.76  valid acc: 0.05\n",
      "bin i= 9000 train loss:  3.66  train acc: 0.13  val loss:  3.87  valid acc: 0.02\n",
      "bin i= 9500 train loss:  3.65  train acc: 0.12  val loss:  3.78  valid acc: 0.02\n",
      "bin i=10000 train loss:  3.64  train acc: 0.13  val loss:  3.91  valid acc: 0.03\n",
      "->  bin  layer idx: 12 , best valid accuracy: 0.05, test accuracy: 0.10\n",
      "bin i=    0 train loss:  9.34  train acc: 0.00  val loss:  7.30  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.28  train acc: 0.09  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.13  train acc: 0.10  val loss:  3.87  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.16  train acc: 0.09  val loss:  3.82  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.02  train acc: 0.10  val loss:  3.81  valid acc: 0.04\n",
      "bin i= 2500 train loss:  4.01  train acc: 0.11  val loss:  3.81  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.88  train acc: 0.12  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 3500 train loss:  3.87  train acc: 0.11  val loss:  3.85  valid acc: 0.02\n",
      "bin i= 4000 train loss:  3.78  train acc: 0.13  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.84  train acc: 0.12  val loss:  3.82  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.72  train acc: 0.10  val loss:  3.83  valid acc: 0.03\n",
      "bin i= 5500 train loss:  3.72  train acc: 0.13  val loss:  3.78  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.72  train acc: 0.14  val loss:  3.77  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.70  train acc: 0.12  val loss:  3.82  valid acc: 0.04\n",
      "bin i= 7000 train loss:  3.67  train acc: 0.11  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.65  train acc: 0.15  val loss:  3.79  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.65  train acc: 0.12  val loss:  3.81  valid acc: 0.03\n",
      "bin i= 8500 train loss:  3.63  train acc: 0.13  val loss:  3.80  valid acc: 0.05\n",
      "bin i= 9000 train loss:  3.65  train acc: 0.13  val loss:  3.80  valid acc: 0.03\n",
      "bin i= 9500 train loss:  3.64  train acc: 0.12  val loss:  3.76  valid acc: 0.03\n",
      "bin i=10000 train loss:  3.65  train acc: 0.12  val loss:  3.86  valid acc: 0.07\n",
      "->  bin  layer idx: 13 , best valid accuracy: 0.07, test accuracy: 0.04\n",
      "bin i=    0 train loss:  9.38  train acc: 0.00  val loss:  7.51  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.26  train acc: 0.10  val loss:  3.93  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.11  train acc: 0.11  val loss:  3.89  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.09  train acc: 0.10  val loss:  3.83  valid acc: 0.03\n",
      "bin i= 2000 train loss:  3.95  train acc: 0.11  val loss:  3.88  valid acc: 0.02\n",
      "bin i= 2500 train loss:  3.95  train acc: 0.12  val loss:  3.90  valid acc: 0.02\n",
      "bin i= 3000 train loss:  3.81  train acc: 0.12  val loss:  3.86  valid acc: 0.03\n",
      "bin i= 3500 train loss:  3.78  train acc: 0.13  val loss:  3.93  valid acc: 0.01\n",
      "bin i= 4000 train loss:  3.71  train acc: 0.12  val loss:  3.84  valid acc: 0.02\n",
      "bin i= 4500 train loss:  3.78  train acc: 0.11  val loss:  3.88  valid acc: 0.02\n",
      "bin i= 5000 train loss:  3.66  train acc: 0.12  val loss:  3.82  valid acc: 0.02\n",
      "bin i= 5500 train loss:  3.63  train acc: 0.14  val loss:  3.87  valid acc: 0.02\n",
      "bin i= 6000 train loss:  3.64  train acc: 0.16  val loss:  3.83  valid acc: 0.02\n",
      "bin i= 6500 train loss:  3.63  train acc: 0.12  val loss:  3.85  valid acc: 0.02\n",
      "bin i= 7000 train loss:  3.59  train acc: 0.13  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.59  train acc: 0.15  val loss:  3.85  valid acc: 0.02\n",
      "bin i= 8000 train loss:  3.57  train acc: 0.12  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 8500 train loss:  3.59  train acc: 0.12  val loss:  3.81  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.58  train acc: 0.13  val loss:  3.96  valid acc: 0.00\n",
      "bin i= 9500 train loss:  3.55  train acc: 0.13  val loss:  3.84  valid acc: 0.01\n",
      "bin i=10000 train loss:  3.55  train acc: 0.14  val loss:  3.88  valid acc: 0.03\n",
      "->  bin  layer idx: 14 , best valid accuracy: 0.03, test accuracy: 0.09\n",
      "bin i=    0 train loss:  9.45  train acc: 0.00  val loss:  8.51  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.32  train acc: 0.10  val loss:  4.01  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.17  train acc: 0.10  val loss:  3.99  valid acc: 0.01\n",
      "bin i= 1500 train loss:  4.11  train acc: 0.09  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 2000 train loss:  3.95  train acc: 0.12  val loss:  3.95  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.95  train acc: 0.11  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 3000 train loss:  3.83  train acc: 0.13  val loss:  3.99  valid acc: 0.02\n",
      "bin i= 3500 train loss:  3.80  train acc: 0.11  val loss:  4.03  valid acc: 0.00\n",
      "bin i= 4000 train loss:  3.71  train acc: 0.12  val loss:  3.94  valid acc: 0.01\n",
      "bin i= 4500 train loss:  3.78  train acc: 0.12  val loss:  4.01  valid acc: 0.01\n",
      "bin i= 5000 train loss:  3.67  train acc: 0.11  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 5500 train loss:  3.68  train acc: 0.14  val loss:  3.87  valid acc: 0.01\n",
      "bin i= 6000 train loss:  3.67  train acc: 0.15  val loss:  3.86  valid acc: 0.02\n",
      "bin i= 6500 train loss:  3.65  train acc: 0.12  val loss:  3.93  valid acc: 0.02\n",
      "bin i= 7000 train loss:  3.62  train acc: 0.13  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 7500 train loss:  3.61  train acc: 0.14  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.61  train acc: 0.13  val loss:  3.92  valid acc: 0.02\n",
      "bin i= 8500 train loss:  3.60  train acc: 0.14  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.62  train acc: 0.13  val loss:  4.04  valid acc: 0.00\n",
      "bin i= 9500 train loss:  3.57  train acc: 0.14  val loss:  3.89  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.60  train acc: 0.14  val loss:  4.04  valid acc: 0.03\n",
      "->  bin  layer idx: 15 , best valid accuracy: 0.04, test accuracy: 0.05\n",
      "bin i=    0 train loss: 12.76  train acc: 0.00  val loss: 53.32  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.20  train acc: 0.04  val loss:  4.72  valid acc: 0.01\n",
      "bin i= 1000 train loss:  5.63  train acc: 0.06  val loss:  4.54  valid acc: 0.04\n",
      "bin i= 1500 train loss:  5.46  train acc: 0.05  val loss:  4.47  valid acc: 0.01\n",
      "bin i= 2000 train loss:  5.20  train acc: 0.04  val loss:  4.46  valid acc: 0.02\n",
      "bin i= 2500 train loss:  5.24  train acc: 0.06  val loss:  4.48  valid acc: 0.00\n",
      "bin i= 3000 train loss:  5.13  train acc: 0.06  val loss:  4.55  valid acc: 0.01\n",
      "bin i= 3500 train loss:  5.09  train acc: 0.04  val loss:  4.58  valid acc: 0.00\n",
      "bin i= 4000 train loss:  4.86  train acc: 0.07  val loss:  4.38  valid acc: 0.01\n",
      "bin i= 4500 train loss:  4.92  train acc: 0.06  val loss:  4.52  valid acc: 0.00\n",
      "bin i= 5000 train loss:  4.78  train acc: 0.06  val loss:  4.34  valid acc: 0.06\n",
      "bin i= 5500 train loss:  4.76  train acc: 0.05  val loss:  4.24  valid acc: 0.02\n",
      "bin i= 6000 train loss:  4.76  train acc: 0.07  val loss:  4.45  valid acc: 0.01\n",
      "bin i= 6500 train loss:  4.73  train acc: 0.06  val loss:  4.35  valid acc: 0.01\n",
      "bin i= 7000 train loss:  4.65  train acc: 0.07  val loss:  4.20  valid acc: 0.02\n",
      "bin i= 7500 train loss:  4.59  train acc: 0.07  val loss:  4.23  valid acc: 0.00\n",
      "bin i= 8000 train loss:  4.61  train acc: 0.07  val loss:  4.30  valid acc: 0.00\n",
      "bin i= 8500 train loss:  4.57  train acc: 0.06  val loss:  4.43  valid acc: 0.00\n",
      "bin i= 9000 train loss:  4.64  train acc: 0.07  val loss:  4.41  valid acc: 0.00\n",
      "bin i= 9500 train loss:  4.55  train acc: 0.08  val loss:  4.36  valid acc: 0.00\n",
      "bin i=10000 train loss:  4.61  train acc: 0.07  val loss:  4.40  valid acc: 0.00\n",
      "->  bin  layer idx: 16 , best valid accuracy: 0.06, test accuracy: 0.01\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:12:59.878940Z",
     "start_time": "2025-09-15T19:12:59.873954Z"
    }
   },
   "cell_type": "code",
   "source": "valid_hidden_states[0][layer_idx].shape, len(valid_labels)",
   "id": "d4643f15b8c0ddd5",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([2048]), 4096)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 17
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:13:00.237922Z",
     "start_time": "2025-09-15T19:13:00.233767Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def solve_linear_layer(x: Tensor, y: Tensor) -> torch.nn.Linear:\n",
    "    if y.ndim == 1:\n",
    "        y = y.unsqueeze(-1)\n",
    "    if not y.is_floating_point():\n",
    "        y = y.float()\n",
    "   \n",
    "    lin = torch.nn.Linear(x.shape[-1], y.shape[-1], device=x.device)\n",
    "    x_aug = torch.cat([x, torch.ones(len(x), 1, device=x.device)], dim=1)\n",
    "    coeffs = torch.linalg.lstsq(x_aug, y).solution\n",
    "    w, b = coeffs[:-1], coeffs[-1]\n",
    "    with torch.no_grad():\n",
    "        lin.weight[:] = w.T\n",
    "        lin.bias[:] = b\n",
    "    return lin"
   ],
   "id": "180a6d66da68307d",
   "outputs": [],
   "execution_count": 18
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:13:15.480321Z",
     "start_time": "2025-09-15T19:13:00.514812Z"
    }
   },
   "cell_type": "code",
   "source": [
    "for layer_idx in range(len(train_hidden_states)):\n",
    "    lin_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.to(device),\n",
    "    )\n",
    "    log_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.log1p().to(device),\n",
    "    )\n",
    "    lin_test_pred = lin_probe(test_hidden_states[layer_idx].float().to(device)).flatten().round().int()\n",
    "    lin_test_accuracy = (lin_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    log_test_pred = log_probe(test_hidden_states[layer_idx].float().to(device)).flatten().exp().add(1).round().int()\n",
    "    log_test_accuracy = (log_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    test_accuracies[\"lin\"][layer_idx] = lin_test_accuracy\n",
    "    test_accuracies[\"log\"][layer_idx] = log_test_accuracy\n",
    "\n",
    "    print(f\"layer idx: {layer_idx:<3}, linear probe acc: {lin_test_accuracy:.2f}, log probe acc: {log_test_accuracy:.2f}\")"
   ],
   "id": "6f18c4fd0b785e1b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer idx: 0  , linear probe acc: 0.00, log probe acc: 0.00\n",
      "layer idx: 1  , linear probe acc: 0.02, log probe acc: 0.04\n",
      "layer idx: 2  , linear probe acc: 0.02, log probe acc: 0.04\n",
      "layer idx: 3  , linear probe acc: 0.01, log probe acc: 0.03\n",
      "layer idx: 4  , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 5  , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 6  , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 7  , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 8  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 9  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 10 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 11 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 12 , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 13 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 14 , linear probe acc: 0.00, log probe acc: 0.01\n",
      "layer idx: 15 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 16 , linear probe acc: 0.01, log probe acc: 0.01\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:13:15.682667Z",
     "start_time": "2025-09-15T19:13:15.677753Z"
    }
   },
   "cell_type": "code",
   "source": [
    "for name, accs in test_accuracies.items():\n",
    "    print(f\"{name} accs: | \" + \" | \".join([f\"{x:.0%}\" for layer, x in sorted(accs.items())]) + \" |\")"
   ],
   "id": "9d2a87552870d319",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin accs: | 85% | 93% | 100% | 99% | 99% | 100% | 100% | 100% | 97% | 97% | 100% | 100% | 100% | 100% | 98% | 97% | 45% |\n",
      "bin accs: | 33% | 18% | 15% | 18% | 8% | 3% | 2% | 1% | 2% | 6% | 9% | 6% | 10% | 4% | 9% | 5% | 1% |\n",
      "lin accs: | 0% | 2% | 2% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 0% | 1% | 1% |\n",
      "log accs: | 0% | 4% | 4% | 3% | 2% | 2% | 2% | 2% | 1% | 1% | 1% | 1% | 2% | 1% | 1% | 1% | 1% |\n"
     ]
    }
   ],
   "execution_count": 20
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
