{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:07.607975Z",
     "start_time": "2025-09-15T19:32:07.604641Z"
    }
   },
   "cell_type": "code",
   "source": "device = \"cuda\"",
   "id": "966f831c279e30cc",
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Preliminaries",
   "id": "560b3e3e4dbf7eed"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:09.814935Z",
     "start_time": "2025-09-15T19:32:07.926486Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import random\n",
    "import collections\n",
    "\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm.auto\n",
    "from torch import Tensor"
   ],
   "id": "6267d52fa4cf1f38",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:09.991024Z",
     "start_time": "2025-09-15T19:32:09.968671Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def sinusoidal_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int,\n",
    "    max_value: int,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional\n",
    "    encoding works in transformers.\n",
    "\n",
    "    The encoding is an evaluation of a sine and cosine function at different frequencies, where the\n",
    "    frequency is determined by the embedding dimension and the allowed range of the input values.\n",
    "\n",
    "    >>> sinusoidal_encode(\n",
    "    ...     torch.tensor([-5, 2, 1, 0]),\n",
    "    ...     embedding_dim=6,\n",
    "    ...     min_value=-5,\n",
    "    ...     max_value=5,\n",
    "    ... )\n",
    "    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
    "            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],\n",
    "            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],\n",
    "            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])\n",
    "    \"\"\"\n",
    "\n",
    "    if embedding_dim % 2 != 0 and not use_l2_norm:\n",
    "        raise ValueError(\"Embedding dimension must be even\")\n",
    "\n",
    "    if use_l2_norm:\n",
    "        if embedding_dim % 2 == 0:\n",
    "            reserved_dim = 2\n",
    "        else:\n",
    "            reserved_dim = 1\n",
    "        embedding_dim -= reserved_dim\n",
    "    else:\n",
    "        reserved_dim = 0  # will not be used\n",
    "\n",
    "    domain = max_value - min_value\n",
    "    y_shape = x.shape + (embedding_dim,)\n",
    "    y = torch.zeros(y_shape, device=x.device)\n",
    "    even_indices = torch.arange(0, embedding_dim, 2)\n",
    "    log_term = torch.log(torch.tensor(domain)) / embedding_dim\n",
    "    div_term = torch.exp(even_indices * -log_term)\n",
    "    x = x - min_value\n",
    "    values = x.unsqueeze(-1).float() * div_term\n",
    "    y[..., 0::2] = torch.sin(values)\n",
    "    y[..., 1::2] = torch.cos(values)\n",
    "\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "\n",
    "    return y\n",
    "\n",
    "def binary_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int | float,\n",
    "    max_value: int | float,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)\n",
    "    reserve_dim = 0 if not use_l2_norm else 1\n",
    "    x = x - min_value\n",
    "    maximum = x.max()\n",
    "    for i in range(embedding_dim - reserve_dim):\n",
    "        coeff = 2**i\n",
    "        if maximum < coeff:\n",
    "            break\n",
    "        y[..., -i - 1] = torch.floor(x / coeff) % 2\n",
    "        x = x - coeff * y[..., -i - 1]\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "    return y"
   ],
   "id": "aee6481c6481b767",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Prepare model and data",
   "id": "785b4c78f485ff64"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:16.791331Z",
     "start_time": "2025-09-15T19:32:10.097506Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-1B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt, token=\"XXX\").eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "tokenizer.add_special_tokens({'pad_token': '<|end_of_text|>'})\n",
    "model = model.half().to(device).eval()"
   ],
   "id": "11f1d7f066a43d8",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:16.969290Z",
     "start_time": "2025-09-15T19:32:16.963862Z"
    }
   },
   "cell_type": "code",
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))\n",
    "train_mask = mask < 0.9\n",
    "valid_mask = ~train_mask & (mask < 0.95)\n",
    "test_mask = ~train_mask & ~valid_mask\n",
    "\n",
    "train_values = all_values[train_mask]\n",
    "valid_values = all_values[valid_mask]\n",
    "test_values = all_values[test_mask]"
   ],
   "id": "77893cb1a5a3c08e",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:17.161858Z",
     "start_time": "2025-09-15T19:32:17.155051Z"
    }
   },
   "cell_type": "code",
   "source": [
    "all_inputs = all_values.tolist()\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "        \n",
    "train_inputs = [x for x in all_inputs if x in train_values_set]\n",
    "valid_inputs = [x for x in all_inputs if x in valid_values_set]\n",
    "test_inputs = [x for x in all_inputs if x in test_values_set]\n",
    "\n",
    "# sanity check\n",
    "assert set(train_inputs) & set(valid_inputs) == set()\n",
    "assert set(train_inputs) & set(test_inputs) == set()\n",
    "assert set(valid_inputs) & set(test_inputs) == set()\n",
    "\n",
    "random.seed(0)\n",
    "random.shuffle(train_inputs)\n",
    "random.shuffle(valid_inputs)\n",
    "random.shuffle(test_inputs)\n",
    "valid_size = 4096\n",
    "train_size = 100_000\n",
    "train_inputs = train_inputs[:train_size]\n",
    "valid_inputs = valid_inputs[:valid_size]"
   ],
   "id": "b7034237fc92de8a",
   "outputs": [],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:17.314263Z",
     "start_time": "2025-09-15T19:32:17.299817Z"
    }
   },
   "cell_type": "code",
   "source": "len(train_inputs)",
   "id": "b977f354e932dc29",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "888"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Constructing altered natural texts -- with all numbers from pre-defined ranges",
   "id": "665d5464c926042c"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:32.008774Z",
     "start_time": "2025-09-15T19:32:17.473039Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# cell loading the input texts\n",
    "import json\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import datasets\n",
    "from git import Repo\n",
    "import os\n",
    "\n",
    "import itertools\n",
    "\n",
    "\n",
    "HOME_PATH = \"./\"\n",
    "\n",
    "def load_data(genre=\"food-1\", downsample_to=0):\n",
    "    \"\"\"\n",
    "    genre: input , genre of dataset you want to load\n",
    "    data :  output,\n",
    "\n",
    "    \"\"\"\n",
    "    if genre ==  'food-1':\n",
    "        directory_path = \"./FoodRecipe-ImageCaptioning/\"\n",
    "        if os.path.exists(directory_path) and os.path.isdir(directory_path):\n",
    "            1;\n",
    "        else:\n",
    "            Repo.clone_from(\"https://github.com/samsatp/FoodRecipe-ImageCaptioning.git/\", \"./FoodRecipe-ImageCaptioning/\")\n",
    "\n",
    "        with open(HOME_PATH + directory_path + \"data/data_strings_local.json\", \"r\") as fp:\n",
    "            recipes = json.load(fp)\n",
    "            #print(recipes)\n",
    "            concated_data = [' '.join(d) for d in recipes.values()]\n",
    "            data = concated_data\n",
    "            print(len(data))\n",
    "\n",
    "    elif genre == 'food-2':\n",
    "        reciepe_data2 = datasets.load_dataset(\"m3hrdadfi/recipe_nlg_lite\",trust_remote_code=True) #steps o ingredients\n",
    "        #train 6118 test 1000\n",
    "        # ['uid', 'name', 'description', 'link', 'ner', 'ingredients', 'steps']\n",
    "        data  = reciepe_data2['train']['steps']\n",
    "\n",
    "    elif genre == 'arthmetic-1':\n",
    "\n",
    "        metamathqa = datasets.load_dataset(\"meta-math/MetaMathQA\") #original_question\n",
    "        data = metamathqa['train']['original_question']\n",
    "\n",
    "    elif genre == 'arthmetic-2':\n",
    "\n",
    "        drop = datasets.load_dataset(\"ucinlp/drop\") #passage\n",
    "        data = drop['train']['passage']#['section_id', 'query_id', 'passage', 'question', 'answers_spans']\n",
    "\n",
    "    elif genre == 'arthmetic-3':\n",
    "        aquarat = datasets.load_dataset(\"deepmind/aqua_rat\") #['question', 'options', 'rationale', 'correct'] go question or rationale\n",
    "        data = aquarat['train']['question']\n",
    "\n",
    "    elif genre == 'technical-1':\n",
    "        icdatta = datasets.load_dataset(\"atta00/icd10-codes\") #['chapter', 'section', 'category', 'category_code', 'code', 'description']\n",
    "        data = [f\"description: {d} | code: {c}\" for d,c in zip(icdatta['train']['description'], icdatta['train']['code'] )] # go for description + code\n",
    "\n",
    "    elif genre == 'technical-2':\n",
    "        icdcm = datasets.load_dataset(\"Gokul-waterlabs/ICD-10-CM\")#input+output\n",
    "        data = [f\"Description: {d} | code: {c}\" for d,c in zip(icdcm['train']['input'], icdcm['train']['output'] )]\n",
    "\n",
    "    elif genre == 'datetime-1':\n",
    "\n",
    "        directory_path = \"./TimeLineExtractionDecisionLettersCASE/\"\n",
    "        if os.path.exists(directory_path) and os.path.isdir(directory_path):\n",
    "            1;\n",
    "        else:\n",
    "            Repo.clone_from(\"https://github.com/irlabamsterdam/TimeLineExtractionDecisionLettersCASE.git\", directory_path)\n",
    "\n",
    "        data = []\n",
    "        for file in tqdm(glob(HOME_PATH + directory_path + 'data/txt_files/train/*txt')):\n",
    "            with open(file, 'r') as fp:\n",
    "                data.append(fp.read())\n",
    "    else:\n",
    "        data=\"ERROR : Pick a genre from [food-1/2, arthmetic-1/2/3, techincal-1/2, datetime]\"\n",
    "        print(data)\n",
    "    print(\"Number of samples in the data loaded:\", len(data))\n",
    "    if downsample_to and len(data) > downsample_to:\n",
    "        print(\"Downsampling to %s\" % downsample_to)\n",
    "        data = data[:downsample_to]\n",
    "\n",
    "    return data\n",
    "\n",
    "texts = list(itertools.chain(*(load_data(k) for k in ['food-1', 'food-2', 'arthmetic-1', 'arthmetic-2', 'arthmetic-3', 'technical-1', 'technical-2', 'datetime-1'])))\n",
    "print(len(texts))"
   ],
   "id": "f25ec2c4beaa530f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "719\n",
      "Number of samples in the data loaded: 719\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Repo card metadata block was not found. Setting CardData to empty.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples in the data loaded: 6118\n",
      "Number of samples in the data loaded: 395000\n",
      "Number of samples in the data loaded: 77400\n",
      "Number of samples in the data loaded: 97467\n",
      "Number of samples in the data loaded: 25719\n",
      "Number of samples in the data loaded: 74044\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:00<00:00, 2521.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples in the data loaded: 50\n",
      "676517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:32.227953Z",
     "start_time": "2025-09-15T19:32:32.220473Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import re\n",
    "\n",
    "\n",
    "def make_str_input(all_possible_operands: list[int]) -> str:\n",
    "    selected_text = random.choice(texts)\n",
    "    text_with_replaced_nums = re.sub(r\"\\d+\", lambda _: str(random.choice(all_possible_operands)), selected_text)\n",
    "    return text_with_replaced_nums\n",
    "\n",
    "make_str_input(train_inputs), make_str_input(valid_inputs)"
   ],
   "id": "92cc3db4cd042de6",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('The population of Port Perry is seven times as many as the population of Wellington. The population of Port Perry is 659 more than the population of Lazy Harbor. If Wellington has a population of 699, how many people live in Port Perry and Lazy Harbor combined?',\n",
       " \"The Gnollish language consists of 338 words, ``splargh,'' ``glumph,'' and ``amr.''  In a sentence, ``splargh'' cannot come directly before ``glumph''; all other sentences are grammatically correct (including sentences with repeated words).  How many valid 545-word sentences are there in Gnollish?\")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Inference of model's hidden states",
   "id": "576cfa3c4fb69ddf"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:32.450125Z",
     "start_time": "2025-09-15T19:32:32.393006Z"
    }
   },
   "cell_type": "code",
   "source": [
    "num_input_ids = tokenizer(list(map(str, all_inputs)), add_special_tokens=False, return_tensors=\"pt\").input_ids[:, 0]\n",
    "batch_inputs = tokenizer('In a shower, 801 cm of rain falls. The volume of water that falls on 289.564 hectares of ground is:', return_tensors=\"pt\")\n",
    "torch.isin(batch_inputs.input_ids, num_input_ids)"
   ],
   "id": "5d155876416af74b",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False, False, False, False, False,  True, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "          True, False,  True, False, False, False, False, False]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:32.665246Z",
     "start_time": "2025-09-15T19:32:32.641713Z"
    }
   },
   "cell_type": "code",
   "source": "tokenizer(list(map(str, all_inputs)), add_special_tokens=False, return_tensors=\"pt\").input_ids[:, 0]",
   "id": "f712bd258566dc9a",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([   15,    16,    17,    18,    19,    20,    21,    22,    23,    24,\n",
       "          605,   806,   717,  1032,   975,   868,   845,  1114,   972,   777,\n",
       "          508,  1691,  1313,  1419,  1187,   914,  1627,  1544,  1591,  1682,\n",
       "          966,  2148,   843,  1644,  1958,  1758,  1927,  1806,  1987,  2137,\n",
       "         1272,  3174,  2983,  3391,  2096,  1774,  2790,  2618,  2166,  2491,\n",
       "         1135,  3971,  4103,  4331,  4370,  2131,  3487,  3226,  2970,  2946,\n",
       "         1399,  5547,  5538,  5495,  1227,  2397,  2287,  3080,  2614,  3076,\n",
       "         2031,  6028,  5332,  5958,  5728,  2075,  4767,  2813,  2495,  4643,\n",
       "         1490,  5932,  6086,  6069,  5833,  5313,  4218,  4044,  2421,  4578,\n",
       "         1954,  5925,  6083,  6365,  6281,  2721,  4161,  3534,  3264,  1484,\n",
       "         1041,  4645,  4278,  6889,  6849,  6550,  7461,  7699,  6640,  7743,\n",
       "         5120,  5037,  7261,  8190,  8011,  7322,  8027,  8546,  8899,  9079,\n",
       "         4364,  7994,  8259,  4513,  8874,  6549,  9390,  6804,  4386,  9748,\n",
       "         5894,  9263,  9413,  9423,  9565,  8878,  9795, 10148, 10350, 10125,\n",
       "         6860,  9335, 10239, 10290,  8929,  9591, 10465, 10288, 10410, 10161,\n",
       "         3965,  9690,  9756,  9800, 10559,  9992, 10132, 10895, 11286, 11068,\n",
       "         6330, 10718, 10674,  9892, 10513, 10680, 11247, 11515,  8953, 11739,\n",
       "         8258, 11123, 10861, 11908, 11771, 10005, 10967, 11242, 11256, 11128,\n",
       "         5245, 10562, 10828, 10750, 10336,  9741,  9714,  9674,  9367,  9378,\n",
       "         7028,  7529,  5926,  7285,  6393,  6280,  5162,  4468,  3753,  2550,\n",
       "         1049,   679,  2366,  9639,  7854, 10866, 11056, 12060, 12171, 12652,\n",
       "         8848, 11483, 11227, 11702, 11584, 12112, 12463, 13460, 13302, 13762,\n",
       "         8610, 12425,  9716, 12533, 10697, 11057, 14057, 14206, 14261, 14378,\n",
       "         9870, 12245, 12338, 12994, 11727, 12422, 14087, 14590, 13895, 14815,\n",
       "         8273, 13341, 12754, 14052, 13719, 13078, 14205, 14125, 14185, 14735,\n",
       "         5154, 13860, 12326, 14022, 12375,  3192,  4146, 15574, 15966, 15537,\n",
       "        11387, 15602, 14274, 15666, 12815, 14374, 15999, 16567, 16332, 16955,\n",
       "        10914, 15828, 15741, 15451, 16590, 14417, 16660, 16367, 16949, 17267,\n",
       "        11209, 15282, 16544, 16085, 17058, 15935, 17361, 17897, 15287, 17212,\n",
       "        13754, 17335, 16443, 17313, 17168, 16780, 17408, 18163, 17690, 15531,\n",
       "         3101, 12405, 13121, 13236, 12166, 13364, 12879, 14777, 14498, 15500,\n",
       "        12226, 15134, 13384, 15231, 16104, 15189, 15340, 16718, 17592, 16874,\n",
       "         9588, 14423, 15805, 15726, 16723, 15257, 17470, 13817, 16884, 18196,\n",
       "        10568, 16707, 17079,  8765, 17153, 16596, 17014, 17609, 18633, 17887,\n",
       "        13679, 16546, 17590, 16522, 17451, 12901, 18061, 17678, 19746, 18634,\n",
       "         8652, 18113, 16482, 17228, 18384, 17306, 18349, 18520, 17112, 19192,\n",
       "         6843, 18277, 18509, 18199, 15951, 12676, 18044, 18775, 19057, 19929,\n",
       "        14648, 18650, 17662, 18017, 18265, 12935, 18322, 10898, 19166, 19867,\n",
       "        13897, 19162, 18781, 19230, 12910, 18695, 16481, 20062, 19081, 20422,\n",
       "        15515, 19631, 19695, 18252, 20077, 19498, 19615, 20698, 19838, 18572,\n",
       "         3443, 10841, 16496, 13074,  7507, 16408, 17264, 18501, 18058, 12378,\n",
       "        14487, 17337, 17574, 19288, 17448, 18136, 17763, 19561, 19770, 19391,\n",
       "        12819, 18245, 16460, 19711, 18517, 17837, 20363, 20465, 19140, 16371,\n",
       "        14245, 19852, 16739, 20153, 20165, 19305, 21299, 18318, 20596, 20963,\n",
       "        14868, 18495, 20502, 17147, 14870, 19697, 20385, 20800, 19956, 21125,\n",
       "        10617, 20360, 21098, 20235, 20555, 20325, 10961, 21675, 21209, 22094,\n",
       "        16551, 19608, 20911, 21290, 21033, 19988, 21404, 20419, 20304, 21330,\n",
       "        17711, 20617, 21757, 21505, 21358, 19799, 22191, 21144, 22086, 21848,\n",
       "        11738, 21235, 21984, 21884, 20339, 19773, 21511, 22184, 21310, 22418,\n",
       "        18518, 21824, 21776, 22741, 22054, 21038, 19447, 22640, 21962, 18162,\n",
       "         2636, 14408, 17824, 17735, 18048, 17786, 19673, 20068, 19869, 12448,\n",
       "        15633, 18625,  8358, 21164, 20998, 19633, 20571, 22507, 21312, 21851,\n",
       "        15830, 20767, 20936, 21123, 21177, 18415, 22593, 22369, 21458, 21618,\n",
       "        17252, 20823, 20711, 21876, 22467, 20618, 21600, 19038, 22600, 23033,\n",
       "        17048, 22058, 21791, 19642, 21239, 20749, 22048, 23215, 22287, 22782,\n",
       "        13506, 21860, 21478, 22663, 22303, 14148, 20866, 23906, 22895, 22424,\n",
       "        17698, 20460, 19242, 21789, 22210, 20943, 23477, 19282, 22049, 23642,\n",
       "        18712, 22005, 22468, 22529, 23402, 21228, 20758, 23411, 22915, 24847,\n",
       "        18216, 23864, 23670, 23493, 23816, 21535, 22345, 22159, 20691, 22905,\n",
       "        20615, 24380, 20128, 22608, 23428, 22754, 24515, 24574, 21856, 21944,\n",
       "         5067, 18262, 20224, 21006, 20354, 19666, 20213, 21996, 19944, 21138,\n",
       "        17608, 20973, 21018, 22922, 22638, 21385, 21379, 21717, 21985, 23388,\n",
       "        17416, 22488, 19808, 22801, 23000, 15894, 22385, 23103, 23574, 24239,\n",
       "        18660, 21729, 20775, 23736, 24307, 22276, 22422, 21788, 24495, 23079,\n",
       "        14033, 23525, 22266, 22956, 21975, 22926, 22642, 22644, 23802, 24734,\n",
       "        13655, 23409, 23181, 21598, 21969, 15573, 20744, 23480, 23654, 25090,\n",
       "        19274, 24132, 24199, 24491, 23888, 23467, 10943, 19774, 24427, 25289,\n",
       "        21218, 23403, 22768, 24938, 25513, 21129, 24187, 24375, 17458, 25136,\n",
       "        17814, 25091, 25178, 24887, 24313, 23717, 22347, 21897, 23292, 25458,\n",
       "        21741, 25168, 25073, 25298, 25392, 24394, 23578, 25388, 25169, 23459,\n",
       "         7007, 19597, 20253, 20436, 21949, 21469, 22457, 18770, 21295, 22874,\n",
       "        19027, 22375, 22708, 22977, 23193, 22744, 23929, 25150, 21982, 24758,\n",
       "        13104, 20873, 23024, 24388, 24735, 23309, 24430, 23486, 24054, 22194,\n",
       "        20785, 24626, 24289, 24865, 24438, 24939, 23969, 22039, 25527, 25809,\n",
       "        21112, 25021, 25560, 26260, 23800, 23901, 25594, 23619, 20338, 25541,\n",
       "        11711, 23986, 23644, 25504, 23952, 23532, 24456, 23776, 25302, 26439,\n",
       "        19104, 25110, 24376, 26083, 24402, 22240, 25358, 23275, 17521, 24619,\n",
       "        20772, 24876, 23624, 23267, 24472, 22908, 23823, 15831, 23592, 25659,\n",
       "        19423, 21893, 23833, 26008, 22148, 22539, 25251, 23171, 24216, 16474,\n",
       "        22876, 26234, 24763, 24531, 25926, 25808, 24832, 25314, 26519, 23987,\n",
       "         4728, 17973, 13135, 20899, 20417, 21032, 22397, 23178, 11770, 21474,\n",
       "        19232, 22588, 19270, 24288, 25498, 23582, 23713, 25528, 23141, 18831,\n",
       "        18248, 23282, 23105, 23848, 25016, 22091, 23038, 24920, 22716, 26218,\n",
       "        21221, 25009, 23879, 22904, 26223, 23424, 25192, 26244, 24250, 25465,\n",
       "        19899, 25496, 25377, 23996, 24344, 24650, 26563, 25125, 24951, 26537,\n",
       "        16217, 24866, 24571, 25724, 25515, 22869, 25505, 20907, 23805, 24061,\n",
       "        18670, 24963, 24071, 26051, 19355, 24678, 22455, 26013, 25862, 26497,\n",
       "        22440, 25665, 25303, 25747, 25822, 17419, 24870, 23873, 25890, 25622,\n",
       "        19272, 25339, 23213, 24902, 25962, 19445, 25399, 26058, 12251, 25354,\n",
       "        21381, 24962, 24110, 26088, 26227, 25238, 24542, 24777, 24809, 22889,\n",
       "         7467, 19319, 21026, 23305, 22777, 22393, 22224, 23505, 23629, 21278,\n",
       "        21056, 17000, 22750, 24331, 24579, 22387, 24487, 24391, 25828, 24337,\n",
       "        18485, 22536, 20275, 22614, 23890, 21910, 26026, 26437, 25001, 25344,\n",
       "        19306, 25717, 25401, 25806, 24347, 26970, 25612, 21936, 25454, 26164,\n",
       "        21251, 21322, 20249, 26576, 25687, 24599, 26491, 26511, 26979, 24680,\n",
       "        15862, 24989, 24597, 25326, 25741, 25875, 26067, 27341, 27079, 26328,\n",
       "        16415, 26114, 26366, 26087, 26281, 24837, 25285, 27134, 23386, 24792,\n",
       "        21133, 25693, 24425, 24471, 26007, 24609, 25208, 26409, 17272, 25476,\n",
       "        19068, 25643, 25873, 24742, 23812, 24961, 27468, 22207, 24538, 25350,\n",
       "        19146, 24606, 22992, 24242, 22897, 22101, 23031, 22694, 19416,  5500])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:32:32.851735Z",
     "start_time": "2025-09-15T19:32:32.843461Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import tqdm\n",
    "\n",
    "def get_hidden_states(model, str_inputs: list[str], batch_size: int) -> tuple[dict[int, Tensor], Tensor]:\n",
    "    model.eval()\n",
    "    num_input_ids = tokenizer(list(map(str, all_inputs)), add_special_tokens=False, return_tensors=\"pt\").input_ids[:, 0]\n",
    "\n",
    "    nums: list[str] = []\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\", padding=True, truncation=True)\n",
    "            num_pos = torch.isin(batch_inputs.input_ids, num_input_ids)\n",
    "            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[num_pos].detach().cpu())\n",
    "            new_nums = tokenizer.batch_decode(batch_inputs.input_ids[num_pos])\n",
    "            nums.extend(new_nums)\n",
    "\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}, torch.tensor(list(map(int, nums)), device=device)"
   ],
   "id": "2a7460caa4e09ba7",
   "outputs": [],
   "execution_count": 12
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:55:20.669793Z",
     "start_time": "2025-09-15T19:32:32.989904Z"
    }
   },
   "cell_type": "code",
   "source": [
    "batch_size = 16\n",
    "\n",
    "train_input_texts = [make_str_input(train_inputs) for _ in range(train_size)]\n",
    "valid_input_texts = [make_str_input(valid_inputs) for _ in range(valid_size)]\n",
    "test_input_texts = [make_str_input(test_inputs) for _ in range(valid_size)]\n",
    "\n",
    "train_hidden_states, train_labels = get_hidden_states(model, train_input_texts, batch_size)\n",
    "assert train_hidden_states[0].shape[0] == len(train_labels)\n",
    "\n",
    "valid_hidden_states, valid_labels = get_hidden_states(model, valid_input_texts, batch_size)\n",
    "assert valid_hidden_states[0].shape[0] == len(valid_labels)\n",
    "\n",
    "test_hidden_states, test_labels = get_hidden_states(model, test_input_texts, batch_size)\n",
    "assert test_hidden_states[0].shape[0] == len(test_labels)\n",
    "\n",
    "# hidden_state, new_nums = get_hidden_states(model, train_input_texts, batch_size)\n"
   ],
   "id": "831000bb04b5ffd7",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/6250 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "da27c023249b4502a84a6ad0b4fa0e1b"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/256 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5644af3e785a4a039824b57481b183f8"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/256 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "79f32c9be26940ba82188cec83410fbd"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 13
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Probing",
   "id": "bfa6b454ea64b77f"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:55:20.867917Z",
     "start_time": "2025-09-15T19:55:20.856114Z"
    }
   },
   "cell_type": "code",
   "source": [
    "basis_embs_sin = sinusoidal_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=train_hidden_states[0].shape[-1],\n",
    ")\n",
    "\n",
    "basis_embs_bin = binary_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=10,\n",
    ")"
   ],
   "id": "1788ca655c6b2f13",
   "outputs": [],
   "execution_count": 14
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T19:55:21.055965Z",
     "start_time": "2025-09-15T19:55:21.044910Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class ClassifierProbe(torch.nn.Module):\n",
    "    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):\n",
    "        super().__init__()\n",
    "        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)\n",
    "        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)\n",
    "        self.basis: torch.nn.Buffer\n",
    "        self.heldout_mask: torch.nn.Buffer\n",
    "        self.register_buffer(\"basis\", basis)\n",
    "        self.register_buffer(\"heldout_mask\", heldout_mask)\n",
    "    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:\n",
    "        latent_x = self.emb_to_latent(x)\n",
    "        # during training, model learns to choose among only training tokens\n",
    "        # but during eval, model must choose among all tokens\n",
    "        # this means that the model is never exposed to the eval tokens during training\n",
    "        latent_choices = self.basis_to_latent(self.basis)\n",
    "        logits = latent_x @ latent_choices.T\n",
    "        if holdout_eval_tokens:\n",
    "            logits[:, self.heldout_mask] = float(\"-inf\")\n",
    "        return logits"
   ],
   "id": "37cff5a15e9a866f",
   "outputs": [],
   "execution_count": 15
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:17:54.785769Z",
     "start_time": "2025-09-15T19:55:21.229846Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# train_labels = torch.tensor([x2 for x1, x2 in train_inputs])\n",
    "# valid_labels = torch.tensor([x2 for x1, x2 in valid_inputs]).to(device)\n",
    "# test_labels = torch.tensor([x2 for x1, x2 in test_inputs]).to(device)\n",
    "\n",
    "test_accuracies = {\"sin\": {}, \"bin\": {}, \"lin\": {}, \"log\": {}}\n",
    "\n",
    "for basis_name, basis_embs in {\n",
    "    \"sin\": basis_embs_sin,\n",
    "    \"bin\": basis_embs_bin\n",
    "}.items():\n",
    "    for layer_idx in range(len(train_hidden_states)):\n",
    "\n",
    "        torch.manual_seed(0)\n",
    "        probe = ClassifierProbe(\n",
    "            emb_dim=train_hidden_states[0].shape[-1],\n",
    "            hidden_dim=100,\n",
    "            basis=basis_embs,\n",
    "            heldout_mask=test_mask,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-4, weight_decay=0)\n",
    "\n",
    "        rng = torch.Generator().manual_seed(0)\n",
    "        best_val_acc = -1\n",
    "        best_ckpt = None\n",
    "        for i in range(30000+1):\n",
    "            probe.train()\n",
    "            optimizer.zero_grad()\n",
    "            minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "            x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "            y = train_labels[minibatch_idcs].to(device)\n",
    "            logits = probe(x, holdout_eval_tokens=True)\n",
    "            # add l1 regularization of all params to the loss\n",
    "            loss = torch.nn.functional.cross_entropy(logits, y) + 0.001 * sum(p.abs().sum() for p in probe.parameters())\n",
    "            loss += 1e-3 * sum(p.abs().sum() for p in probe.parameters()) # L1 regularization\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if i % 500 == 0:\n",
    "                train_acc = (logits.argmax(dim=-1) == y).float().mean().item()\n",
    "                probe.eval()\n",
    "                with torch.no_grad():\n",
    "                    valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                    valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                    valid_accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                    if valid_accuracy > best_val_acc:\n",
    "                        best_val_acc = valid_accuracy\n",
    "                        best_ckpt = probe.state_dict()\n",
    "                print(f\"{basis_name} {i=:>5} train loss: {loss.item():5.2f}  train acc: {train_acc:.2f}  val loss: {valid_loss.item():5.2f}  valid acc: {valid_accuracy:.2f}\")\n",
    "        probe.load_state_dict(best_ckpt)\n",
    "        probe.eval()\n",
    "        with torch.no_grad():\n",
    "            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        test_accuracies[basis_name][layer_idx] = test_accuracy\n",
    "        print(f\"->  {basis_name}  layer idx: {layer_idx:<3}, best valid accuracy: {best_val_acc:.2f}, test accuracy: {test_accuracy:.2f}\")\n",
    "                    "
   ],
   "id": "d248aea218155782",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  6.85  valid acc: 0.00\n",
      "sin i=  500 train loss:  2.41  train acc: 0.98  val loss:  1.48  valid acc: 0.72\n",
      "sin i= 1000 train loss:  1.94  train acc: 1.00  val loss:  1.02  valid acc: 0.84\n",
      "sin i= 1500 train loss:  1.74  train acc: 1.00  val loss:  0.88  valid acc: 0.88\n",
      "sin i= 2000 train loss:  1.62  train acc: 1.00  val loss:  0.82  valid acc: 0.84\n",
      "sin i= 2500 train loss:  1.53  train acc: 1.00  val loss:  0.77  valid acc: 0.83\n",
      "sin i= 3000 train loss:  1.45  train acc: 1.00  val loss:  0.73  valid acc: 0.83\n",
      "sin i= 3500 train loss:  1.39  train acc: 1.00  val loss:  0.69  valid acc: 0.84\n",
      "sin i= 4000 train loss:  1.35  train acc: 1.00  val loss:  0.66  valid acc: 0.86\n",
      "sin i= 4500 train loss:  1.30  train acc: 1.00  val loss:  0.64  valid acc: 0.88\n",
      "sin i= 5000 train loss:  1.27  train acc: 1.00  val loss:  0.63  valid acc: 0.90\n",
      "sin i= 5500 train loss:  1.24  train acc: 1.00  val loss:  0.61  valid acc: 0.88\n",
      "sin i= 6000 train loss:  1.20  train acc: 1.00  val loss:  0.60  valid acc: 0.91\n",
      "sin i= 6500 train loss:  1.17  train acc: 1.00  val loss:  0.59  valid acc: 0.88\n",
      "sin i= 7000 train loss:  1.14  train acc: 1.00  val loss:  0.59  valid acc: 0.90\n",
      "sin i= 7500 train loss:  1.12  train acc: 1.00  val loss:  0.57  valid acc: 0.88\n",
      "sin i= 8000 train loss:  1.10  train acc: 1.00  val loss:  0.58  valid acc: 0.90\n",
      "sin i= 8500 train loss:  1.08  train acc: 1.00  val loss:  0.58  valid acc: 0.88\n",
      "sin i= 9000 train loss:  1.07  train acc: 1.00  val loss:  0.57  valid acc: 0.86\n",
      "sin i= 9500 train loss:  1.05  train acc: 1.00  val loss:  0.57  valid acc: 0.86\n",
      "sin i=10000 train loss:  1.04  train acc: 1.00  val loss:  0.56  valid acc: 0.88\n",
      "->  sin  layer idx: 0  , best valid accuracy: 0.91, test accuracy: 0.89\n",
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  6.63  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.48  train acc: 1.00  val loss:  0.52  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.14  train acc: 1.00  val loss:  0.30  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.96  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.84  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.75  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.68  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.64  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.60  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.58  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.56  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.55  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.52  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.49  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.48  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.48  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 1  , best valid accuracy: 1.00, test accuracy: 0.98\n",
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  6.60  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.43  train acc: 1.00  val loss:  0.49  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.12  train acc: 1.00  val loss:  0.29  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.95  train acc: 1.00  val loss:  0.21  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.83  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.75  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.68  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.64  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.62  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.60  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.58  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.57  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.56  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.53  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.52  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.51  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.51  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 2  , best valid accuracy: 1.00, test accuracy: 0.98\n",
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  6.59  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.40  train acc: 1.00  val loss:  0.48  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.09  train acc: 1.00  val loss:  0.29  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.91  train acc: 1.00  val loss:  0.21  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.81  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.73  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.67  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.63  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.61  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.59  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.57  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.56  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.55  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.53  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.52  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.50  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.49  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.48  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 3  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  6.55  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.29  train acc: 1.00  val loss:  0.43  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.99  train acc: 1.00  val loss:  0.25  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.83  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.74  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.62  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.58  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.56  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.51  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.46  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 4  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.52  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.27  train acc: 1.00  val loss:  0.43  valid acc: 0.99\n",
      "sin i= 1000 train loss:  1.00  train acc: 1.00  val loss:  0.26  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.83  train acc: 1.00  val loss:  0.19  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.74  train acc: 1.00  val loss:  0.17  valid acc: 0.99\n",
      "sin i= 2500 train loss:  0.66  train acc: 1.00  val loss:  0.14  valid acc: 0.99\n",
      "sin i= 3000 train loss:  0.61  train acc: 1.00  val loss:  0.13  valid acc: 0.99\n",
      "sin i= 3500 train loss:  0.57  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.52  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.50  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.48  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.47  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.46  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 5  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.40  train acc: 0.00  val loss:  6.65  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.29  train acc: 1.00  val loss:  0.42  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.01  train acc: 1.00  val loss:  0.27  valid acc: 0.99\n",
      "sin i= 1500 train loss:  0.85  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.75  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.62  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.59  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.56  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.54  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.52  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.46  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.46  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 6  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.40  train acc: 0.00  val loss:  6.69  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.28  train acc: 1.00  val loss:  0.43  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.01  train acc: 1.00  val loss:  0.26  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.85  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.76  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.69  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.64  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.60  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.57  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.51  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.49  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 9500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 7  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.41  train acc: 0.00  val loss:  6.74  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.25  train acc: 1.00  val loss:  0.41  valid acc: 0.99\n",
      "sin i= 1000 train loss:  0.97  train acc: 1.00  val loss:  0.25  valid acc: 0.99\n",
      "sin i= 1500 train loss:  0.82  train acc: 1.00  val loss:  0.19  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.74  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.67  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.63  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.59  train acc: 1.00  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 4000 train loss:  0.57  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.51  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 6000 train loss:  0.50  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.49  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.43  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.42  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "->  sin  layer idx: 8  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.42  train acc: 0.00  val loss:  6.65  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.21  train acc: 1.00  val loss:  0.38  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.95  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.80  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.71  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.64  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.60  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.57  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.54  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.50  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.43  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "->  sin  layer idx: 9  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.45  train acc: 0.00  val loss:  6.60  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.15  train acc: 1.00  val loss:  0.35  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.90  train acc: 1.00  val loss:  0.22  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.76  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.68  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.62  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.58  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.52  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.49  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.42  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.41  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.40  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.40  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.39  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.39  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.39  train acc: 0.99  val loss:  0.07  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.38  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "->  sin  layer idx: 10 , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.49  train acc: 0.00  val loss:  6.62  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.05  train acc: 1.00  val loss:  0.31  valid acc: 0.99\n",
      "sin i= 1000 train loss:  0.83  train acc: 1.00  val loss:  0.20  valid acc: 0.99\n",
      "sin i= 1500 train loss:  0.70  train acc: 1.00  val loss:  0.16  valid acc: 0.99\n",
      "sin i= 2000 train loss:  0.63  train acc: 1.00  val loss:  0.13  valid acc: 0.99\n",
      "sin i= 2500 train loss:  0.58  train acc: 1.00  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 3000 train loss:  0.54  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 3500 train loss:  0.51  train acc: 0.99  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 4000 train loss:  0.48  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 4500 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 5000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 5500 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 6000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.41  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.40  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.39  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 0.99\n",
      "sin i= 9500 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 0.99\n",
      "->  sin  layer idx: 11 , best valid accuracy: 0.99, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.51  train acc: 0.00  val loss:  6.70  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.01  train acc: 1.00  val loss:  0.30  valid acc: 0.99\n",
      "sin i= 1000 train loss:  0.79  train acc: 0.99  val loss:  0.20  valid acc: 0.99\n",
      "sin i= 1500 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 0.99\n",
      "sin i= 2000 train loss:  0.59  train acc: 0.99  val loss:  0.13  valid acc: 0.99\n",
      "sin i= 2500 train loss:  0.54  train acc: 0.99  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 3000 train loss:  0.49  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 3500 train loss:  0.46  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 4000 train loss:  0.44  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 4500 train loss:  0.43  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 5000 train loss:  0.42  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 5500 train loss:  0.42  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 6000 train loss:  0.41  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.40  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.40  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.39  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.39  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9500 train loss:  0.38  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 0.99\n",
      "->  sin  layer idx: 12 , best valid accuracy: 0.99, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.59  train acc: 0.00  val loss:  6.78  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.01  train acc: 0.99  val loss:  0.31  valid acc: 0.99\n",
      "sin i= 1000 train loss:  0.79  train acc: 0.99  val loss:  0.20  valid acc: 0.99\n",
      "sin i= 1500 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 0.99\n",
      "sin i= 2000 train loss:  0.60  train acc: 0.99  val loss:  0.13  valid acc: 0.99\n",
      "sin i= 2500 train loss:  0.54  train acc: 0.99  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 3000 train loss:  0.51  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 3500 train loss:  0.49  train acc: 1.00  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 4000 train loss:  0.47  train acc: 0.99  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 4500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 5000 train loss:  0.44  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 5500 train loss:  0.43  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 6000 train loss:  0.42  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.41  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.40  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.39  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9500 train loss:  0.38  train acc: 0.99  val loss:  0.07  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.37  train acc: 0.99  val loss:  0.07  valid acc: 0.99\n",
      "->  sin  layer idx: 13 , best valid accuracy: 0.99, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.63  train acc: 0.00  val loss:  6.83  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.01  train acc: 0.99  val loss:  0.32  valid acc: 0.98\n",
      "sin i= 1000 train loss:  0.78  train acc: 0.99  val loss:  0.22  valid acc: 0.98\n",
      "sin i= 1500 train loss:  0.66  train acc: 1.00  val loss:  0.16  valid acc: 0.99\n",
      "sin i= 2000 train loss:  0.58  train acc: 0.99  val loss:  0.14  valid acc: 0.99\n",
      "sin i= 2500 train loss:  0.53  train acc: 1.00  val loss:  0.14  valid acc: 0.98\n",
      "sin i= 3000 train loss:  0.49  train acc: 1.00  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 3500 train loss:  0.48  train acc: 0.99  val loss:  0.12  valid acc: 0.99\n",
      "sin i= 4000 train loss:  0.46  train acc: 0.99  val loss:  0.11  valid acc: 0.99\n",
      "sin i= 4500 train loss:  0.44  train acc: 0.99  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 5000 train loss:  0.43  train acc: 0.99  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 5500 train loss:  0.41  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 6000 train loss:  0.41  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 6500 train loss:  0.40  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.40  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.39  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.39  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.39  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "sin i= 9500 train loss:  0.38  train acc: 0.99  val loss:  0.08  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.37  train acc: 1.00  val loss:  0.07  valid acc: 0.99\n",
      "->  sin  layer idx: 14 , best valid accuracy: 0.99, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.74  train acc: 0.00  val loss:  7.85  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.08  train acc: 0.99  val loss:  0.38  valid acc: 0.97\n",
      "sin i= 1000 train loss:  0.83  train acc: 0.99  val loss:  0.26  valid acc: 0.97\n",
      "sin i= 1500 train loss:  0.71  train acc: 0.99  val loss:  0.19  valid acc: 0.98\n",
      "sin i= 2000 train loss:  0.63  train acc: 0.99  val loss:  0.17  valid acc: 0.98\n",
      "sin i= 2500 train loss:  0.58  train acc: 0.99  val loss:  0.16  valid acc: 0.98\n",
      "sin i= 3000 train loss:  0.54  train acc: 0.99  val loss:  0.14  valid acc: 0.98\n",
      "sin i= 3500 train loss:  0.53  train acc: 0.98  val loss:  0.15  valid acc: 0.98\n",
      "sin i= 4000 train loss:  0.50  train acc: 0.99  val loss:  0.14  valid acc: 0.98\n",
      "sin i= 4500 train loss:  0.50  train acc: 0.99  val loss:  0.13  valid acc: 0.98\n",
      "sin i= 5000 train loss:  0.47  train acc: 0.99  val loss:  0.12  valid acc: 0.98\n",
      "sin i= 5500 train loss:  0.46  train acc: 0.99  val loss:  0.12  valid acc: 0.98\n",
      "sin i= 6000 train loss:  0.46  train acc: 0.99  val loss:  0.11  valid acc: 0.98\n",
      "sin i= 6500 train loss:  0.46  train acc: 0.99  val loss:  0.11  valid acc: 0.98\n",
      "sin i= 7000 train loss:  0.44  train acc: 0.99  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.44  train acc: 0.99  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.43  train acc: 0.99  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.43  train acc: 0.99  val loss:  0.10  valid acc: 0.98\n",
      "sin i= 9000 train loss:  0.42  train acc: 0.99  val loss:  0.10  valid acc: 0.98\n",
      "sin i= 9500 train loss:  0.42  train acc: 0.99  val loss:  0.11  valid acc: 0.98\n",
      "sin i=10000 train loss:  0.40  train acc: 0.99  val loss:  0.09  valid acc: 0.99\n",
      "->  sin  layer idx: 15 , best valid accuracy: 0.99, test accuracy: 0.98\n",
      "sin i=    0 train loss: 19.38  train acc: 0.00  val loss: 30.92  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.89  train acc: 0.37  val loss:  3.28  valid acc: 0.09\n",
      "sin i= 1000 train loss:  4.00  train acc: 0.58  val loss:  2.23  valid acc: 0.30\n",
      "sin i= 1500 train loss:  3.27  train acc: 0.51  val loss:  2.11  valid acc: 0.31\n",
      "sin i= 2000 train loss:  2.67  train acc: 0.63  val loss:  1.93  valid acc: 0.36\n",
      "sin i= 2500 train loss:  2.40  train acc: 0.65  val loss:  2.09  valid acc: 0.34\n",
      "sin i= 3000 train loss:  2.44  train acc: 0.57  val loss:  2.39  valid acc: 0.28\n",
      "sin i= 3500 train loss:  2.11  train acc: 0.62  val loss:  2.02  valid acc: 0.35\n",
      "sin i= 4000 train loss:  2.07  train acc: 0.65  val loss:  1.96  valid acc: 0.36\n",
      "sin i= 4500 train loss:  2.01  train acc: 0.65  val loss:  2.14  valid acc: 0.33\n",
      "sin i= 5000 train loss:  1.97  train acc: 0.65  val loss:  2.01  valid acc: 0.35\n",
      "sin i= 5500 train loss:  2.02  train acc: 0.62  val loss:  1.96  valid acc: 0.36\n",
      "sin i= 6000 train loss:  1.98  train acc: 0.62  val loss:  1.95  valid acc: 0.38\n",
      "sin i= 6500 train loss:  9.75  train acc: 0.44  val loss:  7.07  valid acc: 0.26\n",
      "sin i= 7000 train loss:  9.40  train acc: 0.41  val loss:  6.40  valid acc: 0.24\n",
      "sin i= 7500 train loss:  5.81  train acc: 0.62  val loss:  3.03  valid acc: 0.34\n",
      "sin i= 8000 train loss:  5.32  train acc: 0.62  val loss:  3.41  valid acc: 0.34\n",
      "sin i= 8500 train loss: 95.52  train acc: 0.05  val loss: 80.79  valid acc: 0.03\n",
      "sin i= 9000 train loss:  4.29  train acc: 0.66  val loss:  2.30  valid acc: 0.38\n",
      "sin i= 9500 train loss:  3.61  train acc: 0.70  val loss:  2.22  valid acc: 0.38\n",
      "sin i=10000 train loss:  6.58  train acc: 0.59  val loss:  4.40  valid acc: 0.31\n",
      "->  sin  layer idx: 16 , best valid accuracy: 0.38, test accuracy: 0.37\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.90  valid acc: 0.00\n",
      "bin i=  500 train loss:  5.26  train acc: 0.22  val loss:  4.18  valid acc: 0.07\n",
      "bin i= 1000 train loss:  4.38  train acc: 0.42  val loss:  3.44  valid acc: 0.18\n",
      "bin i= 1500 train loss:  3.87  train acc: 0.62  val loss:  3.25  valid acc: 0.19\n",
      "bin i= 2000 train loss:  3.50  train acc: 0.79  val loss:  3.20  valid acc: 0.19\n",
      "bin i= 2500 train loss:  3.24  train acc: 0.87  val loss:  3.16  valid acc: 0.21\n",
      "bin i= 3000 train loss:  3.04  train acc: 0.92  val loss:  3.14  valid acc: 0.28\n",
      "bin i= 3500 train loss:  2.87  train acc: 0.96  val loss:  3.12  valid acc: 0.23\n",
      "bin i= 4000 train loss:  2.73  train acc: 0.98  val loss:  3.10  valid acc: 0.28\n",
      "bin i= 4500 train loss:  2.63  train acc: 0.99  val loss:  3.09  valid acc: 0.28\n",
      "bin i= 5000 train loss:  2.52  train acc: 0.99  val loss:  3.07  valid acc: 0.26\n",
      "bin i= 5500 train loss:  2.44  train acc: 0.99  val loss:  3.07  valid acc: 0.26\n",
      "bin i= 6000 train loss:  2.37  train acc: 1.00  val loss:  3.07  valid acc: 0.25\n",
      "bin i= 6500 train loss:  2.32  train acc: 1.00  val loss:  3.08  valid acc: 0.21\n",
      "bin i= 7000 train loss:  2.26  train acc: 1.00  val loss:  3.09  valid acc: 0.19\n",
      "bin i= 7500 train loss:  2.21  train acc: 1.00  val loss:  3.10  valid acc: 0.18\n",
      "bin i= 8000 train loss:  2.18  train acc: 1.00  val loss:  3.11  valid acc: 0.19\n",
      "bin i= 8500 train loss:  2.15  train acc: 1.00  val loss:  3.12  valid acc: 0.18\n",
      "bin i= 9000 train loss:  2.11  train acc: 1.00  val loss:  3.14  valid acc: 0.18\n",
      "bin i= 9500 train loss:  2.08  train acc: 1.00  val loss:  3.16  valid acc: 0.18\n",
      "bin i=10000 train loss:  2.06  train acc: 1.00  val loss:  3.22  valid acc: 0.21\n",
      "->  bin  layer idx: 0  , best valid accuracy: 0.28, test accuracy: 0.22\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.85  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.56  train acc: 0.15  val loss:  3.90  valid acc: 0.05\n",
      "bin i= 1000 train loss:  4.07  train acc: 0.32  val loss:  3.50  valid acc: 0.07\n",
      "bin i= 1500 train loss:  3.69  train acc: 0.46  val loss:  3.27  valid acc: 0.13\n",
      "bin i= 2000 train loss:  3.32  train acc: 0.63  val loss:  3.22  valid acc: 0.17\n",
      "bin i= 2500 train loss:  3.04  train acc: 0.74  val loss:  3.22  valid acc: 0.17\n",
      "bin i= 3000 train loss:  2.85  train acc: 0.80  val loss:  3.24  valid acc: 0.16\n",
      "bin i= 3500 train loss:  2.66  train acc: 0.87  val loss:  3.31  valid acc: 0.15\n",
      "bin i= 4000 train loss:  2.54  train acc: 0.90  val loss:  3.36  valid acc: 0.17\n",
      "bin i= 4500 train loss:  2.44  train acc: 0.92  val loss:  3.39  valid acc: 0.16\n",
      "bin i= 5000 train loss:  2.37  train acc: 0.91  val loss:  3.46  valid acc: 0.16\n",
      "bin i= 5500 train loss:  2.30  train acc: 0.93  val loss:  3.52  valid acc: 0.16\n",
      "bin i= 6000 train loss:  2.23  train acc: 0.93  val loss:  3.54  valid acc: 0.16\n",
      "bin i= 6500 train loss:  2.16  train acc: 0.95  val loss:  3.56  valid acc: 0.16\n",
      "bin i= 7000 train loss:  2.12  train acc: 0.95  val loss:  3.62  valid acc: 0.17\n",
      "bin i= 7500 train loss:  2.09  train acc: 0.94  val loss:  3.60  valid acc: 0.17\n",
      "bin i= 8000 train loss:  2.02  train acc: 0.95  val loss:  3.63  valid acc: 0.16\n",
      "bin i= 8500 train loss:  1.99  train acc: 0.96  val loss:  3.64  valid acc: 0.17\n",
      "bin i= 9000 train loss:  1.95  train acc: 0.97  val loss:  3.67  valid acc: 0.16\n",
      "bin i= 9500 train loss:  1.93  train acc: 0.96  val loss:  3.71  valid acc: 0.16\n",
      "bin i=10000 train loss:  1.91  train acc: 0.96  val loss:  3.73  valid acc: 0.17\n",
      "->  bin  layer idx: 1  , best valid accuracy: 0.17, test accuracy: 0.16\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.85  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.56  train acc: 0.14  val loss:  3.99  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.21  train acc: 0.21  val loss:  3.68  valid acc: 0.06\n",
      "bin i= 1500 train loss:  3.96  train acc: 0.32  val loss:  3.50  valid acc: 0.10\n",
      "bin i= 2000 train loss:  3.63  train acc: 0.43  val loss:  3.41  valid acc: 0.11\n",
      "bin i= 2500 train loss:  3.34  train acc: 0.52  val loss:  3.29  valid acc: 0.13\n",
      "bin i= 3000 train loss:  3.14  train acc: 0.60  val loss:  3.25  valid acc: 0.13\n",
      "bin i= 3500 train loss:  2.94  train acc: 0.66  val loss:  3.25  valid acc: 0.15\n",
      "bin i= 4000 train loss:  2.84  train acc: 0.67  val loss:  3.25  valid acc: 0.16\n",
      "bin i= 4500 train loss:  2.73  train acc: 0.73  val loss:  3.25  valid acc: 0.16\n",
      "bin i= 5000 train loss:  2.65  train acc: 0.74  val loss:  3.29  valid acc: 0.17\n",
      "bin i= 5500 train loss:  2.63  train acc: 0.71  val loss:  3.33  valid acc: 0.17\n",
      "bin i= 6000 train loss:  2.53  train acc: 0.75  val loss:  3.36  valid acc: 0.17\n",
      "bin i= 6500 train loss:  2.48  train acc: 0.76  val loss:  3.37  valid acc: 0.18\n",
      "bin i= 7000 train loss:  2.48  train acc: 0.75  val loss:  3.41  valid acc: 0.17\n",
      "bin i= 7500 train loss:  2.46  train acc: 0.76  val loss:  3.43  valid acc: 0.17\n",
      "bin i= 8000 train loss:  2.38  train acc: 0.79  val loss:  3.42  valid acc: 0.17\n",
      "bin i= 8500 train loss:  2.32  train acc: 0.81  val loss:  3.44  valid acc: 0.18\n",
      "bin i= 9000 train loss:  2.30  train acc: 0.79  val loss:  3.45  valid acc: 0.18\n",
      "bin i= 9500 train loss:  2.28  train acc: 0.79  val loss:  3.51  valid acc: 0.19\n",
      "bin i=10000 train loss:  2.26  train acc: 0.80  val loss:  3.50  valid acc: 0.18\n",
      "->  bin  layer idx: 2  , best valid accuracy: 0.19, test accuracy: 0.15\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.85  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.56  train acc: 0.09  val loss:  4.06  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.34  train acc: 0.13  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.25  train acc: 0.15  val loss:  3.77  valid acc: 0.05\n",
      "bin i= 2000 train loss:  4.05  train acc: 0.18  val loss:  3.71  valid acc: 0.04\n",
      "bin i= 2500 train loss:  3.83  train acc: 0.27  val loss:  3.62  valid acc: 0.06\n",
      "bin i= 3000 train loss:  3.72  train acc: 0.27  val loss:  3.54  valid acc: 0.09\n",
      "bin i= 3500 train loss:  3.54  train acc: 0.33  val loss:  3.56  valid acc: 0.08\n",
      "bin i= 4000 train loss:  3.43  train acc: 0.38  val loss:  3.54  valid acc: 0.09\n",
      "bin i= 4500 train loss:  3.34  train acc: 0.38  val loss:  3.54  valid acc: 0.09\n",
      "bin i= 5000 train loss:  3.24  train acc: 0.44  val loss:  3.57  valid acc: 0.10\n",
      "bin i= 5500 train loss:  3.20  train acc: 0.42  val loss:  3.60  valid acc: 0.11\n",
      "bin i= 6000 train loss:  3.10  train acc: 0.51  val loss:  3.59  valid acc: 0.11\n",
      "bin i= 6500 train loss:  3.07  train acc: 0.50  val loss:  3.56  valid acc: 0.11\n",
      "bin i= 7000 train loss:  2.98  train acc: 0.52  val loss:  3.59  valid acc: 0.12\n",
      "bin i= 7500 train loss:  2.99  train acc: 0.53  val loss:  3.61  valid acc: 0.12\n",
      "bin i= 8000 train loss:  2.95  train acc: 0.55  val loss:  3.61  valid acc: 0.12\n",
      "bin i= 8500 train loss:  2.91  train acc: 0.54  val loss:  3.61  valid acc: 0.13\n",
      "bin i= 9000 train loss:  2.87  train acc: 0.55  val loss:  3.60  valid acc: 0.13\n",
      "bin i= 9500 train loss:  2.85  train acc: 0.54  val loss:  3.63  valid acc: 0.13\n",
      "bin i=10000 train loss:  2.81  train acc: 0.56  val loss:  3.65  valid acc: 0.14\n",
      "->  bin  layer idx: 3  , best valid accuracy: 0.14, test accuracy: 0.13\n",
      "bin i=    0 train loss:  9.31  train acc: 0.00  val loss:  6.83  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.49  train acc: 0.09  val loss:  4.08  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.32  train acc: 0.11  val loss:  3.97  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.30  train acc: 0.12  val loss:  3.91  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.17  train acc: 0.12  val loss:  3.87  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.03  train acc: 0.15  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 3000 train loss:  4.02  train acc: 0.14  val loss:  3.74  valid acc: 0.07\n",
      "bin i= 3500 train loss:  3.88  train acc: 0.15  val loss:  3.74  valid acc: 0.05\n",
      "bin i= 4000 train loss:  3.79  train acc: 0.17  val loss:  3.71  valid acc: 0.05\n",
      "bin i= 4500 train loss:  3.77  train acc: 0.16  val loss:  3.70  valid acc: 0.05\n",
      "bin i= 5000 train loss:  3.66  train acc: 0.20  val loss:  3.68  valid acc: 0.06\n",
      "bin i= 5500 train loss:  3.63  train acc: 0.19  val loss:  3.66  valid acc: 0.06\n",
      "bin i= 6000 train loss:  3.59  train acc: 0.19  val loss:  3.70  valid acc: 0.04\n",
      "bin i= 6500 train loss:  3.56  train acc: 0.20  val loss:  3.64  valid acc: 0.06\n",
      "bin i= 7000 train loss:  3.50  train acc: 0.21  val loss:  3.65  valid acc: 0.06\n",
      "bin i= 7500 train loss:  3.49  train acc: 0.24  val loss:  3.65  valid acc: 0.06\n",
      "bin i= 8000 train loss:  3.45  train acc: 0.23  val loss:  3.65  valid acc: 0.07\n",
      "bin i= 8500 train loss:  3.47  train acc: 0.23  val loss:  3.65  valid acc: 0.06\n",
      "bin i= 9000 train loss:  3.44  train acc: 0.23  val loss:  3.62  valid acc: 0.07\n",
      "bin i= 9500 train loss:  3.39  train acc: 0.26  val loss:  3.64  valid acc: 0.07\n",
      "bin i=10000 train loss:  3.34  train acc: 0.27  val loss:  3.64  valid acc: 0.07\n",
      "->  bin  layer idx: 4  , best valid accuracy: 0.07, test accuracy: 0.06\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.83  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.47  train acc: 0.08  val loss:  4.10  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.33  train acc: 0.09  val loss:  3.99  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.30  train acc: 0.12  val loss:  3.96  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.20  train acc: 0.11  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.08  train acc: 0.14  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.06  train acc: 0.12  val loss:  3.84  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.97  train acc: 0.11  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.92  train acc: 0.12  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.91  train acc: 0.11  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.83  train acc: 0.13  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.80  train acc: 0.13  val loss:  3.84  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.78  train acc: 0.11  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.77  train acc: 0.14  val loss:  3.79  valid acc: 0.06\n",
      "bin i= 7000 train loss:  3.72  train acc: 0.13  val loss:  3.80  valid acc: 0.05\n",
      "bin i= 7500 train loss:  3.76  train acc: 0.14  val loss:  3.80  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.72  train acc: 0.13  val loss:  3.80  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.71  train acc: 0.14  val loss:  3.81  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.73  train acc: 0.14  val loss:  3.78  valid acc: 0.05\n",
      "bin i= 9500 train loss:  3.67  train acc: 0.14  val loss:  3.77  valid acc: 0.05\n",
      "bin i=10000 train loss:  3.65  train acc: 0.17  val loss:  3.75  valid acc: 0.05\n",
      "->  bin  layer idx: 5  , best valid accuracy: 0.06, test accuracy: 0.03\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.86  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.47  train acc: 0.08  val loss:  4.11  valid acc: 0.02\n",
      "bin i= 1000 train loss:  4.33  train acc: 0.09  val loss:  4.02  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.31  train acc: 0.12  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.22  train acc: 0.11  val loss:  3.98  valid acc: 0.01\n",
      "bin i= 2500 train loss:  4.10  train acc: 0.12  val loss:  3.94  valid acc: 0.02\n",
      "bin i= 3000 train loss:  4.10  train acc: 0.11  val loss:  3.88  valid acc: 0.05\n",
      "bin i= 3500 train loss:  4.02  train acc: 0.10  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 4000 train loss:  3.98  train acc: 0.11  val loss:  3.89  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.96  train acc: 0.10  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.89  train acc: 0.13  val loss:  3.87  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.86  train acc: 0.11  val loss:  3.87  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.84  train acc: 0.11  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.82  train acc: 0.13  val loss:  3.83  valid acc: 0.05\n",
      "bin i= 7000 train loss:  3.78  train acc: 0.11  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.83  train acc: 0.11  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.81  train acc: 0.11  val loss:  3.86  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.77  train acc: 0.12  val loss:  3.86  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.78  train acc: 0.12  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 9500 train loss:  3.75  train acc: 0.13  val loss:  3.85  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.72  train acc: 0.15  val loss:  3.83  valid acc: 0.04\n",
      "->  bin  layer idx: 6  , best valid accuracy: 0.05, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.89  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.49  train acc: 0.09  val loss:  4.12  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.35  train acc: 0.10  val loss:  4.03  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.33  train acc: 0.11  val loss:  4.01  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.25  train acc: 0.10  val loss:  3.99  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.13  train acc: 0.11  val loss:  3.96  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.12  train acc: 0.11  val loss:  3.91  valid acc: 0.04\n",
      "bin i= 3500 train loss:  4.06  train acc: 0.10  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 4000 train loss:  4.02  train acc: 0.10  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.99  train acc: 0.10  val loss:  3.92  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.94  train acc: 0.12  val loss:  3.91  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.90  train acc: 0.11  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.87  train acc: 0.11  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.87  train acc: 0.12  val loss:  3.87  valid acc: 0.05\n",
      "bin i= 7000 train loss:  3.83  train acc: 0.10  val loss:  3.89  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.86  train acc: 0.11  val loss:  3.88  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.86  train acc: 0.11  val loss:  3.88  valid acc: 0.04\n",
      "bin i= 8500 train loss:  3.82  train acc: 0.12  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.81  train acc: 0.12  val loss:  3.89  valid acc: 0.03\n",
      "bin i= 9500 train loss:  3.79  train acc: 0.12  val loss:  3.87  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.75  train acc: 0.15  val loss:  3.84  valid acc: 0.04\n",
      "->  bin  layer idx: 7  , best valid accuracy: 0.05, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.89  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.45  train acc: 0.08  val loss:  4.12  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.34  train acc: 0.09  val loss:  4.03  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.32  train acc: 0.10  val loss:  4.02  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.23  train acc: 0.09  val loss:  4.00  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.12  train acc: 0.12  val loss:  3.98  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.11  train acc: 0.11  val loss:  3.93  valid acc: 0.04\n",
      "bin i= 3500 train loss:  4.03  train acc: 0.10  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 4000 train loss:  4.00  train acc: 0.10  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.99  train acc: 0.11  val loss:  3.92  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.94  train acc: 0.12  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.90  train acc: 0.11  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.88  train acc: 0.10  val loss:  3.92  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.88  train acc: 0.12  val loss:  3.88  valid acc: 0.05\n",
      "bin i= 7000 train loss:  3.82  train acc: 0.09  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 7500 train loss:  3.87  train acc: 0.10  val loss:  3.89  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.84  train acc: 0.11  val loss:  3.88  valid acc: 0.04\n",
      "bin i= 8500 train loss:  3.82  train acc: 0.12  val loss:  3.90  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.79  train acc: 0.12  val loss:  3.88  valid acc: 0.04\n",
      "bin i= 9500 train loss:  3.80  train acc: 0.11  val loss:  3.88  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.73  train acc: 0.14  val loss:  3.84  valid acc: 0.04\n",
      "->  bin  layer idx: 8  , best valid accuracy: 0.05, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.89  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.44  train acc: 0.08  val loss:  4.12  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.34  train acc: 0.08  val loss:  4.04  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.30  train acc: 0.10  val loss:  4.01  valid acc: 0.03\n",
      "bin i= 2000 train loss:  4.21  train acc: 0.09  val loss:  3.99  valid acc: 0.02\n",
      "bin i= 2500 train loss:  4.11  train acc: 0.11  val loss:  3.97  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.11  train acc: 0.10  val loss:  3.91  valid acc: 0.05\n",
      "bin i= 3500 train loss:  4.04  train acc: 0.10  val loss:  3.93  valid acc: 0.04\n",
      "bin i= 4000 train loss:  4.00  train acc: 0.12  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.99  train acc: 0.10  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.94  train acc: 0.12  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.88  train acc: 0.11  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 6000 train loss:  3.86  train acc: 0.11  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.86  train acc: 0.12  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 7000 train loss:  3.80  train acc: 0.10  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.86  train acc: 0.11  val loss:  3.89  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.84  train acc: 0.11  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.82  train acc: 0.11  val loss:  3.91  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.81  train acc: 0.10  val loss:  3.88  valid acc: 0.04\n",
      "bin i= 9500 train loss:  3.80  train acc: 0.11  val loss:  3.88  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.72  train acc: 0.15  val loss:  3.84  valid acc: 0.04\n",
      "->  bin  layer idx: 9  , best valid accuracy: 0.05, test accuracy: 0.01\n",
      "bin i=    0 train loss:  9.31  train acc: 0.00  val loss:  6.96  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.38  train acc: 0.08  val loss:  4.12  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.28  train acc: 0.09  val loss:  4.08  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.25  train acc: 0.09  val loss:  4.03  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.15  train acc: 0.08  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 2500 train loss:  4.06  train acc: 0.11  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 3000 train loss:  4.07  train acc: 0.09  val loss:  3.93  valid acc: 0.05\n",
      "bin i= 3500 train loss:  4.00  train acc: 0.09  val loss:  3.97  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.96  train acc: 0.11  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.96  train acc: 0.10  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.91  train acc: 0.12  val loss:  3.95  valid acc: 0.05\n",
      "bin i= 5500 train loss:  3.85  train acc: 0.10  val loss:  3.96  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.84  train acc: 0.12  val loss:  3.94  valid acc: 0.04\n",
      "bin i= 6500 train loss:  3.84  train acc: 0.11  val loss:  3.90  valid acc: 0.05\n",
      "bin i= 7000 train loss:  3.80  train acc: 0.10  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.85  train acc: 0.11  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.84  train acc: 0.11  val loss:  3.91  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.82  train acc: 0.11  val loss:  3.94  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.80  train acc: 0.12  val loss:  3.90  valid acc: 0.05\n",
      "bin i= 9500 train loss:  3.78  train acc: 0.12  val loss:  3.91  valid acc: 0.05\n",
      "bin i=10000 train loss:  3.72  train acc: 0.14  val loss:  3.89  valid acc: 0.04\n",
      "->  bin  layer idx: 10 , best valid accuracy: 0.05, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.32  train acc: 0.00  val loss:  6.99  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.27  train acc: 0.09  val loss:  4.07  valid acc: 0.04\n",
      "bin i= 1000 train loss:  4.18  train acc: 0.10  val loss:  4.05  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.14  train acc: 0.10  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.06  train acc: 0.10  val loss:  3.95  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.97  train acc: 0.12  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.99  train acc: 0.11  val loss:  3.88  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.92  train acc: 0.10  val loss:  3.90  valid acc: 0.05\n",
      "bin i= 4000 train loss:  3.88  train acc: 0.12  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.88  train acc: 0.11  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.84  train acc: 0.11  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 5500 train loss:  3.79  train acc: 0.10  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.76  train acc: 0.12  val loss:  3.93  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.77  train acc: 0.11  val loss:  3.87  valid acc: 0.06\n",
      "bin i= 7000 train loss:  3.74  train acc: 0.11  val loss:  3.89  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.77  train acc: 0.12  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 8000 train loss:  3.77  train acc: 0.12  val loss:  3.88  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.74  train acc: 0.12  val loss:  3.90  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.72  train acc: 0.13  val loss:  3.86  valid acc: 0.05\n",
      "bin i= 9500 train loss:  3.70  train acc: 0.13  val loss:  3.90  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.64  train acc: 0.15  val loss:  3.86  valid acc: 0.04\n",
      "->  bin  layer idx: 11 , best valid accuracy: 0.06, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.33  train acc: 0.00  val loss:  7.13  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.25  train acc: 0.10  val loss:  4.07  valid acc: 0.04\n",
      "bin i= 1000 train loss:  4.15  train acc: 0.10  val loss:  4.06  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.12  train acc: 0.10  val loss:  3.99  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.04  train acc: 0.10  val loss:  3.97  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.95  train acc: 0.12  val loss:  3.98  valid acc: 0.03\n",
      "bin i= 3000 train loss:  3.97  train acc: 0.11  val loss:  3.90  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.87  train acc: 0.10  val loss:  3.92  valid acc: 0.05\n",
      "bin i= 4000 train loss:  3.86  train acc: 0.11  val loss:  3.93  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.86  train acc: 0.10  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.83  train acc: 0.13  val loss:  3.91  valid acc: 0.05\n",
      "bin i= 5500 train loss:  3.76  train acc: 0.10  val loss:  3.94  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.75  train acc: 0.11  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 6500 train loss:  3.74  train acc: 0.11  val loss:  3.88  valid acc: 0.06\n",
      "bin i= 7000 train loss:  3.72  train acc: 0.11  val loss:  3.92  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.72  train acc: 0.12  val loss:  3.91  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.73  train acc: 0.11  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.72  train acc: 0.12  val loss:  3.94  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.70  train acc: 0.11  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 9500 train loss:  3.68  train acc: 0.13  val loss:  3.91  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.64  train acc: 0.15  val loss:  3.88  valid acc: 0.04\n",
      "->  bin  layer idx: 12 , best valid accuracy: 0.06, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.35  train acc: 0.00  val loss:  7.35  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.25  train acc: 0.10  val loss:  4.07  valid acc: 0.04\n",
      "bin i= 1000 train loss:  4.15  train acc: 0.10  val loss:  4.04  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.12  train acc: 0.10  val loss:  3.98  valid acc: 0.05\n",
      "bin i= 2000 train loss:  4.03  train acc: 0.09  val loss:  3.99  valid acc: 0.04\n",
      "bin i= 2500 train loss:  3.95  train acc: 0.13  val loss:  3.99  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.95  train acc: 0.11  val loss:  3.91  valid acc: 0.06\n",
      "bin i= 3500 train loss:  3.86  train acc: 0.10  val loss:  3.94  valid acc: 0.05\n",
      "bin i= 4000 train loss:  3.84  train acc: 0.11  val loss:  3.96  valid acc: 0.05\n",
      "bin i= 4500 train loss:  3.83  train acc: 0.10  val loss:  3.93  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.81  train acc: 0.11  val loss:  3.93  valid acc: 0.05\n",
      "bin i= 5500 train loss:  3.73  train acc: 0.11  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.72  train acc: 0.11  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 6500 train loss:  3.72  train acc: 0.11  val loss:  3.88  valid acc: 0.06\n",
      "bin i= 7000 train loss:  3.72  train acc: 0.11  val loss:  3.91  valid acc: 0.05\n",
      "bin i= 7500 train loss:  3.71  train acc: 0.11  val loss:  3.91  valid acc: 0.05\n",
      "bin i= 8000 train loss:  3.73  train acc: 0.10  val loss:  3.90  valid acc: 0.06\n",
      "bin i= 8500 train loss:  3.72  train acc: 0.11  val loss:  3.94  valid acc: 0.04\n",
      "bin i= 9000 train loss:  3.70  train acc: 0.10  val loss:  3.89  valid acc: 0.06\n",
      "bin i= 9500 train loss:  3.67  train acc: 0.12  val loss:  3.93  valid acc: 0.05\n",
      "bin i=10000 train loss:  3.65  train acc: 0.13  val loss:  3.92  valid acc: 0.04\n",
      "->  bin  layer idx: 13 , best valid accuracy: 0.06, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.38  train acc: 0.00  val loss:  7.55  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.23  train acc: 0.09  val loss:  4.07  valid acc: 0.04\n",
      "bin i= 1000 train loss:  4.13  train acc: 0.09  val loss:  4.04  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.06  train acc: 0.11  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 2000 train loss:  3.96  train acc: 0.11  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.89  train acc: 0.13  val loss:  3.99  valid acc: 0.03\n",
      "bin i= 3000 train loss:  3.86  train acc: 0.12  val loss:  3.92  valid acc: 0.06\n",
      "bin i= 3500 train loss:  3.79  train acc: 0.11  val loss:  3.96  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.75  train acc: 0.12  val loss:  3.97  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.75  train acc: 0.11  val loss:  3.94  valid acc: 0.04\n",
      "bin i= 5000 train loss:  3.73  train acc: 0.12  val loss:  3.95  valid acc: 0.05\n",
      "bin i= 5500 train loss:  3.68  train acc: 0.12  val loss:  3.96  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.67  train acc: 0.11  val loss:  3.96  valid acc: 0.04\n",
      "bin i= 6500 train loss:  3.68  train acc: 0.12  val loss:  3.92  valid acc: 0.06\n",
      "bin i= 7000 train loss:  3.68  train acc: 0.11  val loss:  3.95  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.67  train acc: 0.12  val loss:  3.96  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.67  train acc: 0.12  val loss:  3.93  valid acc: 0.05\n",
      "bin i= 8500 train loss:  3.66  train acc: 0.11  val loss:  3.99  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.62  train acc: 0.11  val loss:  3.92  valid acc: 0.05\n",
      "bin i= 9500 train loss:  3.65  train acc: 0.12  val loss:  3.97  valid acc: 0.04\n",
      "bin i=10000 train loss:  3.61  train acc: 0.15  val loss:  3.96  valid acc: 0.03\n",
      "->  bin  layer idx: 14 , best valid accuracy: 0.06, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.46  train acc: 0.00  val loss:  8.59  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.27  train acc: 0.09  val loss:  4.12  valid acc: 0.03\n",
      "bin i= 1000 train loss:  4.20  train acc: 0.09  val loss:  4.06  valid acc: 0.03\n",
      "bin i= 1500 train loss:  4.09  train acc: 0.10  val loss:  4.00  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.00  train acc: 0.10  val loss:  4.03  valid acc: 0.02\n",
      "bin i= 2500 train loss:  3.91  train acc: 0.12  val loss:  4.03  valid acc: 0.03\n",
      "bin i= 3000 train loss:  3.89  train acc: 0.11  val loss:  3.96  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.81  train acc: 0.10  val loss:  4.00  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.77  train acc: 0.11  val loss:  4.01  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.77  train acc: 0.10  val loss:  4.03  valid acc: 0.03\n",
      "bin i= 5000 train loss:  3.77  train acc: 0.11  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.74  train acc: 0.10  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.72  train acc: 0.09  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 6500 train loss:  3.74  train acc: 0.12  val loss:  3.96  valid acc: 0.05\n",
      "bin i= 7000 train loss:  3.72  train acc: 0.11  val loss:  3.97  valid acc: 0.04\n",
      "bin i= 7500 train loss:  3.70  train acc: 0.12  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 8000 train loss:  3.73  train acc: 0.11  val loss:  3.99  valid acc: 0.04\n",
      "bin i= 8500 train loss:  3.71  train acc: 0.11  val loss:  4.00  valid acc: 0.03\n",
      "bin i= 9000 train loss:  3.65  train acc: 0.11  val loss:  3.94  valid acc: 0.04\n",
      "bin i= 9500 train loss:  3.69  train acc: 0.12  val loss:  3.99  valid acc: 0.03\n",
      "bin i=10000 train loss:  3.68  train acc: 0.13  val loss:  3.98  valid acc: 0.03\n",
      "->  bin  layer idx: 15 , best valid accuracy: 0.05, test accuracy: 0.03\n",
      "bin i=    0 train loss: 12.84  train acc: 0.00  val loss: 52.06  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.54  train acc: 0.04  val loss:  5.54  valid acc: 0.01\n",
      "bin i= 1000 train loss:  5.77  train acc: 0.04  val loss:  5.13  valid acc: 0.02\n",
      "bin i= 1500 train loss:  5.56  train acc: 0.04  val loss:  5.19  valid acc: 0.02\n",
      "bin i= 2000 train loss:  5.38  train acc: 0.05  val loss:  5.23  valid acc: 0.01\n",
      "bin i= 2500 train loss:  5.22  train acc: 0.06  val loss:  5.14  valid acc: 0.02\n",
      "bin i= 3000 train loss:  5.26  train acc: 0.03  val loss:  5.07  valid acc: 0.01\n",
      "bin i= 3500 train loss:  5.10  train acc: 0.04  val loss:  5.01  valid acc: 0.02\n",
      "bin i= 4000 train loss:  5.01  train acc: 0.04  val loss:  4.91  valid acc: 0.02\n",
      "bin i= 4500 train loss:  5.16  train acc: 0.04  val loss:  4.94  valid acc: 0.02\n",
      "bin i= 5000 train loss:  4.89  train acc: 0.05  val loss:  5.10  valid acc: 0.01\n",
      "bin i= 5500 train loss:  4.93  train acc: 0.06  val loss:  4.93  valid acc: 0.01\n",
      "bin i= 6000 train loss:  4.85  train acc: 0.05  val loss:  5.04  valid acc: 0.01\n",
      "bin i= 6500 train loss:  4.87  train acc: 0.05  val loss:  4.83  valid acc: 0.02\n",
      "bin i= 7000 train loss:  4.79  train acc: 0.05  val loss:  4.84  valid acc: 0.01\n",
      "bin i= 7500 train loss:  4.79  train acc: 0.05  val loss:  5.06  valid acc: 0.01\n",
      "bin i= 8000 train loss:  4.85  train acc: 0.05  val loss:  4.88  valid acc: 0.02\n",
      "bin i= 8500 train loss:  4.74  train acc: 0.06  val loss:  4.99  valid acc: 0.01\n",
      "bin i= 9000 train loss:  4.82  train acc: 0.05  val loss:  4.99  valid acc: 0.02\n",
      "bin i= 9500 train loss:  4.76  train acc: 0.05  val loss:  4.88  valid acc: 0.02\n",
      "bin i=10000 train loss:  4.71  train acc: 0.07  val loss:  4.89  valid acc: 0.02\n",
      "->  bin  layer idx: 16 , best valid accuracy: 0.02, test accuracy: 0.01\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:17:55.156343Z",
     "start_time": "2025-09-15T20:17:55.151554Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def solve_linear_layer(x: Tensor, y: Tensor) -> torch.nn.Linear:\n",
    "    if y.ndim == 1:\n",
    "        y = y.unsqueeze(-1)\n",
    "    if not y.is_floating_point():\n",
    "        y = y.float()\n",
    "   \n",
    "    lin = torch.nn.Linear(x.shape[-1], y.shape[-1], device=x.device)\n",
    "    x_aug = torch.cat([x, torch.ones(len(x), 1, device=x.device)], dim=1)\n",
    "    coeffs = torch.linalg.lstsq(x_aug, y).solution\n",
    "    w, b = coeffs[:-1], coeffs[-1]\n",
    "    with torch.no_grad():\n",
    "        lin.weight[:] = w.T\n",
    "        lin.bias[:] = b\n",
    "    return lin"
   ],
   "id": "180a6d66da68307d",
   "outputs": [],
   "execution_count": 17
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:20:06.775553Z",
     "start_time": "2025-09-15T20:17:55.348246Z"
    }
   },
   "cell_type": "code",
   "source": [
    "for layer_idx in range(len(train_hidden_states)):\n",
    "    lin_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.to(device),\n",
    "    )\n",
    "    log_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.log1p().to(device),\n",
    "    )\n",
    "    lin_test_pred = lin_probe(test_hidden_states[layer_idx].float().to(device)).flatten().round().int()\n",
    "    lin_test_accuracy = (lin_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    log_test_pred = log_probe(test_hidden_states[layer_idx].float().to(device)).flatten().exp().add(1).round().int()\n",
    "    log_test_accuracy = (log_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    test_accuracies[\"lin\"][layer_idx] = lin_test_accuracy\n",
    "    test_accuracies[\"log\"][layer_idx] = log_test_accuracy\n",
    "\n",
    "    print(f\"layer idx: {layer_idx:<3}, linear probe acc: {lin_test_accuracy:.2f}, log probe acc: {log_test_accuracy:.2f}\")"
   ],
   "id": "6f18c4fd0b785e1b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer idx: 0  , linear probe acc: 0.00, log probe acc: 0.00\n",
      "layer idx: 1  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 2  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 3  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 4  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 5  , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 6  , linear probe acc: 0.01, log probe acc: 0.02\n",
      "layer idx: 7  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 8  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 9  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 10 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 11 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 12 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 13 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 14 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 15 , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 16 , linear probe acc: 0.00, log probe acc: 0.00\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-15T20:20:07.116504Z",
     "start_time": "2025-09-15T20:20:07.111295Z"
    }
   },
   "cell_type": "code",
   "source": [
    "for name, accs in test_accuracies.items():\n",
    "    print(f\"{name} accs: | \" + \" | \".join([f\"{x:.0%}\" for layer, x in sorted(accs.items())]) + \" |\")"
   ],
   "id": "9d2a87552870d319",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin accs: | 89% | 98% | 98% | 99% | 100% | 100% | 99% | 99% | 99% | 99% | 99% | 99% | 99% | 99% | 99% | 98% | 37% |\n",
      "bin accs: | 22% | 16% | 15% | 13% | 6% | 3% | 2% | 2% | 2% | 1% | 2% | 2% | 2% | 2% | 2% | 3% | 1% |\n",
      "lin accs: | 0% | 1% | 2% | 2% | 2% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 0% |\n",
      "log accs: | 0% | 1% | 2% | 2% | 2% | 2% | 2% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 0% |\n"
     ]
    }
   ],
   "execution_count": 19
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
