{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7a542467",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:7\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b9c875c",
   "metadata": {},
   "source": [
    "### Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3b08b50c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import random\n",
    "import collections\n",
    "\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm.auto\n",
    "from torch import Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "da5bca2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sinusoidal_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int,\n",
    "    max_value: int,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional\n",
    "    encoding works in transformers.\n",
    "\n",
    "    The encoding is an evaluation of a sine and cosine function at different frequencies, where the\n",
    "    frequency is determined by the embedding dimension and the allowed range of the input values.\n",
    "\n",
    "    >>> sinusoidal_encode(\n",
    "    ...     torch.tensor([-5, 2, 1, 0]),\n",
    "    ...     embedding_dim=6,\n",
    "    ...     min_value=-5,\n",
    "    ...     max_value=5,\n",
    "    ... )\n",
    "    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
    "            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],\n",
    "            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],\n",
    "            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])\n",
    "    \"\"\"\n",
    "\n",
    "    if embedding_dim % 2 != 0 and not use_l2_norm:\n",
    "        raise ValueError(\"Embedding dimension must be even\")\n",
    "\n",
    "    if use_l2_norm:\n",
    "        if embedding_dim % 2 == 0:\n",
    "            reserved_dim = 2\n",
    "        else:\n",
    "            reserved_dim = 1\n",
    "        embedding_dim -= reserved_dim\n",
    "    else:\n",
    "        reserved_dim = 0  # will not be used\n",
    "\n",
    "    domain = max_value - min_value\n",
    "    y_shape = x.shape + (embedding_dim,)\n",
    "    y = torch.zeros(y_shape, device=x.device)\n",
    "    even_indices = torch.arange(0, embedding_dim, 2)\n",
    "    log_term = torch.log(torch.tensor(domain)) / embedding_dim\n",
    "    div_term = torch.exp(even_indices * -log_term)\n",
    "    x = x - min_value\n",
    "    values = x.unsqueeze(-1).float() * div_term\n",
    "    y[..., 0::2] = torch.sin(values)\n",
    "    y[..., 1::2] = torch.cos(values)\n",
    "\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "\n",
    "    return y\n",
    "\n",
    "def binary_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int | float,\n",
    "    max_value: int | float,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)\n",
    "    reserve_dim = 0 if not use_l2_norm else 1\n",
    "    x = x - min_value\n",
    "    maximum = x.max()\n",
    "    for i in range(embedding_dim - reserve_dim):\n",
    "        coeff = 2**i\n",
    "        if maximum < coeff:\n",
    "            break\n",
    "        y[..., -i - 1] = torch.floor(x / coeff) % 2\n",
    "        x = x - coeff * y[..., -i - 1]\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ff97ee3",
   "metadata": {},
   "source": [
    "### Prepare model and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8723d3d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-1B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = model.half().to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9b394470",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))\n",
    "train_mask = mask < 0.9\n",
    "valid_mask = ~train_mask & (mask < 0.95)\n",
    "test_mask = ~train_mask & ~valid_mask\n",
    "\n",
    "train_values = all_values[train_mask]\n",
    "valid_values = all_values[valid_mask]\n",
    "test_values = all_values[test_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c15ab591",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_inputs = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x1 + x2 < 1000]\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "        \n",
    "train_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in train_values_set]\n",
    "valid_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in valid_values_set]\n",
    "test_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in test_values_set]\n",
    "\n",
    "# sanity check\n",
    "assert set(train_inputs) & set(valid_inputs) == set()\n",
    "assert set(train_inputs) & set(test_inputs) == set()\n",
    "assert set(valid_inputs) & set(test_inputs) == set()\n",
    "\n",
    "random.seed(0)\n",
    "random.shuffle(train_inputs)\n",
    "random.shuffle(valid_inputs)\n",
    "random.shuffle(test_inputs)\n",
    "valid_size = 4096\n",
    "train_size = 100_000\n",
    "train_inputs = train_inputs[:train_size]\n",
    "valid_inputs = valid_inputs[:valid_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c41d1f88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('3 + 500', '3 + 0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def make_str_input(operands: tuple[int, int] | list[int]) -> str:\n",
    "    x1, x2 = operands\n",
    "    return f\"{x1} + {x2}\"\n",
    "\n",
    "make_str_input((3, 500)), make_str_input((3, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "36dd2f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hidden_states(model, str_inputs: list[str], batch_size: int) -> collections.defaultdict[int, Tensor]:\n",
    "    model.eval()\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\")\n",
    "            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "da87aca4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a8dd26122f84877b06895cdf0eb9eb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed62f9697fd6432d984de595d1f4ae87",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d13e38f7c8c47b7ad08779d04a99507",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch_size = 1024\n",
    "train_hidden_states = get_hidden_states(model, [make_str_input(val) for val in train_inputs], batch_size)\n",
    "valid_hidden_states = get_hidden_states(model, [make_str_input(val) for val in valid_inputs], batch_size)\n",
    "test_hidden_states = get_hidden_states(model, [make_str_input(val) for val in test_inputs], batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6288a397",
   "metadata": {},
   "source": [
    "### Probing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cff73d20",
   "metadata": {},
   "outputs": [],
   "source": [
    "basis_embs_sin = sinusoidal_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=train_hidden_states[0].shape[-1],\n",
    ")\n",
    "\n",
    "basis_embs_bin = binary_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=10,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "184992d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierProbe(torch.nn.Module):\n",
    "    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):\n",
    "        super().__init__()\n",
    "        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)\n",
    "        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)\n",
    "        self.basis: torch.nn.Buffer\n",
    "        self.heldout_mask: torch.nn.Buffer\n",
    "        self.register_buffer(\"basis\", basis)\n",
    "        self.register_buffer(\"heldout_mask\", heldout_mask)\n",
    "    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:\n",
    "        latent_x = self.emb_to_latent(x)\n",
    "        # during training, model learns to choose among only training tokens\n",
    "        # but during eval, model must choose among all tokens\n",
    "        # this means that the model is never exposed to the eval tokens during training\n",
    "        latent_choices = self.basis_to_latent(self.basis)\n",
    "        logits = latent_x @ latent_choices.T\n",
    "        if holdout_eval_tokens:\n",
    "            logits[:, self.heldout_mask] = float(\"-inf\")\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2d3d2c7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.72  valid acc: 0.00\n",
      "sin i=  500 train loss:  2.23  train acc: 0.99  val loss:  1.36  valid acc: 0.83\n",
      "sin i= 1000 train loss:  1.83  train acc: 0.99  val loss:  0.94  valid acc: 0.92\n",
      "sin i= 1500 train loss:  1.66  train acc: 1.00  val loss:  0.78  valid acc: 0.88\n",
      "sin i= 2000 train loss:  1.56  train acc: 0.99  val loss:  0.74  valid acc: 0.89\n",
      "sin i= 2500 train loss:  1.50  train acc: 0.99  val loss:  0.73  valid acc: 0.84\n",
      "sin i= 3000 train loss:  1.41  train acc: 1.00  val loss:  0.71  valid acc: 0.84\n",
      "sin i= 3500 train loss:  1.35  train acc: 1.00  val loss:  0.71  valid acc: 0.82\n",
      "sin i= 4000 train loss:  1.32  train acc: 1.00  val loss:  0.71  valid acc: 0.82\n",
      "sin i= 4500 train loss:  1.28  train acc: 1.00  val loss:  0.72  valid acc: 0.80\n",
      "sin i= 5000 train loss:  1.26  train acc: 1.00  val loss:  0.68  valid acc: 0.80\n",
      "sin i= 5500 train loss:  1.22  train acc: 1.00  val loss:  0.69  valid acc: 0.82\n",
      "sin i= 6000 train loss:  1.18  train acc: 0.99  val loss:  0.69  valid acc: 0.83\n",
      "sin i= 6500 train loss:  1.17  train acc: 1.00  val loss:  0.68  valid acc: 0.83\n",
      "sin i= 7000 train loss:  1.14  train acc: 1.00  val loss:  0.67  valid acc: 0.85\n",
      "sin i= 7500 train loss:  1.13  train acc: 1.00  val loss:  0.65  valid acc: 0.85\n",
      "sin i= 8000 train loss:  1.11  train acc: 1.00  val loss:  0.64  valid acc: 0.83\n",
      "sin i= 8500 train loss:  1.09  train acc: 1.00  val loss:  0.62  valid acc: 0.86\n",
      "sin i= 9000 train loss:  1.08  train acc: 1.00  val loss:  0.61  valid acc: 0.83\n",
      "sin i= 9500 train loss:  1.06  train acc: 1.00  val loss:  0.61  valid acc: 0.83\n",
      "sin i=10000 train loss:  1.04  train acc: 1.00  val loss:  0.58  valid acc: 0.86\n",
      "->  sin  layer idx: 0  , best valid accuracy: 0.92, test accuracy: 0.89\n",
      "sin i=    0 train loss: 11.40  train acc: 0.00  val loss:  6.48  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.38  train acc: 1.00  val loss:  0.47  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.07  train acc: 1.00  val loss:  0.29  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.91  train acc: 1.00  val loss:  0.22  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.79  train acc: 1.00  val loss:  0.18  valid acc: 0.99\n",
      "sin i= 2500 train loss:  0.70  train acc: 1.00  val loss:  0.16  valid acc: 0.98\n",
      "sin i= 3000 train loss:  0.62  train acc: 1.00  val loss:  0.14  valid acc: 0.97\n",
      "sin i= 3500 train loss:  0.56  train acc: 1.00  val loss:  0.13  valid acc: 0.97\n",
      "sin i= 4000 train loss:  0.53  train acc: 1.00  val loss:  0.12  valid acc: 0.98\n",
      "sin i= 4500 train loss:  0.51  train acc: 1.00  val loss:  0.12  valid acc: 0.98\n",
      "sin i= 5000 train loss:  0.49  train acc: 1.00  val loss:  0.13  valid acc: 0.96\n",
      "sin i= 5500 train loss:  0.48  train acc: 1.00  val loss:  0.14  valid acc: 0.96\n",
      "sin i= 6000 train loss:  0.47  train acc: 1.00  val loss:  0.15  valid acc: 0.96\n",
      "sin i= 6500 train loss:  0.46  train acc: 1.00  val loss:  0.17  valid acc: 0.96\n",
      "sin i= 7000 train loss:  0.46  train acc: 1.00  val loss:  0.15  valid acc: 0.96\n",
      "sin i= 7500 train loss:  0.45  train acc: 1.00  val loss:  0.15  valid acc: 0.96\n",
      "sin i= 8000 train loss:  0.45  train acc: 1.00  val loss:  0.16  valid acc: 0.97\n",
      "sin i= 8500 train loss:  0.44  train acc: 1.00  val loss:  0.14  valid acc: 0.96\n",
      "sin i= 9000 train loss:  0.44  train acc: 1.00  val loss:  0.15  valid acc: 0.97\n",
      "sin i= 9500 train loss:  0.44  train acc: 1.00  val loss:  0.14  valid acc: 0.96\n",
      "sin i=10000 train loss:  0.43  train acc: 1.00  val loss:  0.13  valid acc: 0.97\n",
      "->  sin  layer idx: 1  , best valid accuracy: 1.00, test accuracy: 0.96\n",
      "sin i=    0 train loss: 11.43  train acc: 0.00  val loss:  6.50  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.31  train acc: 1.00  val loss:  0.43  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.02  train acc: 1.00  val loss:  0.26  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.87  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.76  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.68  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.62  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.58  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.53  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.52  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.50  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.49  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.47  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.44  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.44  train acc: 1.00  val loss:  0.10  valid acc: 0.98\n",
      "sin i= 9500 train loss:  0.43  train acc: 1.00  val loss:  0.10  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.42  train acc: 1.00  val loss:  0.10  valid acc: 0.98\n",
      "->  sin  layer idx: 2  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.43  train acc: 0.00  val loss:  6.59  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.30  train acc: 1.00  val loss:  0.43  valid acc: 1.00\n",
      "sin i= 1000 train loss:  1.00  train acc: 1.00  val loss:  0.25  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.84  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.72  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.64  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.58  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.54  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.52  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.49  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.46  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.42  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.42  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "->  sin  layer idx: 3  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.41  train acc: 0.00  val loss:  6.51  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.20  train acc: 1.00  val loss:  0.38  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.93  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.77  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.68  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.61  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.56  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.42  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.41  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.40  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.39  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.39  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 4  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.42  train acc: 0.00  val loss:  6.63  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.19  train acc: 1.00  val loss:  0.38  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.92  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.79  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.69  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.62  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.57  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.50  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.42  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.39  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "->  sin  layer idx: 5  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.41  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.19  train acc: 1.00  val loss:  0.38  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.92  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.77  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.60  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.56  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.52  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.49  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.48  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.44  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.43  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.42  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.41  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.40  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.40  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.39  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.38  train acc: 1.00  val loss:  0.08  valid acc: 0.99\n",
      "->  sin  layer idx: 6  , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.40  train acc: 0.00  val loss:  6.96  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.19  train acc: 1.00  val loss:  0.39  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.92  train acc: 1.00  val loss:  0.23  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.77  train acc: 1.00  val loss:  0.17  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.67  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.60  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.50  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.48  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.46  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.39  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.37  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.37  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.36  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "->  sin  layer idx: 7  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  7.09  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.15  train acc: 1.00  val loss:  0.35  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.89  train acc: 1.00  val loss:  0.21  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.74  train acc: 1.00  val loss:  0.16  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.65  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.59  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.53  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.49  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.47  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.45  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.42  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.39  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.38  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.37  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.37  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.36  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.36  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "->  sin  layer idx: 8  , best valid accuracy: 1.00, test accuracy: 0.98\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  7.05  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.12  train acc: 1.00  val loss:  0.33  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.86  train acc: 1.00  val loss:  0.19  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.71  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.61  train acc: 1.00  val loss:  0.12  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.55  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.47  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.43  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.42  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.39  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.39  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.38  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.38  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.36  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "->  sin  layer idx: 9  , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.45  train acc: 0.00  val loss:  7.18  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.06  train acc: 1.00  val loss:  0.30  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.81  train acc: 1.00  val loss:  0.18  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.66  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.56  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.50  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.45  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.41  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.39  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.36  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.35  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.34  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.33  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.33  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.33  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.32  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.32  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.31  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.31  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.31  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "->  sin  layer idx: 10 , best valid accuracy: 1.00, test accuracy: 1.00\n",
      "sin i=    0 train loss: 11.50  train acc: 0.00  val loss:  7.66  valid acc: 0.00\n",
      "sin i=  500 train loss:  0.96  train acc: 1.00  val loss:  0.25  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.72  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.59  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.51  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.44  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.39  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.35  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.34  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.33  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.32  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.31  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.30  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.30  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.30  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.29  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "->  sin  layer idx: 11 , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.52  train acc: 0.00  val loss:  8.12  valid acc: 0.00\n",
      "sin i=  500 train loss:  0.93  train acc: 1.00  val loss:  0.26  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.69  train acc: 1.00  val loss:  0.14  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.57  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.48  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.42  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.38  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.36  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.34  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.33  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.32  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.31  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.30  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.30  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.29  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.29  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.28  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.28  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.28  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.28  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.28  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "->  sin  layer idx: 12 , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.63  train acc: 0.00  val loss:  9.47  valid acc: 0.00\n",
      "sin i=  500 train loss:  0.96  train acc: 1.00  val loss:  0.30  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.55  train acc: 1.00  val loss:  0.10  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.48  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.42  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.38  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.35  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.34  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.32  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.32  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.34  train acc: 1.00  val loss:  0.11  valid acc: 0.97\n",
      "sin i= 6000 train loss:  0.30  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8000 train loss:  0.28  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 8500 train loss:  0.28  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.28  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.28  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.28  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "->  sin  layer idx: 13 , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 11.78  train acc: 0.00  val loss:  9.97  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.02  train acc: 1.00  val loss:  0.33  valid acc: 1.00\n",
      "sin i= 1000 train loss:  0.67  train acc: 1.00  val loss:  0.15  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.55  train acc: 1.00  val loss:  0.11  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.48  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.43  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.36  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.34  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.33  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.31  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.30  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.29  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.28  train acc: 1.00  val loss:  0.05  valid acc: 0.99\n",
      "sin i= 7000 train loss:  0.28  train acc: 1.00  val loss:  0.06  valid acc: 0.99\n",
      "sin i= 7500 train loss:  0.27  train acc: 1.00  val loss:  0.05  valid acc: 0.99\n",
      "sin i= 8000 train loss:  0.27  train acc: 1.00  val loss:  0.05  valid acc: 0.99\n",
      "sin i= 8500 train loss:  0.26  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 9000 train loss:  0.26  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.25  train acc: 1.00  val loss:  0.05  valid acc: 0.99\n",
      "sin i=10000 train loss:  0.25  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "->  sin  layer idx: 14 , best valid accuracy: 1.00, test accuracy: 0.96\n",
      "sin i=    0 train loss: 11.98  train acc: 0.00  val loss: 12.30  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.38  train acc: 1.00  val loss:  0.48  valid acc: 0.99\n",
      "sin i= 1000 train loss:  0.74  train acc: 1.00  val loss:  0.20  valid acc: 1.00\n",
      "sin i= 1500 train loss:  0.57  train acc: 1.00  val loss:  0.13  valid acc: 1.00\n",
      "sin i= 2000 train loss:  0.49  train acc: 1.00  val loss:  0.09  valid acc: 1.00\n",
      "sin i= 2500 train loss:  0.43  train acc: 1.00  val loss:  0.08  valid acc: 1.00\n",
      "sin i= 3000 train loss:  0.40  train acc: 1.00  val loss:  0.07  valid acc: 1.00\n",
      "sin i= 3500 train loss:  0.37  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4000 train loss:  0.35  train acc: 1.00  val loss:  0.06  valid acc: 1.00\n",
      "sin i= 4500 train loss:  0.33  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5000 train loss:  0.32  train acc: 1.00  val loss:  0.05  valid acc: 1.00\n",
      "sin i= 5500 train loss:  0.31  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6000 train loss:  0.29  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 6500 train loss:  0.28  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 7000 train loss:  0.27  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i= 7500 train loss:  0.27  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 8000 train loss:  2.61  train acc: 1.00  val loss:  0.27  valid acc: 0.94\n",
      "sin i= 8500 train loss:  1.23  train acc: 1.00  val loss:  0.06  valid acc: 0.99\n",
      "sin i= 9000 train loss:  0.69  train acc: 1.00  val loss:  0.04  valid acc: 1.00\n",
      "sin i= 9500 train loss:  0.45  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "sin i=10000 train loss:  0.33  train acc: 1.00  val loss:  0.03  valid acc: 1.00\n",
      "->  sin  layer idx: 15 , best valid accuracy: 1.00, test accuracy: 0.99\n",
      "sin i=    0 train loss: 19.28  train acc: 0.00  val loss: 42.22  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.37  train acc: 0.81  val loss:  2.27  valid acc: 0.23\n",
      "sin i= 1000 train loss:  3.95  train acc: 0.97  val loss:  1.33  valid acc: 0.69\n",
      "sin i= 1500 train loss:  2.86  train acc: 0.99  val loss:  0.94  valid acc: 0.82\n",
      "sin i= 2000 train loss:  2.03  train acc: 0.99  val loss:  0.72  valid acc: 0.90\n",
      "sin i= 2500 train loss:  1.46  train acc: 1.00  val loss:  0.58  valid acc: 0.94\n",
      "sin i= 3000 train loss:  1.07  train acc: 1.00  val loss:  0.43  valid acc: 0.97\n",
      "sin i= 3500 train loss: 10.44  train acc: 0.53  val loss: 10.38  valid acc: 0.40\n",
      "sin i= 4000 train loss:  5.98  train acc: 0.94  val loss:  1.99  valid acc: 0.77\n",
      "sin i= 4500 train loss:  5.29  train acc: 0.98  val loss:  1.54  valid acc: 0.82\n",
      "sin i= 5000 train loss:  4.71  train acc: 0.99  val loss:  1.29  valid acc: 0.84\n",
      "sin i= 5500 train loss:  5.14  train acc: 0.98  val loss:  1.62  valid acc: 0.85\n",
      "sin i= 6000 train loss:  4.40  train acc: 0.99  val loss:  0.87  valid acc: 0.88\n",
      "sin i= 6500 train loss:  3.74  train acc: 1.00  val loss:  0.80  valid acc: 0.89\n",
      "sin i= 7000 train loss:  3.15  train acc: 1.00  val loss:  0.58  valid acc: 0.89\n",
      "sin i= 7500 train loss:  5.88  train acc: 0.95  val loss:  1.24  valid acc: 0.76\n",
      "sin i= 8000 train loss:  5.11  train acc: 0.99  val loss:  0.72  valid acc: 0.85\n",
      "sin i= 8500 train loss:  4.47  train acc: 0.99  val loss:  0.61  valid acc: 0.87\n",
      "sin i= 9000 train loss:  5.13  train acc: 0.98  val loss:  1.40  valid acc: 0.78\n",
      "sin i= 9500 train loss:  4.23  train acc: 1.00  val loss:  1.18  valid acc: 0.85\n",
      "sin i=10000 train loss:  3.56  train acc: 1.00  val loss:  1.10  valid acc: 0.84\n",
      "->  sin  layer idx: 16 , best valid accuracy: 0.97, test accuracy: 0.86\n",
      "bin i=    0 train loss:  9.28  train acc: 0.00  val loss:  6.87  valid acc: 0.00\n",
      "bin i=  500 train loss:  5.10  train acc: 0.26  val loss:  4.03  valid acc: 0.10\n",
      "bin i= 1000 train loss:  4.23  train acc: 0.48  val loss:  3.31  valid acc: 0.11\n",
      "bin i= 1500 train loss:  3.71  train acc: 0.69  val loss:  3.08  valid acc: 0.21\n",
      "bin i= 2000 train loss:  3.40  train acc: 0.83  val loss:  2.98  valid acc: 0.23\n",
      "bin i= 2500 train loss:  3.12  train acc: 0.89  val loss:  2.95  valid acc: 0.17\n",
      "bin i= 3000 train loss:  2.91  train acc: 0.93  val loss:  2.93  valid acc: 0.22\n",
      "bin i= 3500 train loss:  2.76  train acc: 0.96  val loss:  2.91  valid acc: 0.24\n",
      "bin i= 4000 train loss:  2.62  train acc: 0.97  val loss:  2.93  valid acc: 0.24\n",
      "bin i= 4500 train loss:  2.53  train acc: 0.98  val loss:  2.92  valid acc: 0.21\n",
      "bin i= 5000 train loss:  2.46  train acc: 0.97  val loss:  2.93  valid acc: 0.24\n",
      "bin i= 5500 train loss:  2.39  train acc: 0.98  val loss:  2.97  valid acc: 0.27\n",
      "bin i= 6000 train loss:  2.35  train acc: 0.97  val loss:  3.00  valid acc: 0.27\n",
      "bin i= 6500 train loss:  2.27  train acc: 0.98  val loss:  3.00  valid acc: 0.32\n",
      "bin i= 7000 train loss:  2.23  train acc: 0.98  val loss:  3.07  valid acc: 0.29\n",
      "bin i= 7500 train loss:  2.18  train acc: 0.99  val loss:  3.07  valid acc: 0.29\n",
      "bin i= 8000 train loss:  2.15  train acc: 0.99  val loss:  3.09  valid acc: 0.29\n",
      "bin i= 8500 train loss:  2.12  train acc: 0.99  val loss:  3.14  valid acc: 0.30\n",
      "bin i= 9000 train loss:  2.08  train acc: 0.99  val loss:  3.16  valid acc: 0.28\n",
      "bin i= 9500 train loss:  2.03  train acc: 0.99  val loss:  3.22  valid acc: 0.30\n",
      "bin i=10000 train loss:  2.01  train acc: 1.00  val loss:  3.21  valid acc: 0.30\n",
      "->  bin  layer idx: 0  , best valid accuracy: 0.32, test accuracy: 0.41\n",
      "bin i=    0 train loss:  9.28  train acc: 0.00  val loss:  6.76  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.36  train acc: 0.24  val loss:  3.70  valid acc: 0.10\n",
      "bin i= 1000 train loss:  3.73  train acc: 0.50  val loss:  3.31  valid acc: 0.14\n",
      "bin i= 1500 train loss:  3.28  train acc: 0.74  val loss:  3.16  valid acc: 0.15\n",
      "bin i= 2000 train loss:  3.00  train acc: 0.85  val loss:  3.13  valid acc: 0.16\n",
      "bin i= 2500 train loss:  2.74  train acc: 0.89  val loss:  3.18  valid acc: 0.18\n",
      "bin i= 3000 train loss:  2.52  train acc: 0.93  val loss:  3.24  valid acc: 0.12\n",
      "bin i= 3500 train loss:  2.37  train acc: 0.95  val loss:  3.32  valid acc: 0.17\n",
      "bin i= 4000 train loss:  2.25  train acc: 0.97  val loss:  3.41  valid acc: 0.16\n",
      "bin i= 4500 train loss:  2.18  train acc: 0.96  val loss:  3.47  valid acc: 0.15\n",
      "bin i= 5000 train loss:  2.09  train acc: 0.97  val loss:  3.53  valid acc: 0.17\n",
      "bin i= 5500 train loss:  2.03  train acc: 0.97  val loss:  3.58  valid acc: 0.17\n",
      "bin i= 6000 train loss:  1.96  train acc: 0.98  val loss:  3.68  valid acc: 0.19\n",
      "bin i= 6500 train loss:  1.90  train acc: 0.98  val loss:  3.67  valid acc: 0.19\n",
      "bin i= 7000 train loss:  1.86  train acc: 0.98  val loss:  3.81  valid acc: 0.18\n",
      "bin i= 7500 train loss:  1.80  train acc: 0.99  val loss:  3.84  valid acc: 0.19\n",
      "bin i= 8000 train loss:  1.77  train acc: 0.99  val loss:  3.89  valid acc: 0.20\n",
      "bin i= 8500 train loss:  1.73  train acc: 0.99  val loss:  3.93  valid acc: 0.21\n",
      "bin i= 9000 train loss:  1.70  train acc: 0.99  val loss:  4.00  valid acc: 0.19\n",
      "bin i= 9500 train loss:  1.64  train acc: 0.99  val loss:  4.07  valid acc: 0.20\n",
      "bin i=10000 train loss:  1.61  train acc: 1.00  val loss:  4.12  valid acc: 0.21\n",
      "->  bin  layer idx: 1  , best valid accuracy: 0.21, test accuracy: 0.15\n",
      "bin i=    0 train loss:  9.28  train acc: 0.00  val loss:  6.71  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.32  train acc: 0.22  val loss:  3.69  valid acc: 0.11\n",
      "bin i= 1000 train loss:  3.77  train acc: 0.44  val loss:  3.34  valid acc: 0.14\n",
      "bin i= 1500 train loss:  3.34  train acc: 0.65  val loss:  3.12  valid acc: 0.18\n",
      "bin i= 2000 train loss:  3.07  train acc: 0.74  val loss:  2.99  valid acc: 0.20\n",
      "bin i= 2500 train loss:  2.81  train acc: 0.82  val loss:  2.96  valid acc: 0.21\n",
      "bin i= 3000 train loss:  2.58  train acc: 0.88  val loss:  2.92  valid acc: 0.23\n",
      "bin i= 3500 train loss:  2.41  train acc: 0.91  val loss:  2.97  valid acc: 0.21\n",
      "bin i= 4000 train loss:  2.31  train acc: 0.93  val loss:  2.97  valid acc: 0.22\n",
      "bin i= 4500 train loss:  2.23  train acc: 0.94  val loss:  2.99  valid acc: 0.25\n",
      "bin i= 5000 train loss:  2.15  train acc: 0.95  val loss:  3.03  valid acc: 0.23\n",
      "bin i= 5500 train loss:  2.11  train acc: 0.95  val loss:  3.04  valid acc: 0.24\n",
      "bin i= 6000 train loss:  2.03  train acc: 0.96  val loss:  3.10  valid acc: 0.25\n",
      "bin i= 6500 train loss:  1.97  train acc: 0.95  val loss:  2.99  valid acc: 0.28\n",
      "bin i= 7000 train loss:  1.92  train acc: 0.96  val loss:  3.09  valid acc: 0.27\n",
      "bin i= 7500 train loss:  1.86  train acc: 0.97  val loss:  3.13  valid acc: 0.27\n",
      "bin i= 8000 train loss:  1.84  train acc: 0.97  val loss:  3.09  valid acc: 0.29\n",
      "bin i= 8500 train loss:  1.78  train acc: 0.98  val loss:  3.13  valid acc: 0.29\n",
      "bin i= 9000 train loss:  1.75  train acc: 0.98  val loss:  3.09  valid acc: 0.30\n",
      "bin i= 9500 train loss:  1.73  train acc: 0.98  val loss:  3.19  valid acc: 0.29\n",
      "bin i=10000 train loss:  1.72  train acc: 0.98  val loss:  3.16  valid acc: 0.29\n",
      "->  bin  layer idx: 2  , best valid accuracy: 0.30, test accuracy: 0.12\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.67  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.31  train acc: 0.20  val loss:  3.78  valid acc: 0.11\n",
      "bin i= 1000 train loss:  3.84  train acc: 0.35  val loss:  3.48  valid acc: 0.12\n",
      "bin i= 1500 train loss:  3.48  train acc: 0.51  val loss:  3.32  valid acc: 0.12\n",
      "bin i= 2000 train loss:  3.23  train acc: 0.61  val loss:  3.25  valid acc: 0.13\n",
      "bin i= 2500 train loss:  2.97  train acc: 0.71  val loss:  3.25  valid acc: 0.15\n",
      "bin i= 3000 train loss:  2.75  train acc: 0.79  val loss:  3.25  valid acc: 0.14\n",
      "bin i= 3500 train loss:  2.56  train acc: 0.85  val loss:  3.28  valid acc: 0.18\n",
      "bin i= 4000 train loss:  2.46  train acc: 0.87  val loss:  3.35  valid acc: 0.18\n",
      "bin i= 4500 train loss:  2.41  train acc: 0.86  val loss:  3.37  valid acc: 0.22\n",
      "bin i= 5000 train loss:  2.31  train acc: 0.89  val loss:  3.46  valid acc: 0.20\n",
      "bin i= 5500 train loss:  2.30  train acc: 0.88  val loss:  3.43  valid acc: 0.20\n",
      "bin i= 6000 train loss:  2.21  train acc: 0.91  val loss:  3.51  valid acc: 0.20\n",
      "bin i= 6500 train loss:  2.12  train acc: 0.91  val loss:  3.49  valid acc: 0.20\n",
      "bin i= 7000 train loss:  2.07  train acc: 0.93  val loss:  3.57  valid acc: 0.20\n",
      "bin i= 7500 train loss:  2.02  train acc: 0.93  val loss:  3.60  valid acc: 0.23\n",
      "bin i= 8000 train loss:  2.01  train acc: 0.93  val loss:  3.64  valid acc: 0.24\n",
      "bin i= 8500 train loss:  1.95  train acc: 0.94  val loss:  3.65  valid acc: 0.23\n",
      "bin i= 9000 train loss:  1.92  train acc: 0.95  val loss:  3.70  valid acc: 0.23\n",
      "bin i= 9500 train loss:  1.87  train acc: 0.95  val loss:  3.81  valid acc: 0.23\n",
      "bin i=10000 train loss:  1.87  train acc: 0.94  val loss:  3.70  valid acc: 0.25\n",
      "->  bin  layer idx: 3  , best valid accuracy: 0.25, test accuracy: 0.23\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.64  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.28  train acc: 0.17  val loss:  3.84  valid acc: 0.08\n",
      "bin i= 1000 train loss:  3.93  train acc: 0.23  val loss:  3.69  valid acc: 0.08\n",
      "bin i= 1500 train loss:  3.71  train acc: 0.29  val loss:  3.59  valid acc: 0.10\n",
      "bin i= 2000 train loss:  3.54  train acc: 0.38  val loss:  3.48  valid acc: 0.13\n",
      "bin i= 2500 train loss:  3.32  train acc: 0.43  val loss:  3.46  valid acc: 0.14\n",
      "bin i= 3000 train loss:  3.10  train acc: 0.52  val loss:  3.42  valid acc: 0.13\n",
      "bin i= 3500 train loss:  2.92  train acc: 0.55  val loss:  3.46  valid acc: 0.12\n",
      "bin i= 4000 train loss:  2.78  train acc: 0.63  val loss:  3.45  valid acc: 0.13\n",
      "bin i= 4500 train loss:  2.71  train acc: 0.66  val loss:  3.49  valid acc: 0.15\n",
      "bin i= 5000 train loss:  2.62  train acc: 0.72  val loss:  3.52  valid acc: 0.14\n",
      "bin i= 5500 train loss:  2.57  train acc: 0.71  val loss:  3.59  valid acc: 0.14\n",
      "bin i= 6000 train loss:  2.51  train acc: 0.73  val loss:  3.67  valid acc: 0.15\n",
      "bin i= 6500 train loss:  2.37  train acc: 0.77  val loss:  3.61  valid acc: 0.17\n",
      "bin i= 7000 train loss:  2.35  train acc: 0.78  val loss:  3.63  valid acc: 0.16\n",
      "bin i= 7500 train loss:  2.31  train acc: 0.78  val loss:  3.64  valid acc: 0.16\n",
      "bin i= 8000 train loss:  2.27  train acc: 0.82  val loss:  3.73  valid acc: 0.15\n",
      "bin i= 8500 train loss:  2.24  train acc: 0.80  val loss:  3.68  valid acc: 0.17\n",
      "bin i= 9000 train loss:  2.17  train acc: 0.84  val loss:  3.75  valid acc: 0.15\n",
      "bin i= 9500 train loss:  2.13  train acc: 0.84  val loss:  3.80  valid acc: 0.16\n",
      "bin i=10000 train loss:  2.10  train acc: 0.85  val loss:  3.64  valid acc: 0.19\n",
      "->  bin  layer idx: 4  , best valid accuracy: 0.19, test accuracy: 0.16\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.65  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.27  train acc: 0.15  val loss:  3.91  valid acc: 0.04\n",
      "bin i= 1000 train loss:  3.98  train acc: 0.19  val loss:  3.78  valid acc: 0.06\n",
      "bin i= 1500 train loss:  3.85  train acc: 0.19  val loss:  3.71  valid acc: 0.07\n",
      "bin i= 2000 train loss:  3.75  train acc: 0.24  val loss:  3.66  valid acc: 0.07\n",
      "bin i= 2500 train loss:  3.60  train acc: 0.25  val loss:  3.62  valid acc: 0.10\n",
      "bin i= 3000 train loss:  3.44  train acc: 0.30  val loss:  3.58  valid acc: 0.10\n",
      "bin i= 3500 train loss:  3.28  train acc: 0.37  val loss:  3.59  valid acc: 0.09\n",
      "bin i= 4000 train loss:  3.19  train acc: 0.41  val loss:  3.51  valid acc: 0.14\n",
      "bin i= 4500 train loss:  3.15  train acc: 0.41  val loss:  3.54  valid acc: 0.13\n",
      "bin i= 5000 train loss:  3.03  train acc: 0.46  val loss:  3.52  valid acc: 0.12\n",
      "bin i= 5500 train loss:  3.00  train acc: 0.48  val loss:  3.55  valid acc: 0.12\n",
      "bin i= 6000 train loss:  2.98  train acc: 0.50  val loss:  3.56  valid acc: 0.12\n",
      "bin i= 6500 train loss:  2.82  train acc: 0.54  val loss:  3.52  valid acc: 0.13\n",
      "bin i= 7000 train loss:  2.78  train acc: 0.57  val loss:  3.56  valid acc: 0.13\n",
      "bin i= 7500 train loss:  2.72  train acc: 0.58  val loss:  3.51  valid acc: 0.14\n",
      "bin i= 8000 train loss:  2.72  train acc: 0.59  val loss:  3.59  valid acc: 0.13\n",
      "bin i= 8500 train loss:  2.67  train acc: 0.60  val loss:  3.60  valid acc: 0.13\n",
      "bin i= 9000 train loss:  2.58  train acc: 0.66  val loss:  3.69  valid acc: 0.12\n",
      "bin i= 9500 train loss:  2.52  train acc: 0.65  val loss:  3.66  valid acc: 0.12\n",
      "bin i=10000 train loss:  2.47  train acc: 0.68  val loss:  3.64  valid acc: 0.14\n",
      "->  bin  layer idx: 5  , best valid accuracy: 0.14, test accuracy: 0.12\n",
      "bin i=    0 train loss:  9.28  train acc: 0.00  val loss:  6.69  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.26  train acc: 0.14  val loss:  3.96  valid acc: 0.05\n",
      "bin i= 1000 train loss:  4.01  train acc: 0.17  val loss:  3.83  valid acc: 0.05\n",
      "bin i= 1500 train loss:  3.93  train acc: 0.19  val loss:  3.80  valid acc: 0.05\n",
      "bin i= 2000 train loss:  3.85  train acc: 0.22  val loss:  3.78  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.72  train acc: 0.20  val loss:  3.74  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.62  train acc: 0.24  val loss:  3.74  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.47  train acc: 0.26  val loss:  3.72  valid acc: 0.04\n",
      "bin i= 4000 train loss:  3.41  train acc: 0.30  val loss:  3.70  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.36  train acc: 0.30  val loss:  3.66  valid acc: 0.07\n",
      "bin i= 5000 train loss:  3.27  train acc: 0.33  val loss:  3.65  valid acc: 0.08\n",
      "bin i= 5500 train loss:  3.22  train acc: 0.35  val loss:  3.66  valid acc: 0.11\n",
      "bin i= 6000 train loss:  3.21  train acc: 0.37  val loss:  3.64  valid acc: 0.11\n",
      "bin i= 6500 train loss:  3.05  train acc: 0.39  val loss:  3.60  valid acc: 0.11\n",
      "bin i= 7000 train loss:  3.06  train acc: 0.40  val loss:  3.61  valid acc: 0.11\n",
      "bin i= 7500 train loss:  3.00  train acc: 0.41  val loss:  3.54  valid acc: 0.12\n",
      "bin i= 8000 train loss:  3.02  train acc: 0.43  val loss:  3.70  valid acc: 0.12\n",
      "bin i= 8500 train loss:  2.97  train acc: 0.46  val loss:  3.62  valid acc: 0.12\n",
      "bin i= 9000 train loss:  2.87  train acc: 0.48  val loss:  3.66  valid acc: 0.12\n",
      "bin i= 9500 train loss:  2.78  train acc: 0.50  val loss:  3.68  valid acc: 0.13\n",
      "bin i=10000 train loss:  2.79  train acc: 0.53  val loss:  3.62  valid acc: 0.12\n",
      "->  bin  layer idx: 6  , best valid accuracy: 0.13, test accuracy: 0.08\n",
      "bin i=    0 train loss:  9.28  train acc: 0.00  val loss:  6.70  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.27  train acc: 0.13  val loss:  3.98  valid acc: 0.05\n",
      "bin i= 1000 train loss:  4.03  train acc: 0.16  val loss:  3.85  valid acc: 0.04\n",
      "bin i= 1500 train loss:  3.95  train acc: 0.17  val loss:  3.82  valid acc: 0.04\n",
      "bin i= 2000 train loss:  3.89  train acc: 0.20  val loss:  3.82  valid acc: 0.02\n",
      "bin i= 2500 train loss:  3.77  train acc: 0.19  val loss:  3.79  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.66  train acc: 0.22  val loss:  3.78  valid acc: 0.04\n",
      "bin i= 3500 train loss:  3.50  train acc: 0.23  val loss:  3.76  valid acc: 0.03\n",
      "bin i= 4000 train loss:  3.46  train acc: 0.27  val loss:  3.78  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.42  train acc: 0.26  val loss:  3.73  valid acc: 0.07\n",
      "bin i= 5000 train loss:  3.36  train acc: 0.26  val loss:  3.75  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.32  train acc: 0.28  val loss:  3.71  valid acc: 0.07\n",
      "bin i= 6000 train loss:  3.33  train acc: 0.28  val loss:  3.65  valid acc: 0.09\n",
      "bin i= 6500 train loss:  3.15  train acc: 0.33  val loss:  3.65  valid acc: 0.09\n",
      "bin i= 7000 train loss:  3.18  train acc: 0.33  val loss:  3.69  valid acc: 0.08\n",
      "bin i= 7500 train loss:  3.11  train acc: 0.34  val loss:  3.63  valid acc: 0.10\n",
      "bin i= 8000 train loss:  3.11  train acc: 0.34  val loss:  3.69  valid acc: 0.08\n",
      "bin i= 8500 train loss:  3.10  train acc: 0.36  val loss:  3.68  valid acc: 0.09\n",
      "bin i= 9000 train loss:  3.00  train acc: 0.40  val loss:  3.71  valid acc: 0.08\n",
      "bin i= 9500 train loss:  2.95  train acc: 0.39  val loss:  3.69  valid acc: 0.09\n",
      "bin i=10000 train loss:  2.92  train acc: 0.42  val loss:  3.66  valid acc: 0.09\n",
      "->  bin  layer idx: 7  , best valid accuracy: 0.10, test accuracy: 0.09\n",
      "bin i=    0 train loss:  9.25  train acc: 0.00  val loss:  6.73  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.23  train acc: 0.14  val loss:  3.98  valid acc: 0.04\n",
      "bin i= 1000 train loss:  4.00  train acc: 0.16  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 1500 train loss:  3.92  train acc: 0.16  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 2000 train loss:  3.87  train acc: 0.20  val loss:  3.84  valid acc: 0.02\n",
      "bin i= 2500 train loss:  3.76  train acc: 0.18  val loss:  3.79  valid acc: 0.05\n",
      "bin i= 3000 train loss:  3.67  train acc: 0.20  val loss:  3.80  valid acc: 0.03\n",
      "bin i= 3500 train loss:  3.54  train acc: 0.22  val loss:  3.78  valid acc: 0.02\n",
      "bin i= 4000 train loss:  3.49  train acc: 0.25  val loss:  3.80  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.48  train acc: 0.24  val loss:  3.76  valid acc: 0.06\n",
      "bin i= 5000 train loss:  3.44  train acc: 0.23  val loss:  3.80  valid acc: 0.04\n",
      "bin i= 5500 train loss:  3.40  train acc: 0.25  val loss:  3.74  valid acc: 0.06\n",
      "bin i= 6000 train loss:  3.43  train acc: 0.23  val loss:  3.69  valid acc: 0.09\n",
      "bin i= 6500 train loss:  3.25  train acc: 0.27  val loss:  3.68  valid acc: 0.08\n",
      "bin i= 7000 train loss:  3.28  train acc: 0.26  val loss:  3.77  valid acc: 0.07\n",
      "bin i= 7500 train loss:  3.24  train acc: 0.27  val loss:  3.67  valid acc: 0.08\n",
      "bin i= 8000 train loss:  3.27  train acc: 0.25  val loss:  3.73  valid acc: 0.08\n",
      "bin i= 8500 train loss:  3.25  train acc: 0.28  val loss:  3.73  valid acc: 0.07\n",
      "bin i= 9000 train loss:  3.18  train acc: 0.29  val loss:  3.67  valid acc: 0.08\n",
      "bin i= 9500 train loss:  3.08  train acc: 0.31  val loss:  3.67  valid acc: 0.09\n",
      "bin i=10000 train loss:  3.12  train acc: 0.32  val loss:  3.68  valid acc: 0.08\n",
      "->  bin  layer idx: 8  , best valid accuracy: 0.09, test accuracy: 0.09\n",
      "bin i=    0 train loss:  9.25  train acc: 0.00  val loss:  6.78  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.21  train acc: 0.14  val loss:  4.00  valid acc: 0.04\n",
      "bin i= 1000 train loss:  3.98  train acc: 0.15  val loss:  3.88  valid acc: 0.05\n",
      "bin i= 1500 train loss:  3.91  train acc: 0.15  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 2000 train loss:  3.87  train acc: 0.18  val loss:  3.87  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.76  train acc: 0.18  val loss:  3.84  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.69  train acc: 0.19  val loss:  3.84  valid acc: 0.03\n",
      "bin i= 3500 train loss:  3.56  train acc: 0.20  val loss:  3.81  valid acc: 0.03\n",
      "bin i= 4000 train loss:  3.50  train acc: 0.24  val loss:  3.83  valid acc: 0.04\n",
      "bin i= 4500 train loss:  3.48  train acc: 0.23  val loss:  3.79  valid acc: 0.05\n",
      "bin i= 5000 train loss:  3.42  train acc: 0.23  val loss:  3.84  valid acc: 0.06\n",
      "bin i= 5500 train loss:  3.39  train acc: 0.25  val loss:  3.82  valid acc: 0.04\n",
      "bin i= 6000 train loss:  3.39  train acc: 0.22  val loss:  3.74  valid acc: 0.09\n",
      "bin i= 6500 train loss:  3.22  train acc: 0.28  val loss:  3.71  valid acc: 0.10\n",
      "bin i= 7000 train loss:  3.24  train acc: 0.27  val loss:  3.82  valid acc: 0.06\n",
      "bin i= 7500 train loss:  3.18  train acc: 0.29  val loss:  3.69  valid acc: 0.09\n",
      "bin i= 8000 train loss:  3.21  train acc: 0.27  val loss:  3.70  valid acc: 0.09\n",
      "bin i= 8500 train loss:  3.17  train acc: 0.32  val loss:  3.76  valid acc: 0.08\n",
      "bin i= 9000 train loss:  3.11  train acc: 0.31  val loss:  3.70  valid acc: 0.09\n",
      "bin i= 9500 train loss:  3.02  train acc: 0.32  val loss:  3.72  valid acc: 0.09\n",
      "bin i=10000 train loss:  3.01  train acc: 0.37  val loss:  3.68  valid acc: 0.09\n",
      "->  bin  layer idx: 9  , best valid accuracy: 0.10, test accuracy: 0.09\n",
      "bin i=    0 train loss:  9.22  train acc: 0.00  val loss:  6.93  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.12  train acc: 0.14  val loss:  3.99  valid acc: 0.05\n",
      "bin i= 1000 train loss:  3.92  train acc: 0.14  val loss:  3.88  valid acc: 0.05\n",
      "bin i= 1500 train loss:  3.86  train acc: 0.15  val loss:  3.87  valid acc: 0.04\n",
      "bin i= 2000 train loss:  3.82  train acc: 0.17  val loss:  3.88  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.73  train acc: 0.17  val loss:  3.83  valid acc: 0.05\n",
      "bin i= 3000 train loss:  3.68  train acc: 0.18  val loss:  3.83  valid acc: 0.03\n",
      "bin i= 3500 train loss:  3.52  train acc: 0.19  val loss:  3.77  valid acc: 0.05\n",
      "bin i= 4000 train loss:  3.46  train acc: 0.22  val loss:  3.77  valid acc: 0.05\n",
      "bin i= 4500 train loss:  3.44  train acc: 0.21  val loss:  3.75  valid acc: 0.05\n",
      "bin i= 5000 train loss:  3.39  train acc: 0.22  val loss:  3.74  valid acc: 0.07\n",
      "bin i= 5500 train loss:  3.34  train acc: 0.24  val loss:  3.73  valid acc: 0.06\n",
      "bin i= 6000 train loss:  3.37  train acc: 0.22  val loss:  3.64  valid acc: 0.09\n",
      "bin i= 6500 train loss:  3.21  train acc: 0.26  val loss:  3.62  valid acc: 0.10\n",
      "bin i= 7000 train loss:  3.22  train acc: 0.26  val loss:  3.72  valid acc: 0.08\n",
      "bin i= 7500 train loss:  3.16  train acc: 0.26  val loss:  3.63  valid acc: 0.09\n",
      "bin i= 8000 train loss:  3.17  train acc: 0.26  val loss:  3.61  valid acc: 0.09\n",
      "bin i= 8500 train loss:  3.15  train acc: 0.28  val loss:  3.62  valid acc: 0.09\n",
      "bin i= 9000 train loss:  3.07  train acc: 0.28  val loss:  3.58  valid acc: 0.09\n",
      "bin i= 9500 train loss:  2.99  train acc: 0.31  val loss:  3.63  valid acc: 0.10\n",
      "bin i=10000 train loss:  3.01  train acc: 0.31  val loss:  3.55  valid acc: 0.11\n",
      "->  bin  layer idx: 10 , best valid accuracy: 0.11, test accuracy: 0.12\n",
      "bin i=    0 train loss:  9.25  train acc: 0.00  val loss:  7.16  valid acc: 0.00\n",
      "bin i=  500 train loss:  3.99  train acc: 0.14  val loss:  3.94  valid acc: 0.05\n",
      "bin i= 1000 train loss:  3.80  train acc: 0.16  val loss:  3.84  valid acc: 0.05\n",
      "bin i= 1500 train loss:  3.76  train acc: 0.15  val loss:  3.83  valid acc: 0.05\n",
      "bin i= 2000 train loss:  3.73  train acc: 0.17  val loss:  3.82  valid acc: 0.03\n",
      "bin i= 2500 train loss:  3.65  train acc: 0.18  val loss:  3.77  valid acc: 0.07\n",
      "bin i= 3000 train loss:  3.57  train acc: 0.17  val loss:  3.76  valid acc: 0.03\n",
      "bin i= 3500 train loss:  3.45  train acc: 0.20  val loss:  3.73  valid acc: 0.07\n",
      "bin i= 4000 train loss:  3.37  train acc: 0.22  val loss:  3.72  valid acc: 0.07\n",
      "bin i= 4500 train loss:  3.36  train acc: 0.22  val loss:  3.74  valid acc: 0.07\n",
      "bin i= 5000 train loss:  3.31  train acc: 0.22  val loss:  3.77  valid acc: 0.07\n",
      "bin i= 5500 train loss:  3.28  train acc: 0.22  val loss:  3.71  valid acc: 0.08\n",
      "bin i= 6000 train loss:  3.27  train acc: 0.24  val loss:  3.67  valid acc: 0.09\n",
      "bin i= 6500 train loss:  3.12  train acc: 0.25  val loss:  3.64  valid acc: 0.10\n",
      "bin i= 7000 train loss:  3.14  train acc: 0.26  val loss:  3.72  valid acc: 0.08\n",
      "bin i= 7500 train loss:  3.11  train acc: 0.29  val loss:  3.66  valid acc: 0.08\n",
      "bin i= 8000 train loss:  3.14  train acc: 0.25  val loss:  3.60  valid acc: 0.10\n",
      "bin i= 8500 train loss:  3.08  train acc: 0.26  val loss:  3.70  valid acc: 0.08\n",
      "bin i= 9000 train loss:  3.01  train acc: 0.30  val loss:  3.64  valid acc: 0.08\n",
      "bin i= 9500 train loss:  2.90  train acc: 0.32  val loss:  3.68  valid acc: 0.08\n",
      "bin i=10000 train loss:  2.95  train acc: 0.33  val loss:  3.64  valid acc: 0.08\n",
      "->  bin  layer idx: 11 , best valid accuracy: 0.10, test accuracy: 0.11\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  7.25  valid acc: 0.00\n",
      "bin i=  500 train loss:  3.96  train acc: 0.14  val loss:  3.96  valid acc: 0.05\n",
      "bin i= 1000 train loss:  3.77  train acc: 0.16  val loss:  3.85  valid acc: 0.05\n",
      "bin i= 1500 train loss:  3.73  train acc: 0.15  val loss:  3.84  valid acc: 0.05\n",
      "bin i= 2000 train loss:  3.68  train acc: 0.17  val loss:  3.81  valid acc: 0.05\n",
      "bin i= 2500 train loss:  3.61  train acc: 0.18  val loss:  3.76  valid acc: 0.06\n",
      "bin i= 3000 train loss:  3.53  train acc: 0.17  val loss:  3.74  valid acc: 0.02\n",
      "bin i= 3500 train loss:  3.40  train acc: 0.20  val loss:  3.68  valid acc: 0.05\n",
      "bin i= 4000 train loss:  3.33  train acc: 0.22  val loss:  3.70  valid acc: 0.07\n",
      "bin i= 4500 train loss:  3.33  train acc: 0.22  val loss:  3.66  valid acc: 0.08\n",
      "bin i= 5000 train loss:  3.28  train acc: 0.21  val loss:  3.71  valid acc: 0.08\n",
      "bin i= 5500 train loss:  3.23  train acc: 0.22  val loss:  3.62  valid acc: 0.08\n",
      "bin i= 6000 train loss:  3.28  train acc: 0.23  val loss:  3.57  valid acc: 0.10\n",
      "bin i= 6500 train loss:  3.08  train acc: 0.27  val loss:  3.57  valid acc: 0.10\n",
      "bin i= 7000 train loss:  3.11  train acc: 0.27  val loss:  3.65  valid acc: 0.07\n",
      "bin i= 7500 train loss:  3.06  train acc: 0.28  val loss:  3.56  valid acc: 0.10\n",
      "bin i= 8000 train loss:  3.07  train acc: 0.26  val loss:  3.52  valid acc: 0.10\n",
      "bin i= 8500 train loss:  3.03  train acc: 0.27  val loss:  3.58  valid acc: 0.08\n",
      "bin i= 9000 train loss:  2.98  train acc: 0.31  val loss:  3.55  valid acc: 0.09\n",
      "bin i= 9500 train loss:  2.87  train acc: 0.33  val loss:  3.60  valid acc: 0.10\n",
      "bin i=10000 train loss:  2.93  train acc: 0.31  val loss:  3.54  valid acc: 0.10\n",
      "->  bin  layer idx: 12 , best valid accuracy: 0.10, test accuracy: 0.09\n",
      "bin i=    0 train loss:  9.32  train acc: 0.00  val loss:  7.83  valid acc: 0.00\n",
      "bin i=  500 train loss:  3.93  train acc: 0.15  val loss:  3.96  valid acc: 0.04\n",
      "bin i= 1000 train loss:  3.73  train acc: 0.17  val loss:  3.84  valid acc: 0.05\n",
      "bin i= 1500 train loss:  3.67  train acc: 0.16  val loss:  3.78  valid acc: 0.06\n",
      "bin i= 2000 train loss:  3.61  train acc: 0.20  val loss:  3.77  valid acc: 0.05\n",
      "bin i= 2500 train loss:  3.50  train acc: 0.20  val loss:  3.74  valid acc: 0.05\n",
      "bin i= 3000 train loss:  3.45  train acc: 0.20  val loss:  3.71  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.31  train acc: 0.21  val loss:  3.69  valid acc: 0.08\n",
      "bin i= 4000 train loss:  3.20  train acc: 0.24  val loss:  3.69  valid acc: 0.07\n",
      "bin i= 4500 train loss:  3.21  train acc: 0.23  val loss:  3.64  valid acc: 0.09\n",
      "bin i= 5000 train loss:  3.14  train acc: 0.25  val loss:  3.66  valid acc: 0.07\n",
      "bin i= 5500 train loss:  3.06  train acc: 0.26  val loss:  3.64  valid acc: 0.09\n",
      "bin i= 6000 train loss:  3.13  train acc: 0.24  val loss:  3.62  valid acc: 0.09\n",
      "bin i= 6500 train loss:  2.89  train acc: 0.33  val loss:  3.60  valid acc: 0.09\n",
      "bin i= 7000 train loss:  2.94  train acc: 0.28  val loss:  3.63  valid acc: 0.08\n",
      "bin i= 7500 train loss:  2.86  train acc: 0.33  val loss:  3.57  valid acc: 0.09\n",
      "bin i= 8000 train loss:  2.87  train acc: 0.32  val loss:  3.57  valid acc: 0.08\n",
      "bin i= 8500 train loss:  2.86  train acc: 0.32  val loss:  3.54  valid acc: 0.09\n",
      "bin i= 9000 train loss:  2.77  train acc: 0.37  val loss:  3.55  valid acc: 0.08\n",
      "bin i= 9500 train loss:  2.69  train acc: 0.36  val loss:  3.62  valid acc: 0.07\n",
      "bin i=10000 train loss:  2.75  train acc: 0.34  val loss:  3.46  valid acc: 0.11\n",
      "->  bin  layer idx: 13 , best valid accuracy: 0.11, test accuracy: 0.07\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  8.19  valid acc: 0.00\n",
      "bin i=  500 train loss:  3.89  train acc: 0.17  val loss:  3.93  valid acc: 0.06\n",
      "bin i= 1000 train loss:  3.66  train acc: 0.17  val loss:  3.84  valid acc: 0.06\n",
      "bin i= 1500 train loss:  3.55  train acc: 0.19  val loss:  3.80  valid acc: 0.07\n",
      "bin i= 2000 train loss:  3.48  train acc: 0.22  val loss:  3.77  valid acc: 0.07\n",
      "bin i= 2500 train loss:  3.34  train acc: 0.23  val loss:  3.75  valid acc: 0.07\n",
      "bin i= 3000 train loss:  3.29  train acc: 0.24  val loss:  3.70  valid acc: 0.06\n",
      "bin i= 3500 train loss:  3.15  train acc: 0.27  val loss:  3.68  valid acc: 0.08\n",
      "bin i= 4000 train loss:  3.02  train acc: 0.31  val loss:  3.70  valid acc: 0.08\n",
      "bin i= 4500 train loss:  3.04  train acc: 0.29  val loss:  3.67  valid acc: 0.08\n",
      "bin i= 5000 train loss:  2.94  train acc: 0.30  val loss:  3.60  valid acc: 0.09\n",
      "bin i= 5500 train loss:  2.87  train acc: 0.31  val loss:  3.64  valid acc: 0.08\n",
      "bin i= 6000 train loss:  2.88  train acc: 0.31  val loss:  3.70  valid acc: 0.08\n",
      "bin i= 6500 train loss:  2.74  train acc: 0.39  val loss:  3.59  valid acc: 0.09\n",
      "bin i= 7000 train loss:  2.77  train acc: 0.35  val loss:  3.68  valid acc: 0.09\n",
      "bin i= 7500 train loss:  2.73  train acc: 0.37  val loss:  3.60  valid acc: 0.10\n",
      "bin i= 8000 train loss:  2.70  train acc: 0.41  val loss:  3.66  valid acc: 0.08\n",
      "bin i= 8500 train loss:  2.68  train acc: 0.37  val loss:  3.72  valid acc: 0.09\n",
      "bin i= 9000 train loss:  2.60  train acc: 0.44  val loss:  3.65  valid acc: 0.08\n",
      "bin i= 9500 train loss:  2.55  train acc: 0.43  val loss:  3.67  valid acc: 0.10\n",
      "bin i=10000 train loss:  2.62  train acc: 0.43  val loss:  3.66  valid acc: 0.11\n",
      "->  bin  layer idx: 14 , best valid accuracy: 0.11, test accuracy: 0.11\n",
      "bin i=    0 train loss:  9.40  train acc: 0.00  val loss: 10.07  valid acc: 0.00\n",
      "bin i=  500 train loss:  3.91  train acc: 0.17  val loss:  3.96  valid acc: 0.07\n",
      "bin i= 1000 train loss:  3.67  train acc: 0.18  val loss:  3.88  valid acc: 0.07\n",
      "bin i= 1500 train loss:  3.52  train acc: 0.19  val loss:  3.81  valid acc: 0.07\n",
      "bin i= 2000 train loss:  3.43  train acc: 0.24  val loss:  3.75  valid acc: 0.07\n",
      "bin i= 2500 train loss:  3.30  train acc: 0.25  val loss:  3.77  valid acc: 0.06\n",
      "bin i= 3000 train loss:  3.28  train acc: 0.25  val loss:  3.76  valid acc: 0.06\n",
      "bin i= 3500 train loss:  3.04  train acc: 0.32  val loss:  3.68  valid acc: 0.07\n",
      "bin i= 4000 train loss:  2.93  train acc: 0.35  val loss:  3.69  valid acc: 0.08\n",
      "bin i= 4500 train loss:  2.98  train acc: 0.29  val loss:  3.72  valid acc: 0.08\n",
      "bin i= 5000 train loss:  2.86  train acc: 0.34  val loss:  3.67  valid acc: 0.08\n",
      "bin i= 5500 train loss:  2.80  train acc: 0.35  val loss:  3.73  valid acc: 0.08\n",
      "bin i= 6000 train loss:  2.78  train acc: 0.37  val loss:  3.69  valid acc: 0.08\n",
      "bin i= 6500 train loss:  2.72  train acc: 0.39  val loss:  3.63  valid acc: 0.09\n",
      "bin i= 7000 train loss:  2.70  train acc: 0.41  val loss:  3.79  valid acc: 0.08\n",
      "bin i= 7500 train loss:  2.62  train acc: 0.42  val loss:  3.65  valid acc: 0.12\n",
      "bin i= 8000 train loss:  2.63  train acc: 0.44  val loss:  3.67  valid acc: 0.09\n",
      "bin i= 8500 train loss:  2.57  train acc: 0.44  val loss:  3.70  valid acc: 0.10\n",
      "bin i= 9000 train loss:  2.49  train acc: 0.49  val loss:  3.69  valid acc: 0.09\n",
      "bin i= 9500 train loss:  2.47  train acc: 0.47  val loss:  3.67  valid acc: 0.10\n",
      "bin i=10000 train loss:  2.52  train acc: 0.47  val loss:  3.66  valid acc: 0.11\n",
      "->  bin  layer idx: 15 , best valid accuracy: 0.12, test accuracy: 0.14\n",
      "bin i=    0 train loss: 13.33  train acc: 0.00  val loss: 53.20  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.93  train acc: 0.20  val loss:  4.20  valid acc: 0.10\n",
      "bin i= 1000 train loss:  4.19  train acc: 0.25  val loss:  4.10  valid acc: 0.05\n",
      "bin i= 1500 train loss:  4.21  train acc: 0.17  val loss:  3.97  valid acc: 0.04\n",
      "bin i= 2000 train loss:  4.03  train acc: 0.22  val loss:  4.07  valid acc: 0.05\n",
      "bin i= 2500 train loss:  3.76  train acc: 0.19  val loss:  4.00  valid acc: 0.04\n",
      "bin i= 3000 train loss:  3.71  train acc: 0.25  val loss:  3.97  valid acc: 0.05\n",
      "bin i= 3500 train loss:  3.45  train acc: 0.28  val loss:  3.78  valid acc: 0.09\n",
      "bin i= 4000 train loss:  3.33  train acc: 0.25  val loss:  3.75  valid acc: 0.03\n",
      "bin i= 4500 train loss:  3.23  train acc: 0.29  val loss:  3.93  valid acc: 0.09\n",
      "bin i= 5000 train loss:  3.19  train acc: 0.30  val loss:  3.64  valid acc: 0.09\n",
      "bin i= 5500 train loss:  3.35  train acc: 0.27  val loss:  3.89  valid acc: 0.09\n",
      "bin i= 6000 train loss:  3.00  train acc: 0.33  val loss:  3.75  valid acc: 0.05\n",
      "bin i= 6500 train loss:  3.24  train acc: 0.31  val loss:  3.96  valid acc: 0.10\n",
      "bin i= 7000 train loss:  2.66  train acc: 0.42  val loss:  3.78  valid acc: 0.05\n",
      "bin i= 7500 train loss:  2.57  train acc: 0.43  val loss:  3.89  valid acc: 0.05\n",
      "bin i= 8000 train loss:  2.47  train acc: 0.47  val loss:  3.80  valid acc: 0.09\n",
      "bin i= 8500 train loss:  2.46  train acc: 0.46  val loss:  3.75  valid acc: 0.05\n",
      "bin i= 9000 train loss:  2.52  train acc: 0.42  val loss:  4.02  valid acc: 0.06\n",
      "bin i= 9500 train loss:  2.27  train acc: 0.52  val loss:  3.88  valid acc: 0.07\n",
      "bin i=10000 train loss:  2.66  train acc: 0.35  val loss:  4.61  valid acc: 0.05\n",
      "->  bin  layer idx: 16 , best valid accuracy: 0.10, test accuracy: 0.11\n"
     ]
    }
   ],
   "source": [
    "train_labels = torch.tensor([x2 for x1, x2 in train_inputs])\n",
    "valid_labels = torch.tensor([x2 for x1, x2 in valid_inputs]).to(device)\n",
    "test_labels = torch.tensor([x2 for x1, x2 in test_inputs]).to(device) \n",
    "\n",
    "test_accuracies = {\"sin\": {}, \"bin\": {}, \"lin\": {}, \"log\": {}}\n",
    "\n",
    "for basis_name, basis_embs in {\"sin\": basis_embs_sin, \"bin\": basis_embs_bin}.items():\n",
    "    for layer_idx in range(len(train_hidden_states)):\n",
    "\n",
    "        torch.manual_seed(0)\n",
    "        probe = ClassifierProbe(\n",
    "            emb_dim=train_hidden_states[0].shape[-1],\n",
    "            hidden_dim=100,\n",
    "            basis=basis_embs,\n",
    "            heldout_mask=test_mask,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)\n",
    "\n",
    "        rng = torch.Generator().manual_seed(0)\n",
    "        best_val_acc = -1\n",
    "        best_ckpt = None\n",
    "        for i in range(10000+1):\n",
    "            probe.train()\n",
    "            optimizer.zero_grad()\n",
    "            minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "            x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "            y = train_labels[minibatch_idcs].to(device)\n",
    "            logits = probe(x, holdout_eval_tokens=True)\n",
    "            # add l1 regularization of all params to the loss\n",
    "            loss = torch.nn.functional.cross_entropy(logits, y) + 0.001 * sum(p.abs().sum() for p in probe.parameters())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if i % 500 == 0:\n",
    "                train_acc = (logits.argmax(dim=-1) == y).float().mean().item()\n",
    "                probe.eval()\n",
    "                with torch.no_grad():\n",
    "                    valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                    valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                    valid_accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                    if valid_accuracy > best_val_acc:\n",
    "                        best_val_acc = valid_accuracy\n",
    "                        best_ckpt = probe.state_dict()\n",
    "                print(f\"{basis_name} {i=:>5} train loss: {loss.item():5.2f}  train acc: {train_acc:.2f}  val loss: {valid_loss.item():5.2f}  valid acc: {valid_accuracy:.2f}\")\n",
    "        probe.load_state_dict(best_ckpt)\n",
    "        probe.eval()\n",
    "        with torch.no_grad():\n",
    "            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        test_accuracies[basis_name][layer_idx] = test_accuracy\n",
    "        print(f\"->  {basis_name}  layer idx: {layer_idx:<3}, best valid accuracy: {best_val_acc:.2f}, test accuracy: {test_accuracy:.2f}\")\n",
    "                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3b7adf81",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_linear_layer(x: Tensor, y: Tensor) -> torch.nn.Linear:\n",
    "    if y.ndim == 1:\n",
    "        y = y.unsqueeze(-1)\n",
    "    if not y.is_floating_point():\n",
    "        y = y.float()\n",
    "   \n",
    "    lin = torch.nn.Linear(x.shape[-1], y.shape[-1], device=x.device)\n",
    "    x_aug = torch.cat([x, torch.ones(len(x), 1, device=x.device)], dim=1)\n",
    "    coeffs = torch.linalg.lstsq(x_aug, y).solution\n",
    "    w, b = coeffs[:-1], coeffs[-1]\n",
    "    with torch.no_grad():\n",
    "        lin.weight[:] = w.T\n",
    "        lin.bias[:] = b\n",
    "    return lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "207fb479",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer idx: 0  , linear probe acc: 0.02, log probe acc: 0.00\n",
      "layer idx: 1  , linear probe acc: 0.01, log probe acc: 0.00\n",
      "layer idx: 2  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 3  , linear probe acc: 0.03, log probe acc: 0.03\n",
      "layer idx: 4  , linear probe acc: 0.02, log probe acc: 0.03\n",
      "layer idx: 5  , linear probe acc: 0.02, log probe acc: 0.03\n",
      "layer idx: 6  , linear probe acc: 0.02, log probe acc: 0.04\n",
      "layer idx: 7  , linear probe acc: 0.02, log probe acc: 0.04\n",
      "layer idx: 8  , linear probe acc: 0.02, log probe acc: 0.05\n",
      "layer idx: 9  , linear probe acc: 0.02, log probe acc: 0.04\n",
      "layer idx: 10 , linear probe acc: 0.03, log probe acc: 0.05\n",
      "layer idx: 11 , linear probe acc: 0.02, log probe acc: 0.03\n",
      "layer idx: 12 , linear probe acc: 0.02, log probe acc: 0.03\n",
      "layer idx: 13 , linear probe acc: 0.02, log probe acc: 0.03\n",
      "layer idx: 14 , linear probe acc: 0.03, log probe acc: 0.05\n",
      "layer idx: 15 , linear probe acc: 0.02, log probe acc: 0.04\n",
      "layer idx: 16 , linear probe acc: 0.02, log probe acc: 0.04\n"
     ]
    }
   ],
   "source": [
    "for layer_idx in range(len(train_hidden_states)):\n",
    "    lin_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.to(device),\n",
    "    )\n",
    "    log_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.log1p().to(device),\n",
    "    )\n",
    "    lin_test_pred = lin_probe(test_hidden_states[layer_idx].float().to(device)).flatten().round().int()\n",
    "    lin_test_accuracy = (lin_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    log_test_pred = log_probe(test_hidden_states[layer_idx].float().to(device)).flatten().exp().add(1).round().int()\n",
    "    log_test_accuracy = (log_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    test_accuracies[\"lin\"][layer_idx] = lin_test_accuracy\n",
    "    test_accuracies[\"log\"][layer_idx] = log_test_accuracy\n",
    "\n",
    "    print(f\"layer idx: {layer_idx:<3}, linear probe acc: {lin_test_accuracy:.2f}, log probe acc: {log_test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8a5a762b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin accs: | 89% | 96% | 100% | 99% | 100% | 100% | 100% | 99% | 98% | 99% | 100% | 99% | 99% | 99% | 96% | 99% | 86% |\n",
      "bin accs: | 41% | 15% | 12% | 23% | 16% | 12% | 8% | 9% | 9% | 9% | 12% | 11% | 9% | 7% | 11% | 14% | 11% |\n",
      "lin accs: | 2% | 1% | 2% | 3% | 2% | 2% | 2% | 2% | 2% | 2% | 3% | 2% | 2% | 2% | 3% | 2% | 2% |\n",
      "log accs: | 0% | 0% | 2% | 3% | 3% | 3% | 4% | 4% | 5% | 4% | 5% | 3% | 3% | 3% | 5% | 4% | 4% |\n"
     ]
    }
   ],
   "source": [
    "for name, accs in test_accuracies.items():\n",
    "    print(f\"{name} accs: | \" + \" | \".join([f\"{x:.0%}\" for layer, x in sorted(accs.items())]) + \" |\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
