{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7a542467",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:7\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b9c875c",
   "metadata": {},
   "source": [
    "### Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3b08b50c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import random\n",
    "import collections\n",
    "\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "import tqdm.auto\n",
    "from torch import Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "da5bca2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sinusoidal_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int,\n",
    "    max_value: int,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional\n",
    "    encoding works in transformers.\n",
    "\n",
    "    The encoding is an evaluation of a sine and cosine function at different frequencies, where the\n",
    "    frequency is determined by the embedding dimension and the allowed range of the input values.\n",
    "\n",
    "    >>> sinusoidal_encode(\n",
    "    ...     torch.tensor([-5, 2, 1, 0]),\n",
    "    ...     embedding_dim=6,\n",
    "    ...     min_value=-5,\n",
    "    ...     max_value=5,\n",
    "    ... )\n",
    "    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],\n",
    "            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],\n",
    "            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],\n",
    "            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])\n",
    "    \"\"\"\n",
    "\n",
    "    if embedding_dim % 2 != 0 and not use_l2_norm:\n",
    "        raise ValueError(\"Embedding dimension must be even\")\n",
    "\n",
    "    if use_l2_norm:\n",
    "        if embedding_dim % 2 == 0:\n",
    "            reserved_dim = 2\n",
    "        else:\n",
    "            reserved_dim = 1\n",
    "        embedding_dim -= reserved_dim\n",
    "    else:\n",
    "        reserved_dim = 0  # will not be used\n",
    "\n",
    "    domain = max_value - min_value\n",
    "    y_shape = x.shape + (embedding_dim,)\n",
    "    y = torch.zeros(y_shape, device=x.device)\n",
    "    even_indices = torch.arange(0, embedding_dim, 2)\n",
    "    log_term = torch.log(torch.tensor(domain)) / embedding_dim\n",
    "    div_term = torch.exp(even_indices * -log_term)\n",
    "    x = x - min_value\n",
    "    values = x.unsqueeze(-1).float() * div_term\n",
    "    y[..., 0::2] = torch.sin(values)\n",
    "    y[..., 1::2] = torch.cos(values)\n",
    "\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "\n",
    "    return y\n",
    "\n",
    "\n",
    "def binary_encode(\n",
    "    x: Tensor,\n",
    "    embedding_dim: int,\n",
    "    min_value: int | float,\n",
    "    max_value: int | float,\n",
    "    use_l2_norm: bool = False,\n",
    "    norm_const: float | None = None,\n",
    ") -> Tensor:\n",
    "    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)\n",
    "    reserve_dim = 0 if not use_l2_norm else 1\n",
    "    x = x - min_value\n",
    "    maximum = x.max()\n",
    "    for i in range(embedding_dim - reserve_dim):\n",
    "        coeff = 2**i\n",
    "        if maximum < coeff:\n",
    "            break\n",
    "        y[..., -i - 1] = torch.floor(x / coeff) % 2\n",
    "        x = x - coeff * y[..., -i - 1]\n",
    "    if use_l2_norm:\n",
    "        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)\n",
    "        y /= y.norm(dim=-1, keepdim=True, p=2)\n",
    "    if norm_const is not None:\n",
    "        y *= norm_const\n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ff97ee3",
   "metadata": {},
   "source": [
    "### Prepare model and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8723d3d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ckpt = \"meta-llama/Llama-3.2-1B\"\n",
    "model = transformers.AutoModel.from_pretrained(model_ckpt).eval()\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = model.half().to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9b394470",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_values = torch.arange(0, 1000)\n",
    "mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))\n",
    "train_mask = mask < 0.9\n",
    "valid_mask = ~train_mask & (mask < 0.95)\n",
    "test_mask = ~train_mask & ~valid_mask\n",
    "\n",
    "train_values = all_values[train_mask]\n",
    "valid_values = all_values[valid_mask]\n",
    "test_values = all_values[test_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c15ab591",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_inputs = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x1 + x2 < 1000]\n",
    "train_values_set = set(train_values.tolist())\n",
    "valid_values_set = set(valid_values.tolist())\n",
    "test_values_set = set(test_values.tolist())\n",
    "        \n",
    "train_inputs = [(x1, x2) for x1, x2 in all_inputs if x1 + x2 in train_values_set]\n",
    "valid_inputs = [(x1, x2) for x1, x2 in all_inputs if x1 + x2 in valid_values_set]\n",
    "test_inputs = [(x1, x2) for x1, x2 in all_inputs if x1 + x2 in test_values_set]\n",
    "\n",
    "# sanity check\n",
    "assert set(train_inputs) & set(valid_inputs) == set()\n",
    "assert set(train_inputs) & set(test_inputs) == set()\n",
    "assert set(valid_inputs) & set(test_inputs) == set()\n",
    "\n",
    "random.seed(0)\n",
    "random.shuffle(train_inputs)\n",
    "random.shuffle(valid_inputs)\n",
    "random.shuffle(test_inputs)\n",
    "valid_size = 4096\n",
    "train_size = 100_000\n",
    "train_inputs = train_inputs[:train_size]\n",
    "valid_inputs = valid_inputs[:valid_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c41d1f88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('3 + 500 = ', '3 + 0 = ')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def make_str_input(operands: tuple[int, int] | list[int]) -> str:\n",
    "    x1, x2 = operands\n",
    "    return f\"{x1} + {x2} = \"\n",
    "\n",
    "make_str_input((3, 500)), make_str_input((3, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "36dd2f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hidden_states(model, str_inputs: list[str], batch_size: int) -> collections.defaultdict[int, Tensor]:\n",
    "    model.eval()\n",
    "    hidden_states = collections.defaultdict(list)\n",
    "    with torch.no_grad():\n",
    "        num_batches = (len(str_inputs) + batch_size - 1) // batch_size\n",
    "        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):\n",
    "            batch_inputs = tokenizer(batch_str, return_tensors=\"pt\")\n",
    "            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states\n",
    "            for layer_idx, hidden_state in enumerate(hidden_reprs):\n",
    "                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())\n",
    "    return {k: torch.stack(v) for k, v in hidden_states.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "da87aca4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96831aca89cb4fc49f24d86c5300a0c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/98 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8cd4f1eed2d644e186dbb8bce2a5b6a7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "957716e0ed9b4af1b880a33d3e9a0623",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/31 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch_size = 1024\n",
    "train_hidden_states = get_hidden_states(model, [make_str_input(val) for val in train_inputs], batch_size)\n",
    "valid_hidden_states = get_hidden_states(model, [make_str_input(val) for val in valid_inputs], batch_size)\n",
    "test_hidden_states = get_hidden_states(model, [make_str_input(val) for val in test_inputs], batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6288a397",
   "metadata": {},
   "source": [
    "### Probing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cff73d20",
   "metadata": {},
   "outputs": [],
   "source": [
    "basis_embs_sin = sinusoidal_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=train_hidden_states[0].shape[-1],\n",
    ")\n",
    "\n",
    "\n",
    "basis_embs_bin = binary_encode(\n",
    "    torch.arange(1000),\n",
    "    min_value=0,\n",
    "    max_value=1000,\n",
    "    embedding_dim=10,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "184992d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierProbe(torch.nn.Module):\n",
    "    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):\n",
    "        super().__init__()\n",
    "        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)\n",
    "        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)\n",
    "        self.basis: torch.nn.Buffer\n",
    "        self.heldout_mask: torch.nn.Buffer\n",
    "        self.register_buffer(\"basis\", basis)\n",
    "        self.register_buffer(\"heldout_mask\", heldout_mask)\n",
    "    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:\n",
    "        latent_x = self.emb_to_latent(x)\n",
    "        # during training, model learns to choose among only training tokens\n",
    "        # but during eval, model must choose among all tokens\n",
    "        # this means that the model is never exposed to the eval tokens during training\n",
    "        latent_choices = self.basis_to_latent(self.basis)\n",
    "        logits = latent_x @ latent_choices.T\n",
    "        if holdout_eval_tokens:\n",
    "            logits[:, self.heldout_mask] = float(\"-inf\")\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2d3d2c7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin i=    0 train loss: 16.61  train acc: 0.00  val loss: 24.12  valid acc: 0.02\n",
      "sin i=  500 train loss:  3.65  train acc: 0.88  val loss:  0.79  valid acc: 0.71\n",
      "sin i= 1000 train loss:  2.10  train acc: 0.89  val loss:  0.58  valid acc: 0.82\n",
      "sin i= 1500 train loss:  1.22  train acc: 0.92  val loss:  0.44  valid acc: 0.87\n",
      "sin i= 2000 train loss:  0.79  train acc: 0.92  val loss:  0.36  valid acc: 0.89\n",
      "sin i= 2500 train loss:  0.64  train acc: 0.92  val loss:  0.29  valid acc: 0.91\n",
      "sin i= 3000 train loss:  0.50  train acc: 0.94  val loss:  0.29  valid acc: 0.90\n",
      "sin i= 3500 train loss:  0.49  train acc: 0.93  val loss:  0.28  valid acc: 0.91\n",
      "sin i= 4000 train loss:  6.89  train acc: 0.85  val loss:  1.66  valid acc: 0.79\n",
      "sin i= 4500 train loss:  4.37  train acc: 0.90  val loss:  0.80  valid acc: 0.85\n",
      "sin i= 5000 train loss:  3.39  train acc: 0.92  val loss:  0.61  valid acc: 0.88\n",
      "sin i= 5500 train loss:  2.48  train acc: 0.95  val loss:  0.38  valid acc: 0.89\n",
      "sin i= 6000 train loss:  1.75  train acc: 0.95  val loss:  0.24  valid acc: 0.92\n",
      "sin i= 6500 train loss:  1.25  train acc: 0.95  val loss:  0.22  valid acc: 0.93\n",
      "sin i= 7000 train loss:  0.89  train acc: 0.95  val loss:  0.20  valid acc: 0.93\n",
      "sin i= 7500 train loss:  0.65  train acc: 0.95  val loss:  0.21  valid acc: 0.93\n",
      "sin i= 8000 train loss:  0.53  train acc: 0.96  val loss:  0.22  valid acc: 0.93\n",
      "sin i= 8500 train loss:  0.42  train acc: 0.97  val loss:  0.22  valid acc: 0.93\n",
      "sin i= 9000 train loss:  0.49  train acc: 0.97  val loss:  0.21  valid acc: 0.93\n",
      "sin i= 9500 train loss:  0.43  train acc: 0.95  val loss:  0.21  valid acc: 0.93\n",
      "sin i=10000 train loss:  0.39  train acc: 0.95  val loss:  0.21  valid acc: 0.94\n",
      "-> sin layer idx: 16 , best valid accuracy: 0.94, test accuracy: 0.93\n",
      "sin i=    0 train loss: 11.54  train acc: 0.00  val loss:  7.86  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.08  train acc: 0.84  val loss:  0.62  valid acc: 0.80\n",
      "sin i= 1000 train loss:  0.86  train acc: 0.86  val loss:  0.50  valid acc: 0.83\n",
      "sin i= 1500 train loss:  0.79  train acc: 0.87  val loss:  0.48  valid acc: 0.84\n",
      "sin i= 2000 train loss:  0.72  train acc: 0.87  val loss:  0.43  valid acc: 0.85\n",
      "sin i= 2500 train loss:  0.71  train acc: 0.89  val loss:  0.41  valid acc: 0.85\n",
      "sin i= 3000 train loss:  0.66  train acc: 0.90  val loss:  0.40  valid acc: 0.86\n",
      "sin i= 3500 train loss:  0.68  train acc: 0.88  val loss:  0.38  valid acc: 0.87\n",
      "sin i= 4000 train loss:  0.72  train acc: 0.85  val loss:  0.37  valid acc: 0.87\n",
      "sin i= 4500 train loss:  0.63  train acc: 0.88  val loss:  0.38  valid acc: 0.86\n",
      "sin i= 5000 train loss:  0.66  train acc: 0.89  val loss:  0.35  valid acc: 0.88\n",
      "sin i= 5500 train loss:  0.62  train acc: 0.90  val loss:  0.36  valid acc: 0.87\n",
      "sin i= 6000 train loss:  0.62  train acc: 0.90  val loss:  0.34  valid acc: 0.88\n",
      "sin i= 6500 train loss:  0.62  train acc: 0.90  val loss:  0.35  valid acc: 0.88\n",
      "sin i= 7000 train loss:  0.62  train acc: 0.90  val loss:  0.33  valid acc: 0.89\n",
      "sin i= 7500 train loss:  0.62  train acc: 0.90  val loss:  0.33  valid acc: 0.89\n",
      "sin i= 8000 train loss:  0.59  train acc: 0.91  val loss:  0.32  valid acc: 0.89\n",
      "sin i= 8500 train loss:  0.58  train acc: 0.91  val loss:  0.32  valid acc: 0.89\n",
      "sin i= 9000 train loss:  0.56  train acc: 0.91  val loss:  0.33  valid acc: 0.88\n",
      "sin i= 9500 train loss:  0.58  train acc: 0.91  val loss:  0.31  valid acc: 0.90\n",
      "sin i=10000 train loss:  0.58  train acc: 0.91  val loss:  0.31  valid acc: 0.89\n",
      "-> sin layer idx: 15 , best valid accuracy: 0.90, test accuracy: 0.90\n",
      "sin i=    0 train loss: 11.49  train acc: 0.00  val loss:  7.90  valid acc: 0.00\n",
      "sin i=  500 train loss:  1.93  train acc: 0.66  val loss:  1.39  valid acc: 0.55\n",
      "sin i= 1000 train loss:  1.58  train acc: 0.73  val loss:  1.08  valid acc: 0.63\n",
      "sin i= 1500 train loss:  1.41  train acc: 0.76  val loss:  0.95  valid acc: 0.67\n",
      "sin i= 2000 train loss:  1.33  train acc: 0.78  val loss:  0.89  valid acc: 0.70\n",
      "sin i= 2500 train loss:  1.28  train acc: 0.81  val loss:  0.77  valid acc: 0.73\n",
      "sin i= 3000 train loss:  1.19  train acc: 0.82  val loss:  0.74  valid acc: 0.74\n",
      "sin i= 3500 train loss:  1.18  train acc: 0.82  val loss:  0.70  valid acc: 0.77\n",
      "sin i= 4000 train loss:  1.20  train acc: 0.80  val loss:  0.70  valid acc: 0.76\n",
      "sin i= 4500 train loss:  1.11  train acc: 0.83  val loss:  0.66  valid acc: 0.77\n",
      "sin i= 5000 train loss:  1.11  train acc: 0.84  val loss:  0.63  valid acc: 0.78\n",
      "sin i= 5500 train loss:  1.09  train acc: 0.84  val loss:  0.63  valid acc: 0.79\n",
      "sin i= 6000 train loss:  1.08  train acc: 0.83  val loss:  0.61  valid acc: 0.79\n",
      "sin i= 6500 train loss:  1.05  train acc: 0.85  val loss:  0.61  valid acc: 0.79\n",
      "sin i= 7000 train loss:  1.03  train acc: 0.85  val loss:  0.59  valid acc: 0.80\n",
      "sin i= 7500 train loss:  1.05  train acc: 0.87  val loss:  0.58  valid acc: 0.80\n",
      "sin i= 8000 train loss:  1.00  train acc: 0.86  val loss:  0.57  valid acc: 0.81\n",
      "sin i= 8500 train loss:  1.00  train acc: 0.85  val loss:  0.57  valid acc: 0.80\n",
      "sin i= 9000 train loss:  0.94  train acc: 0.88  val loss:  0.54  valid acc: 0.82\n",
      "sin i= 9500 train loss:  0.97  train acc: 0.87  val loss:  0.55  valid acc: 0.82\n",
      "sin i=10000 train loss:  0.97  train acc: 0.88  val loss:  0.55  valid acc: 0.81\n",
      "-> sin layer idx: 14 , best valid accuracy: 0.82, test accuracy: 0.85\n",
      "sin i=    0 train loss: 11.61  train acc: 0.00  val loss:  8.32  valid acc: 0.01\n",
      "sin i=  500 train loss:  2.86  train acc: 0.47  val loss:  2.54  valid acc: 0.14\n",
      "sin i= 1000 train loss:  2.23  train acc: 0.59  val loss:  1.84  valid acc: 0.35\n",
      "sin i= 1500 train loss:  1.99  train acc: 0.68  val loss:  1.59  valid acc: 0.44\n",
      "sin i= 2000 train loss:  1.85  train acc: 0.70  val loss:  1.46  valid acc: 0.48\n",
      "sin i= 2500 train loss:  1.82  train acc: 0.70  val loss:  1.27  valid acc: 0.56\n",
      "sin i= 3000 train loss:  1.69  train acc: 0.74  val loss:  1.17  valid acc: 0.60\n",
      "sin i= 3500 train loss:  1.70  train acc: 0.74  val loss:  1.10  valid acc: 0.63\n",
      "sin i= 4000 train loss:  1.67  train acc: 0.74  val loss:  1.03  valid acc: 0.65\n",
      "sin i= 4500 train loss:  1.59  train acc: 0.76  val loss:  1.00  valid acc: 0.66\n",
      "sin i= 5000 train loss:  1.57  train acc: 0.77  val loss:  1.02  valid acc: 0.66\n",
      "sin i= 5500 train loss:  1.55  train acc: 0.78  val loss:  0.93  valid acc: 0.69\n",
      "sin i= 6000 train loss:  1.54  train acc: 0.77  val loss:  0.91  valid acc: 0.70\n",
      "sin i= 6500 train loss:  1.53  train acc: 0.77  val loss:  0.89  valid acc: 0.71\n",
      "sin i= 7000 train loss:  1.52  train acc: 0.76  val loss:  0.88  valid acc: 0.70\n",
      "sin i= 7500 train loss:  1.56  train acc: 0.78  val loss:  0.88  valid acc: 0.72\n",
      "sin i= 8000 train loss:  1.47  train acc: 0.79  val loss:  0.83  valid acc: 0.73\n",
      "sin i= 8500 train loss:  1.48  train acc: 0.78  val loss:  0.81  valid acc: 0.74\n",
      "sin i= 9000 train loss:  1.40  train acc: 0.83  val loss:  0.83  valid acc: 0.73\n",
      "sin i= 9500 train loss:  1.46  train acc: 0.79  val loss:  0.81  valid acc: 0.73\n",
      "sin i=10000 train loss:  1.39  train acc: 0.84  val loss:  0.80  valid acc: 0.74\n",
      "-> sin layer idx: 13 , best valid accuracy: 0.74, test accuracy: 0.79\n",
      "sin i=    0 train loss: 11.44  train acc: 0.00  val loss:  8.10  valid acc: 0.02\n",
      "sin i=  500 train loss:  4.62  train acc: 0.03  val loss:  4.61  valid acc: 0.00\n",
      "sin i= 1000 train loss:  4.45  train acc: 0.04  val loss:  4.37  valid acc: 0.00\n",
      "sin i= 1500 train loss:  4.38  train acc: 0.04  val loss:  4.33  valid acc: 0.00\n",
      "sin i= 2000 train loss:  4.36  train acc: 0.05  val loss:  4.25  valid acc: 0.00\n",
      "sin i= 2500 train loss:  4.35  train acc: 0.05  val loss:  4.23  valid acc: 0.00\n",
      "sin i= 3000 train loss:  4.32  train acc: 0.05  val loss:  4.20  valid acc: 0.00\n",
      "sin i= 3500 train loss:  4.32  train acc: 0.05  val loss:  4.20  valid acc: 0.00\n",
      "sin i= 4000 train loss:  4.32  train acc: 0.05  val loss:  4.20  valid acc: 0.00\n",
      "sin i= 4500 train loss:  4.32  train acc: 0.05  val loss:  4.17  valid acc: 0.00\n",
      "sin i= 5000 train loss:  4.34  train acc: 0.04  val loss:  4.18  valid acc: 0.00\n",
      "sin i= 5500 train loss:  4.31  train acc: 0.06  val loss:  4.16  valid acc: 0.00\n",
      "sin i= 6000 train loss:  4.28  train acc: 0.05  val loss:  4.12  valid acc: 0.01\n",
      "sin i= 6500 train loss:  4.31  train acc: 0.04  val loss:  4.16  valid acc: 0.00\n",
      "sin i= 7000 train loss:  4.29  train acc: 0.06  val loss:  4.15  valid acc: 0.00\n",
      "sin i= 7500 train loss:  4.29  train acc: 0.04  val loss:  4.12  valid acc: 0.00\n",
      "sin i= 8000 train loss:  4.26  train acc: 0.05  val loss:  4.15  valid acc: 0.00\n",
      "sin i= 8500 train loss:  4.27  train acc: 0.05  val loss:  4.14  valid acc: 0.00\n",
      "sin i= 9000 train loss:  4.30  train acc: 0.05  val loss:  4.12  valid acc: 0.00\n",
      "sin i= 9500 train loss:  4.27  train acc: 0.04  val loss:  4.15  valid acc: 0.00\n",
      "sin i=10000 train loss:  4.26  train acc: 0.05  val loss:  4.09  valid acc: 0.01\n",
      "-> sin layer idx: 12 , best valid accuracy: 0.02, test accuracy: 0.05\n",
      "sin i=    0 train loss: 11.47  train acc: 0.00  val loss:  7.36  valid acc: 0.01\n",
      "sin i=  500 train loss:  5.85  train acc: 0.00  val loss:  5.72  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.57  train acc: 0.01  val loss:  5.50  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.49  train acc: 0.01  val loss:  5.39  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.39  train acc: 0.02  val loss:  5.33  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.38  train acc: 0.01  val loss:  5.29  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.35  train acc: 0.02  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.35  train acc: 0.01  val loss:  5.20  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.32  train acc: 0.01  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.33  train acc: 0.02  val loss:  5.20  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.31  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.28  train acc: 0.02  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.28  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.32  train acc: 0.02  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.22  train acc: 0.02  val loss:  5.14  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.23  train acc: 0.01  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.24  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.20  train acc: 0.02  val loss:  5.08  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.24  train acc: 0.02  val loss:  5.08  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.19  train acc: 0.02  val loss:  5.10  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.20  train acc: 0.01  val loss:  5.07  valid acc: 0.00\n",
      "-> sin layer idx: 11 , best valid accuracy: 0.01, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.41  train acc: 0.00  val loss:  7.47  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.90  train acc: 0.01  val loss:  5.88  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.62  train acc: 0.01  val loss:  5.64  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.52  train acc: 0.01  val loss:  5.55  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.46  train acc: 0.01  val loss:  5.49  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.44  train acc: 0.01  val loss:  5.44  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.44  train acc: 0.02  val loss:  5.40  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.40  train acc: 0.02  val loss:  5.36  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.39  train acc: 0.01  val loss:  5.41  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.37  train acc: 0.02  val loss:  5.38  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.37  train acc: 0.01  val loss:  5.33  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.39  train acc: 0.01  val loss:  5.40  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.36  train acc: 0.01  val loss:  5.32  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.39  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.30  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.32  train acc: 0.01  val loss:  5.31  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.29  train acc: 0.01  val loss:  5.31  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.31  train acc: 0.01  val loss:  5.27  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.30  train acc: 0.01  val loss:  5.26  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.29  train acc: 0.02  val loss:  5.27  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.28  train acc: 0.01  val loss:  5.24  valid acc: 0.00\n",
      "-> sin layer idx: 10 , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.36  train acc: 0.00  val loss:  7.67  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.90  train acc: 0.00  val loss:  5.81  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.65  train acc: 0.01  val loss:  5.60  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.56  train acc: 0.01  val loss:  5.52  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.52  train acc: 0.02  val loss:  5.47  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.48  train acc: 0.01  val loss:  5.42  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.45  train acc: 0.02  val loss:  5.35  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.43  train acc: 0.02  val loss:  5.34  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.40  train acc: 0.01  val loss:  5.34  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.40  train acc: 0.02  val loss:  5.32  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.40  train acc: 0.01  val loss:  5.28  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.37  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.39  train acc: 0.02  val loss:  5.26  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.42  train acc: 0.01  val loss:  5.24  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.35  train acc: 0.02  val loss:  5.25  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.36  train acc: 0.01  val loss:  5.24  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.32  train acc: 0.02  val loss:  5.25  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.35  train acc: 0.01  val loss:  5.21  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.34  train acc: 0.02  val loss:  5.21  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.32  train acc: 0.01  val loss:  5.21  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.33  train acc: 0.01  val loss:  5.19  valid acc: 0.00\n",
      "-> sin layer idx: 9  , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.34  train acc: 0.00  val loss:  7.53  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.81  train acc: 0.01  val loss:  5.85  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.56  train acc: 0.01  val loss:  5.56  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.45  train acc: 0.01  val loss:  5.47  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.41  train acc: 0.02  val loss:  5.43  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.38  train acc: 0.01  val loss:  5.36  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.34  train acc: 0.02  val loss:  5.32  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.31  train acc: 0.01  val loss:  5.28  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.28  train acc: 0.01  val loss:  5.32  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.32  train acc: 0.02  val loss:  5.28  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.29  train acc: 0.01  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.28  train acc: 0.01  val loss:  5.22  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.27  train acc: 0.01  val loss:  5.22  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.31  train acc: 0.01  val loss:  5.20  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.22  train acc: 0.02  val loss:  5.21  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.24  train acc: 0.01  val loss:  5.20  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.22  train acc: 0.02  val loss:  5.22  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.25  train acc: 0.02  val loss:  5.15  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.26  train acc: 0.01  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.22  train acc: 0.02  val loss:  5.18  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.25  train acc: 0.01  val loss:  5.16  valid acc: 0.00\n",
      "-> sin layer idx: 8  , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.34  train acc: 0.00  val loss:  7.68  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.73  train acc: 0.01  val loss:  5.74  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.49  train acc: 0.01  val loss:  5.53  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.40  train acc: 0.01  val loss:  5.41  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.36  train acc: 0.01  val loss:  5.36  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.32  train acc: 0.01  val loss:  5.29  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.29  train acc: 0.02  val loss:  5.25  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.27  train acc: 0.02  val loss:  5.21  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.23  train acc: 0.01  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.26  train acc: 0.02  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.24  train acc: 0.01  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.23  train acc: 0.02  val loss:  5.15  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.23  train acc: 0.01  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.24  train acc: 0.02  val loss:  5.15  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.18  train acc: 0.01  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.19  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.17  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.21  train acc: 0.02  val loss:  5.11  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.22  train acc: 0.01  val loss:  5.10  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.19  train acc: 0.01  val loss:  5.12  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.20  train acc: 0.02  val loss:  5.11  valid acc: 0.00\n",
      "-> sin layer idx: 7  , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.36  train acc: 0.00  val loss:  7.33  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.76  train acc: 0.01  val loss:  5.66  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.50  train acc: 0.01  val loss:  5.43  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.41  train acc: 0.01  val loss:  5.33  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.38  train acc: 0.02  val loss:  5.29  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.38  train acc: 0.01  val loss:  5.26  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.34  train acc: 0.02  val loss:  5.21  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.29  train acc: 0.01  val loss:  5.19  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.26  train acc: 0.01  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.27  train acc: 0.02  val loss:  5.18  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.26  train acc: 0.01  val loss:  5.13  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.24  train acc: 0.02  val loss:  5.13  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.25  train acc: 0.02  val loss:  5.12  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.25  train acc: 0.02  val loss:  5.09  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.18  train acc: 0.01  val loss:  5.09  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.20  train acc: 0.02  val loss:  5.12  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.18  train acc: 0.02  val loss:  5.11  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.20  train acc: 0.02  val loss:  5.06  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.23  train acc: 0.01  val loss:  5.05  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.19  train acc: 0.02  val loss:  5.07  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.21  train acc: 0.01  val loss:  5.05  valid acc: 0.00\n",
      "-> sin layer idx: 6  , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.36  train acc: 0.01  val loss:  7.51  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.77  train acc: 0.01  val loss:  5.73  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.55  train acc: 0.01  val loss:  5.50  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.48  train acc: 0.00  val loss:  5.42  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.41  train acc: 0.02  val loss:  5.36  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.37  train acc: 0.01  val loss:  5.32  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.32  train acc: 0.02  val loss:  5.26  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.26  train acc: 0.01  val loss:  5.24  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.23  train acc: 0.02  val loss:  5.26  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.25  train acc: 0.02  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.23  train acc: 0.01  val loss:  5.19  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.22  train acc: 0.02  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.22  train acc: 0.01  val loss:  5.18  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.23  train acc: 0.02  val loss:  5.15  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.17  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.18  train acc: 0.01  val loss:  5.15  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.22  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.21  train acc: 0.01  val loss:  5.15  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.19  train acc: 0.01  val loss:  5.11  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.16  train acc: 0.02  val loss:  5.15  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.19  train acc: 0.01  val loss:  5.12  valid acc: 0.00\n",
      "-> sin layer idx: 5  , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.38  train acc: 0.00  val loss:  7.09  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.66  train acc: 0.01  val loss:  5.73  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.41  train acc: 0.01  val loss:  5.50  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.34  train acc: 0.01  val loss:  5.40  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.29  train acc: 0.01  val loss:  5.35  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.30  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.28  train acc: 0.02  val loss:  5.25  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.26  train acc: 0.02  val loss:  5.23  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.23  train acc: 0.01  val loss:  5.29  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.26  train acc: 0.02  val loss:  5.25  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.27  train acc: 0.01  val loss:  5.20  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.26  train acc: 0.01  val loss:  5.18  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.26  train acc: 0.01  val loss:  5.19  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.24  train acc: 0.01  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.21  train acc: 0.01  val loss:  5.17  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.20  train acc: 0.02  val loss:  5.16  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.25  train acc: 0.01  val loss:  5.18  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.24  train acc: 0.01  val loss:  5.13  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.20  train acc: 0.02  val loss:  5.12  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.20  train acc: 0.01  val loss:  5.16  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.20  train acc: 0.01  val loss:  5.13  valid acc: 0.00\n",
      "-> sin layer idx: 4  , best valid accuracy: 0.00, test accuracy: 0.01\n",
      "sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.92  valid acc: 0.00\n",
      "sin i=  500 train loss:  5.91  train acc: 0.01  val loss:  5.90  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.70  train acc: 0.01  val loss:  5.64  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.62  train acc: 0.01  val loss:  5.53  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.60  train acc: 0.01  val loss:  5.49  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.57  train acc: 0.02  val loss:  5.45  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.54  train acc: 0.01  val loss:  5.43  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.51  train acc: 0.01  val loss:  5.41  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.49  train acc: 0.01  val loss:  5.43  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.51  train acc: 0.02  val loss:  5.42  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.51  train acc: 0.01  val loss:  5.36  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.52  train acc: 0.02  val loss:  5.36  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.52  train acc: 0.02  val loss:  5.37  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.50  train acc: 0.01  val loss:  5.35  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.46  train acc: 0.01  val loss:  5.34  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.45  train acc: 0.01  val loss:  5.34  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.47  train acc: 0.01  val loss:  5.34  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.47  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.48  train acc: 0.01  val loss:  5.31  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.47  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.47  train acc: 0.01  val loss:  5.30  valid acc: 0.00\n",
      "-> sin layer idx: 3  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "sin i=    0 train loss: 11.37  train acc: 0.00  val loss:  6.90  valid acc: 0.00\n",
      "sin i=  500 train loss:  6.04  train acc: 0.01  val loss:  5.98  valid acc: 0.00\n",
      "sin i= 1000 train loss:  5.88  train acc: 0.01  val loss:  5.83  valid acc: 0.00\n",
      "sin i= 1500 train loss:  5.84  train acc: 0.01  val loss:  5.73  valid acc: 0.00\n",
      "sin i= 2000 train loss:  5.80  train acc: 0.01  val loss:  5.69  valid acc: 0.00\n",
      "sin i= 2500 train loss:  5.75  train acc: 0.01  val loss:  5.65  valid acc: 0.00\n",
      "sin i= 3000 train loss:  5.74  train acc: 0.00  val loss:  5.61  valid acc: 0.00\n",
      "sin i= 3500 train loss:  5.68  train acc: 0.01  val loss:  5.59  valid acc: 0.00\n",
      "sin i= 4000 train loss:  5.67  train acc: 0.01  val loss:  5.61  valid acc: 0.00\n",
      "sin i= 4500 train loss:  5.67  train acc: 0.02  val loss:  5.57  valid acc: 0.00\n",
      "sin i= 5000 train loss:  5.69  train acc: 0.01  val loss:  5.54  valid acc: 0.00\n",
      "sin i= 5500 train loss:  5.69  train acc: 0.01  val loss:  5.56  valid acc: 0.00\n",
      "sin i= 6000 train loss:  5.72  train acc: 0.01  val loss:  5.52  valid acc: 0.00\n",
      "sin i= 6500 train loss:  5.68  train acc: 0.01  val loss:  5.51  valid acc: 0.00\n",
      "sin i= 7000 train loss:  5.63  train acc: 0.01  val loss:  5.53  valid acc: 0.00\n",
      "sin i= 7500 train loss:  5.65  train acc: 0.01  val loss:  5.53  valid acc: 0.00\n",
      "sin i= 8000 train loss:  5.65  train acc: 0.01  val loss:  5.53  valid acc: 0.00\n",
      "sin i= 8500 train loss:  5.65  train acc: 0.01  val loss:  5.48  valid acc: 0.00\n",
      "sin i= 9000 train loss:  5.63  train acc: 0.01  val loss:  5.47  valid acc: 0.00\n",
      "sin i= 9500 train loss:  5.66  train acc: 0.01  val loss:  5.47  valid acc: 0.00\n",
      "sin i=10000 train loss:  5.64  train acc: 0.01  val loss:  5.46  valid acc: 0.00\n",
      "-> sin layer idx: 2  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "sin i=    0 train loss: 11.37  train acc: 0.00  val loss:  6.81  valid acc: 0.00\n",
      "sin i=  500 train loss:  6.44  train acc: 0.00  val loss:  6.63  valid acc: 0.00\n",
      "sin i= 1000 train loss:  6.26  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "sin i= 1500 train loss:  6.21  train acc: 0.01  val loss:  6.32  valid acc: 0.00\n",
      "sin i= 2000 train loss:  6.19  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "sin i= 2500 train loss:  6.18  train acc: 0.00  val loss:  6.25  valid acc: 0.00\n",
      "sin i= 3000 train loss:  6.18  train acc: 0.01  val loss:  6.26  valid acc: 0.00\n",
      "sin i= 3500 train loss:  6.13  train acc: 0.01  val loss:  6.25  valid acc: 0.00\n",
      "sin i= 4000 train loss:  6.12  train acc: 0.01  val loss:  6.21  valid acc: 0.00\n",
      "sin i= 4500 train loss:  6.16  train acc: 0.01  val loss:  6.20  valid acc: 0.00\n",
      "sin i= 5000 train loss:  6.13  train acc: 0.01  val loss:  6.19  valid acc: 0.00\n",
      "sin i= 5500 train loss:  6.11  train acc: 0.00  val loss:  6.16  valid acc: 0.00\n",
      "sin i= 6000 train loss:  6.14  train acc: 0.00  val loss:  6.14  valid acc: 0.00\n",
      "sin i= 6500 train loss:  6.14  train acc: 0.00  val loss:  6.13  valid acc: 0.00\n",
      "sin i= 7000 train loss:  6.11  train acc: 0.01  val loss:  6.11  valid acc: 0.00\n",
      "sin i= 7500 train loss:  6.09  train acc: 0.00  val loss:  6.09  valid acc: 0.00\n",
      "sin i= 8000 train loss:  6.11  train acc: 0.01  val loss:  6.07  valid acc: 0.00\n",
      "sin i= 8500 train loss:  6.10  train acc: 0.01  val loss:  6.07  valid acc: 0.00\n",
      "sin i= 9000 train loss:  6.14  train acc: 0.01  val loss:  6.07  valid acc: 0.00\n",
      "sin i= 9500 train loss:  6.10  train acc: 0.00  val loss:  6.05  valid acc: 0.00\n",
      "sin i=10000 train loss:  6.11  train acc: 0.00  val loss:  6.05  valid acc: 0.00\n",
      "-> sin layer idx: 1  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "sin i=    0 train loss: 11.37  train acc: 0.00  val loss:  6.80  valid acc: 0.00\n",
      "sin i=  500 train loss:  6.68  train acc: 0.00  val loss:  7.07  valid acc: 0.00\n",
      "sin i= 1000 train loss:  6.69  train acc: 0.00  val loss:  7.06  valid acc: 0.00\n",
      "sin i= 1500 train loss:  6.70  train acc: 0.00  val loss:  7.04  valid acc: 0.00\n",
      "sin i= 2000 train loss:  6.73  train acc: 0.00  val loss:  7.05  valid acc: 0.00\n",
      "sin i= 2500 train loss:  6.71  train acc: 0.00  val loss:  7.06  valid acc: 0.00\n",
      "sin i= 3000 train loss:  6.71  train acc: 0.00  val loss:  7.07  valid acc: 0.00\n",
      "sin i= 3500 train loss:  6.70  train acc: 0.00  val loss:  7.04  valid acc: 0.00\n",
      "sin i= 4000 train loss:  6.70  train acc: 0.00  val loss:  7.07  valid acc: 0.00\n",
      "sin i= 4500 train loss:  6.69  train acc: 0.00  val loss:  7.10  valid acc: 0.00\n",
      "sin i= 5000 train loss:  6.69  train acc: 0.00  val loss:  7.10  valid acc: 0.00\n",
      "sin i= 5500 train loss:  6.67  train acc: 0.00  val loss:  7.08  valid acc: 0.00\n",
      "sin i= 6000 train loss:  6.67  train acc: 0.00  val loss:  7.09  valid acc: 0.00\n",
      "sin i= 6500 train loss:  6.71  train acc: 0.00  val loss:  7.06  valid acc: 0.00\n",
      "sin i= 7000 train loss:  6.68  train acc: 0.00  val loss:  7.11  valid acc: 0.00\n",
      "sin i= 7500 train loss:  6.72  train acc: 0.00  val loss:  7.09  valid acc: 0.00\n",
      "sin i= 8000 train loss:  6.70  train acc: 0.00  val loss:  7.12  valid acc: 0.00\n",
      "sin i= 8500 train loss:  6.71  train acc: 0.00  val loss:  7.06  valid acc: 0.00\n",
      "sin i= 9000 train loss:  6.70  train acc: 0.00  val loss:  7.08  valid acc: 0.00\n",
      "sin i= 9500 train loss:  6.72  train acc: 0.00  val loss:  7.08  valid acc: 0.00\n",
      "sin i=10000 train loss:  6.70  train acc: 0.00  val loss:  7.07  valid acc: 0.00\n",
      "-> sin layer idx: 0  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss: 12.62  train acc: 0.00  val loss: 49.98  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.79  train acc: 0.10  val loss:  4.25  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.20  train acc: 0.10  val loss:  4.11  valid acc: 0.02\n",
      "bin i= 1500 train loss:  3.91  train acc: 0.11  val loss:  4.09  valid acc: 0.02\n",
      "bin i= 2000 train loss:  3.84  train acc: 0.12  val loss:  4.27  valid acc: 0.01\n",
      "bin i= 2500 train loss:  3.76  train acc: 0.11  val loss:  4.20  valid acc: 0.01\n",
      "bin i= 3000 train loss:  3.82  train acc: 0.10  val loss:  4.09  valid acc: 0.01\n",
      "bin i= 3500 train loss:  3.53  train acc: 0.12  val loss:  4.06  valid acc: 0.01\n",
      "bin i= 4000 train loss:  3.64  train acc: 0.12  val loss:  3.97  valid acc: 0.01\n",
      "bin i= 4500 train loss:  3.59  train acc: 0.12  val loss:  4.04  valid acc: 0.01\n",
      "bin i= 5000 train loss:  3.45  train acc: 0.12  val loss:  3.92  valid acc: 0.01\n",
      "bin i= 5500 train loss:  3.36  train acc: 0.13  val loss:  3.99  valid acc: 0.01\n",
      "bin i= 6000 train loss:  3.35  train acc: 0.13  val loss:  4.01  valid acc: 0.01\n",
      "bin i= 6500 train loss:  3.40  train acc: 0.13  val loss:  4.13  valid acc: 0.01\n",
      "bin i= 7000 train loss:  3.31  train acc: 0.13  val loss:  3.85  valid acc: 0.01\n",
      "bin i= 7500 train loss:  3.39  train acc: 0.12  val loss:  3.94  valid acc: 0.01\n",
      "bin i= 8000 train loss:  3.32  train acc: 0.14  val loss:  3.95  valid acc: 0.01\n",
      "bin i= 8500 train loss:  3.29  train acc: 0.14  val loss:  4.10  valid acc: 0.02\n",
      "bin i= 9000 train loss:  3.27  train acc: 0.15  val loss:  3.91  valid acc: 0.01\n",
      "bin i= 9500 train loss:  3.22  train acc: 0.16  val loss:  3.96  valid acc: 0.01\n",
      "bin i=10000 train loss:  3.22  train acc: 0.14  val loss:  3.92  valid acc: 0.01\n",
      "-> bin layer idx: 16 , best valid accuracy: 0.02, test accuracy: 0.02\n",
      "bin i=    0 train loss:  9.50  train acc: 0.00  val loss:  8.18  valid acc: 0.00\n",
      "bin i=  500 train loss:  4.70  train acc: 0.02  val loss:  4.46  valid acc: 0.01\n",
      "bin i= 1000 train loss:  4.48  train acc: 0.04  val loss:  4.28  valid acc: 0.02\n",
      "bin i= 1500 train loss:  4.42  train acc: 0.03  val loss:  4.29  valid acc: 0.01\n",
      "bin i= 2000 train loss:  4.32  train acc: 0.02  val loss:  4.25  valid acc: 0.00\n",
      "bin i= 2500 train loss:  4.28  train acc: 0.03  val loss:  4.18  valid acc: 0.00\n",
      "bin i= 3000 train loss:  4.22  train acc: 0.02  val loss:  4.14  valid acc: 0.00\n",
      "bin i= 3500 train loss:  4.14  train acc: 0.02  val loss:  4.10  valid acc: 0.00\n",
      "bin i= 4000 train loss:  4.14  train acc: 0.02  val loss:  4.11  valid acc: 0.00\n",
      "bin i= 4500 train loss:  4.17  train acc: 0.02  val loss:  4.07  valid acc: 0.00\n",
      "bin i= 5000 train loss:  4.09  train acc: 0.02  val loss:  4.04  valid acc: 0.00\n",
      "bin i= 5500 train loss:  4.09  train acc: 0.03  val loss:  4.16  valid acc: 0.00\n",
      "bin i= 6000 train loss:  4.09  train acc: 0.02  val loss:  4.06  valid acc: 0.00\n",
      "bin i= 6500 train loss:  4.11  train acc: 0.03  val loss:  4.06  valid acc: 0.01\n",
      "bin i= 7000 train loss:  4.06  train acc: 0.02  val loss:  4.08  valid acc: 0.00\n",
      "bin i= 7500 train loss:  4.07  train acc: 0.02  val loss:  4.06  valid acc: 0.00\n",
      "bin i= 8000 train loss:  4.04  train acc: 0.02  val loss:  3.99  valid acc: 0.00\n",
      "bin i= 8500 train loss:  4.05  train acc: 0.03  val loss:  4.02  valid acc: 0.02\n",
      "bin i= 9000 train loss:  4.05  train acc: 0.03  val loss:  4.05  valid acc: 0.01\n",
      "bin i= 9500 train loss:  4.03  train acc: 0.02  val loss:  4.01  valid acc: 0.01\n",
      "bin i=10000 train loss:  4.04  train acc: 0.02  val loss:  3.98  valid acc: 0.00\n",
      "-> bin layer idx: 15 , best valid accuracy: 0.02, test accuracy: 0.01\n",
      "bin i=    0 train loss:  9.37  train acc: 0.00  val loss:  7.87  valid acc: 0.00\n",
      "bin i=  500 train loss:  5.00  train acc: 0.01  val loss:  4.73  valid acc: 0.00\n",
      "bin i= 1000 train loss:  4.78  train acc: 0.02  val loss:  4.67  valid acc: 0.01\n",
      "bin i= 1500 train loss:  4.75  train acc: 0.01  val loss:  4.62  valid acc: 0.01\n",
      "bin i= 2000 train loss:  4.68  train acc: 0.03  val loss:  4.62  valid acc: 0.01\n",
      "bin i= 2500 train loss:  4.68  train acc: 0.02  val loss:  4.62  valid acc: 0.01\n",
      "bin i= 3000 train loss:  4.63  train acc: 0.02  val loss:  4.59  valid acc: 0.01\n",
      "bin i= 3500 train loss:  4.55  train acc: 0.01  val loss:  4.57  valid acc: 0.00\n",
      "bin i= 4000 train loss:  4.60  train acc: 0.02  val loss:  4.58  valid acc: 0.01\n",
      "bin i= 4500 train loss:  4.59  train acc: 0.03  val loss:  4.59  valid acc: 0.00\n",
      "bin i= 5000 train loss:  4.52  train acc: 0.02  val loss:  4.54  valid acc: 0.01\n",
      "bin i= 5500 train loss:  4.53  train acc: 0.01  val loss:  4.56  valid acc: 0.00\n",
      "bin i= 6000 train loss:  4.53  train acc: 0.03  val loss:  4.54  valid acc: 0.01\n",
      "bin i= 6500 train loss:  4.56  train acc: 0.02  val loss:  4.56  valid acc: 0.01\n",
      "bin i= 7000 train loss:  4.49  train acc: 0.01  val loss:  4.52  valid acc: 0.00\n",
      "bin i= 7500 train loss:  4.49  train acc: 0.03  val loss:  4.54  valid acc: 0.01\n",
      "bin i= 8000 train loss:  4.47  train acc: 0.02  val loss:  4.53  valid acc: 0.01\n",
      "bin i= 8500 train loss:  4.49  train acc: 0.02  val loss:  4.54  valid acc: 0.02\n",
      "bin i= 9000 train loss:  4.50  train acc: 0.01  val loss:  4.48  valid acc: 0.02\n",
      "bin i= 9500 train loss:  4.47  train acc: 0.03  val loss:  4.51  valid acc: 0.00\n",
      "bin i=10000 train loss:  4.47  train acc: 0.01  val loss:  4.48  valid acc: 0.00\n",
      "-> bin layer idx: 14 , best valid accuracy: 0.02, test accuracy: 0.01\n",
      "bin i=    0 train loss:  9.33  train acc: 0.00  val loss:  7.63  valid acc: 0.00\n",
      "bin i=  500 train loss:  5.25  train acc: 0.01  val loss:  4.98  valid acc: 0.00\n",
      "bin i= 1000 train loss:  5.06  train acc: 0.01  val loss:  4.95  valid acc: 0.00\n",
      "bin i= 1500 train loss:  5.03  train acc: 0.01  val loss:  4.90  valid acc: 0.00\n",
      "bin i= 2000 train loss:  4.97  train acc: 0.01  val loss:  4.92  valid acc: 0.00\n",
      "bin i= 2500 train loss:  4.93  train acc: 0.00  val loss:  4.85  valid acc: 0.00\n",
      "bin i= 3000 train loss:  4.92  train acc: 0.01  val loss:  4.84  valid acc: 0.00\n",
      "bin i= 3500 train loss:  4.83  train acc: 0.01  val loss:  4.81  valid acc: 0.01\n",
      "bin i= 4000 train loss:  4.83  train acc: 0.01  val loss:  4.85  valid acc: 0.00\n",
      "bin i= 4500 train loss:  4.85  train acc: 0.02  val loss:  4.85  valid acc: 0.01\n",
      "bin i= 5000 train loss:  4.79  train acc: 0.01  val loss:  4.79  valid acc: 0.01\n",
      "bin i= 5500 train loss:  4.82  train acc: 0.01  val loss:  4.80  valid acc: 0.00\n",
      "bin i= 6000 train loss:  4.81  train acc: 0.03  val loss:  4.79  valid acc: 0.00\n",
      "bin i= 6500 train loss:  4.85  train acc: 0.01  val loss:  4.80  valid acc: 0.01\n",
      "bin i= 7000 train loss:  4.74  train acc: 0.01  val loss:  4.73  valid acc: 0.00\n",
      "bin i= 7500 train loss:  4.78  train acc: 0.01  val loss:  4.72  valid acc: 0.01\n",
      "bin i= 8000 train loss:  4.74  train acc: 0.01  val loss:  4.72  valid acc: 0.01\n",
      "bin i= 8500 train loss:  4.73  train acc: 0.01  val loss:  4.76  valid acc: 0.01\n",
      "bin i= 9000 train loss:  4.74  train acc: 0.02  val loss:  4.68  valid acc: 0.01\n",
      "bin i= 9500 train loss:  4.72  train acc: 0.01  val loss:  4.72  valid acc: 0.00\n",
      "bin i=10000 train loss:  4.75  train acc: 0.01  val loss:  4.71  valid acc: 0.00\n",
      "-> bin layer idx: 13 , best valid accuracy: 0.01, test accuracy: 0.01\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  7.48  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.11  train acc: 0.00  val loss:  5.89  valid acc: 0.01\n",
      "bin i= 1000 train loss:  5.83  train acc: 0.01  val loss:  5.55  valid acc: 0.01\n",
      "bin i= 1500 train loss:  5.71  train acc: 0.01  val loss:  5.38  valid acc: 0.00\n",
      "bin i= 2000 train loss:  5.61  train acc: 0.01  val loss:  5.30  valid acc: 0.01\n",
      "bin i= 2500 train loss:  5.56  train acc: 0.01  val loss:  5.22  valid acc: 0.00\n",
      "bin i= 3000 train loss:  5.46  train acc: 0.01  val loss:  5.16  valid acc: 0.01\n",
      "bin i= 3500 train loss:  5.39  train acc: 0.01  val loss:  5.15  valid acc: 0.01\n",
      "bin i= 4000 train loss:  5.39  train acc: 0.01  val loss:  5.11  valid acc: 0.00\n",
      "bin i= 4500 train loss:  5.34  train acc: 0.00  val loss:  5.08  valid acc: 0.01\n",
      "bin i= 5000 train loss:  5.29  train acc: 0.02  val loss:  5.03  valid acc: 0.00\n",
      "bin i= 5500 train loss:  5.31  train acc: 0.01  val loss:  5.04  valid acc: 0.01\n",
      "bin i= 6000 train loss:  5.29  train acc: 0.01  val loss:  5.00  valid acc: 0.01\n",
      "bin i= 6500 train loss:  5.31  train acc: 0.01  val loss:  4.99  valid acc: 0.00\n",
      "bin i= 7000 train loss:  5.21  train acc: 0.01  val loss:  4.96  valid acc: 0.00\n",
      "bin i= 7500 train loss:  5.26  train acc: 0.01  val loss:  4.97  valid acc: 0.01\n",
      "bin i= 8000 train loss:  5.19  train acc: 0.01  val loss:  4.95  valid acc: 0.00\n",
      "bin i= 8500 train loss:  5.19  train acc: 0.01  val loss:  4.93  valid acc: 0.00\n",
      "bin i= 9000 train loss:  5.21  train acc: 0.01  val loss:  4.89  valid acc: 0.01\n",
      "bin i= 9500 train loss:  5.15  train acc: 0.01  val loss:  5.00  valid acc: 0.00\n",
      "bin i=10000 train loss:  5.18  train acc: 0.01  val loss:  4.87  valid acc: 0.00\n",
      "-> bin layer idx: 12 , best valid accuracy: 0.01, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.34  train acc: 0.00  val loss:  7.19  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.55  train acc: 0.00  val loss:  6.54  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.44  train acc: 0.00  val loss:  6.48  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.44  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.39  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.38  train acc: 0.00  val loss:  6.37  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.36  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.27  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.28  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.27  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.24  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.23  train acc: 0.00  val loss:  6.25  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.23  train acc: 0.01  val loss:  6.23  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.24  train acc: 0.00  val loss:  6.22  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.21  train acc: 0.00  val loss:  6.21  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.20  train acc: 0.00  val loss:  6.21  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.19  train acc: 0.00  val loss:  6.20  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.17  train acc: 0.00  val loss:  6.20  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.21  train acc: 0.00  val loss:  6.19  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.17  train acc: 0.00  val loss:  6.16  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.18  train acc: 0.00  val loss:  6.16  valid acc: 0.00\n",
      "-> bin layer idx: 11 , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.32  train acc: 0.00  val loss:  7.21  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.56  train acc: 0.00  val loss:  6.58  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.44  train acc: 0.00  val loss:  6.50  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.45  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.39  train acc: 0.00  val loss:  6.37  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.31  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.26  train acc: 0.00  val loss:  6.22  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.17  train acc: 0.00  val loss:  6.18  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.15  train acc: 0.00  val loss:  6.13  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.16  train acc: 0.00  val loss:  6.12  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.11  train acc: 0.00  val loss:  6.11  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.10  train acc: 0.00  val loss:  6.06  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.10  train acc: 0.01  val loss:  6.04  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.09  train acc: 0.00  val loss:  6.04  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.06  train acc: 0.00  val loss:  6.02  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.07  train acc: 0.00  val loss:  6.02  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.04  train acc: 0.00  val loss:  6.01  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.03  train acc: 0.01  val loss:  5.98  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.06  train acc: 0.00  val loss:  5.97  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.00  train acc: 0.00  val loss:  5.95  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.05  train acc: 0.00  val loss:  5.96  valid acc: 0.00\n",
      "-> bin layer idx: 10 , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.31  train acc: 0.00  val loss:  7.17  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.56  train acc: 0.00  val loss:  6.56  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.45  train acc: 0.00  val loss:  6.54  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.46  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.42  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.40  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.39  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.32  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.32  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.32  train acc: 0.01  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.29  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.29  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.29  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.29  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.27  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.28  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.25  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.25  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.27  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.26  train acc: 0.00  val loss:  6.30  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.26  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "-> bin layer idx: 9  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.34  train acc: 0.00  val loss:  7.07  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.55  train acc: 0.00  val loss:  6.56  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.45  train acc: 0.00  val loss:  6.52  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.45  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.42  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.39  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.38  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.31  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.32  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.32  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.28  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.28  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.28  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.28  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.27  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.26  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.24  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.24  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.25  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.24  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.26  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "-> bin layer idx: 8  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.33  train acc: 0.00  val loss:  7.11  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.54  train acc: 0.00  val loss:  6.55  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.44  train acc: 0.00  val loss:  6.52  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.44  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.40  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.38  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.37  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.31  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.31  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.31  train acc: 0.01  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.27  train acc: 0.00  val loss:  6.37  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.27  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.28  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.27  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.26  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.26  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.24  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.24  train acc: 0.01  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.24  train acc: 0.00  val loss:  6.30  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.23  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.25  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "-> bin layer idx: 7  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.35  train acc: 0.00  val loss:  7.11  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.55  train acc: 0.00  val loss:  6.56  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.45  train acc: 0.00  val loss:  6.51  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.45  train acc: 0.00  val loss:  6.42  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.41  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.38  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.37  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.32  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.31  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.32  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.28  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.27  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.28  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.28  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.27  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.26  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.25  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.25  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.25  train acc: 0.00  val loss:  6.30  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.23  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.26  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "-> bin layer idx: 6  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.36  train acc: 0.00  val loss:  7.02  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.53  train acc: 0.00  val loss:  6.53  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.43  train acc: 0.00  val loss:  6.46  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.44  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.40  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.36  train acc: 0.00  val loss:  6.37  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.35  train acc: 0.00  val loss:  6.35  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.29  train acc: 0.00  val loss:  6.37  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.29  train acc: 0.01  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.28  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.25  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.23  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.26  train acc: 0.00  val loss:  6.30  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.24  train acc: 0.00  val loss:  6.30  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.24  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.23  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.22  train acc: 0.00  val loss:  6.29  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.23  train acc: 0.00  val loss:  6.27  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.21  train acc: 0.00  val loss:  6.26  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.21  train acc: 0.00  val loss:  6.26  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.23  train acc: 0.00  val loss:  6.28  valid acc: 0.00\n",
      "-> bin layer idx: 5  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.36  train acc: 0.00  val loss:  6.96  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.57  train acc: 0.00  val loss:  6.59  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.48  train acc: 0.00  val loss:  6.52  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.47  train acc: 0.00  val loss:  6.45  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.43  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.40  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.39  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.33  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.33  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.34  train acc: 0.01  val loss:  6.38  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.29  train acc: 0.00  val loss:  6.37  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.29  train acc: 0.00  val loss:  6.36  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.31  train acc: 0.00  val loss:  6.35  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.29  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.28  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.27  train acc: 0.00  val loss:  6.34  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.26  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.27  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.28  train acc: 0.00  val loss:  6.32  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.26  train acc: 0.00  val loss:  6.33  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.27  train acc: 0.00  val loss:  6.31  valid acc: 0.00\n",
      "-> bin layer idx: 4  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.32  train acc: 0.00  val loss:  6.89  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.66  train acc: 0.00  val loss:  6.77  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.54  train acc: 0.00  val loss:  6.62  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.52  train acc: 0.00  val loss:  6.55  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.48  train acc: 0.00  val loss:  6.51  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.44  train acc: 0.00  val loss:  6.48  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.44  train acc: 0.00  val loss:  6.47  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.38  train acc: 0.00  val loss:  6.45  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.38  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.39  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.35  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.35  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.36  train acc: 0.00  val loss:  6.42  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.35  train acc: 0.00  val loss:  6.43  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.35  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.34  train acc: 0.00  val loss:  6.40  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.33  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.33  train acc: 0.00  val loss:  6.41  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.33  train acc: 0.00  val loss:  6.39  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.33  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.35  train acc: 0.00  val loss:  6.38  valid acc: 0.00\n",
      "-> bin layer idx: 3  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.32  train acc: 0.00  val loss:  6.87  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.69  train acc: 0.00  val loss:  6.83  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.67  train acc: 0.00  val loss:  6.77  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.62  train acc: 0.00  val loss:  6.68  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.55  train acc: 0.00  val loss:  6.61  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.53  train acc: 0.00  val loss:  6.57  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.46  train acc: 0.00  val loss:  6.54  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.45  train acc: 0.00  val loss:  6.52  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.46  train acc: 0.00  val loss:  6.52  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.41  train acc: 0.00  val loss:  6.50  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.42  train acc: 0.00  val loss:  6.48  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.43  train acc: 0.00  val loss:  6.48  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.42  train acc: 0.00  val loss:  6.46  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.40  train acc: 0.00  val loss:  6.46  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.39  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.39  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.38  train acc: 0.00  val loss:  6.44  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.38  train acc: 0.00  val loss:  6.42  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.37  train acc: 0.00  val loss:  6.42  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.39  train acc: 0.00  val loss:  6.42  valid acc: 0.00\n",
      "-> bin layer idx: 2  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.30  train acc: 0.00  val loss:  6.86  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.73  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.70  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.68  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.68  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.67  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.67  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.68  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.70  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.70  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "-> bin layer idx: 1  , best valid accuracy: 0.00, test accuracy: 0.00\n",
      "bin i=    0 train loss:  9.29  train acc: 0.00  val loss:  6.85  valid acc: 0.00\n",
      "bin i=  500 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 1000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 1500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 2000 train loss:  6.73  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 2500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 3000 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 3500 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 4000 train loss:  6.70  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 4500 train loss:  6.68  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 5000 train loss:  6.67  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 5500 train loss:  6.67  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 6000 train loss:  6.67  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 6500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 7000 train loss:  6.68  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 7500 train loss:  6.71  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 8000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 8500 train loss:  6.70  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 9000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i= 9500 train loss:  6.70  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "bin i=10000 train loss:  6.69  train acc: 0.00  val loss:  6.84  valid acc: 0.00\n",
      "-> bin layer idx: 0  , best valid accuracy: 0.00, test accuracy: 0.00\n"
     ]
    }
   ],
   "source": [
    "train_labels = torch.tensor([x1 + x2 for x1, x2 in train_inputs])\n",
    "valid_labels = torch.tensor([x1 + x2 for x1, x2 in valid_inputs]).to(device)\n",
    "test_labels = torch.tensor([x1 + x2 for x1, x2 in test_inputs]).to(device) \n",
    "\n",
    "test_accuracies = {\"sin\": {}, \"bin\": {}, \"lin\": {}, \"log\": {}}\n",
    "\n",
    "for basis_name, basis_embs in {\"sin\": basis_embs_sin, \"bin\": basis_embs_bin}.items():\n",
    "    for layer_idx in reversed(range(len(train_hidden_states))):\n",
    "\n",
    "        torch.manual_seed(0)\n",
    "        probe = ClassifierProbe(\n",
    "            emb_dim=train_hidden_states[0].shape[-1],\n",
    "            hidden_dim=100,\n",
    "            basis=basis_embs,\n",
    "            heldout_mask=test_mask,\n",
    "        ).to(device)\n",
    "\n",
    "        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)\n",
    "\n",
    "        rng = torch.Generator().manual_seed(0)\n",
    "        best_val_acc = -1\n",
    "        best_ckpt = None\n",
    "        for i in range(10000+1):\n",
    "            probe.train()\n",
    "            optimizer.zero_grad()\n",
    "            minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)\n",
    "            x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)\n",
    "            y = train_labels[minibatch_idcs].to(device)\n",
    "            logits = probe(x, holdout_eval_tokens=True)\n",
    "            # add l1 regularization of all params to the loss\n",
    "            loss = torch.nn.functional.cross_entropy(logits, y) + 0.001 * sum(p.abs().sum() for p in probe.parameters())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if i % 500 == 0:\n",
    "                train_acc = (logits.argmax(dim=-1) == y).float().mean().item()\n",
    "                probe.eval()\n",
    "                with torch.no_grad():\n",
    "                    valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "                    valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)\n",
    "                    valid_accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()\n",
    "                    if valid_accuracy > best_val_acc:\n",
    "                        best_val_acc = valid_accuracy\n",
    "                        best_ckpt = probe.state_dict()\n",
    "                print(f\"{basis_name} {i=:>5} train loss: {loss.item():5.2f}  train acc: {train_acc:.2f}  val loss: {valid_loss.item():5.2f}  valid acc: {valid_accuracy:.2f}\")\n",
    "        probe.load_state_dict(best_ckpt)\n",
    "        probe.eval()\n",
    "        with torch.no_grad():\n",
    "            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)\n",
    "            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()\n",
    "        test_accuracies[basis_name][layer_idx] = test_accuracy\n",
    "        print(f\"-> {basis_name} layer idx: {layer_idx:<3}, best valid accuracy: {best_val_acc:.2f}, test accuracy: {test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a121f72d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_linear_layer(x: Tensor, y: Tensor) -> torch.nn.Linear:\n",
    "    if y.ndim == 1:\n",
    "        y = y.unsqueeze(-1)\n",
    "    if not y.is_floating_point():\n",
    "        y = y.float()\n",
    "   \n",
    "    lin = torch.nn.Linear(x.shape[-1], y.shape[-1], device=x.device)\n",
    "    x_aug = torch.cat([x, torch.ones(len(x), 1, device=x.device)], dim=1)\n",
    "    coeffs = torch.linalg.lstsq(x_aug, y).solution\n",
    "    w, b = coeffs[:-1], coeffs[-1]\n",
    "    with torch.no_grad():\n",
    "        lin.weight[:] = w.T\n",
    "        lin.bias[:] = b\n",
    "    return lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "23e13e62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer idx: 0  , linear probe acc: 0.00, log probe acc: 0.00\n",
      "layer idx: 1  , linear probe acc: 0.01, log probe acc: 0.01\n",
      "layer idx: 2  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 3  , linear probe acc: 0.03, log probe acc: 0.02\n",
      "layer idx: 4  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 5  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 6  , linear probe acc: 0.03, log probe acc: 0.02\n",
      "layer idx: 7  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 8  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 9  , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 10 , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 11 , linear probe acc: 0.02, log probe acc: 0.02\n",
      "layer idx: 12 , linear probe acc: 0.03, log probe acc: 0.02\n",
      "layer idx: 13 , linear probe acc: 0.03, log probe acc: 0.02\n",
      "layer idx: 14 , linear probe acc: 0.03, log probe acc: 0.02\n",
      "layer idx: 15 , linear probe acc: 0.03, log probe acc: 0.02\n",
      "layer idx: 16 , linear probe acc: 0.02, log probe acc: 0.01\n"
     ]
    }
   ],
   "source": [
    "for layer_idx in range(len(train_hidden_states)):\n",
    "    lin_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.to(device),\n",
    "    )\n",
    "    log_probe = solve_linear_layer(\n",
    "        train_hidden_states[layer_idx].float().to(device),\n",
    "        train_labels.log1p().to(device),\n",
    "    )\n",
    "    lin_test_pred = lin_probe(test_hidden_states[layer_idx].float().to(device)).flatten().round().int()\n",
    "    lin_test_accuracy = (lin_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    log_test_pred = log_probe(test_hidden_states[layer_idx].float().to(device)).flatten().exp().add(1).round().int()\n",
    "    log_test_accuracy = (log_test_pred == test_labels).float().mean().item()\n",
    "    \n",
    "    test_accuracies[\"lin\"][layer_idx] = lin_test_accuracy\n",
    "    test_accuracies[\"log\"][layer_idx] = log_test_accuracy\n",
    "\n",
    "    print(f\"layer idx: {layer_idx:<3}, linear probe acc: {lin_test_accuracy:.2f}, log probe acc: {log_test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2f0f5f7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin accs: | 0% | 0% | 0% | 0% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 1% | 5% | 79% | 85% | 90% | 93% |\n",
      "bin accs: | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 0% | 1% | 1% | 1% | 2% |\n",
      "lin accs: | 0% | 1% | 2% | 3% | 2% | 2% | 3% | 2% | 2% | 2% | 2% | 2% | 3% | 3% | 3% | 3% | 2% |\n",
      "log accs: | 0% | 1% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 2% | 1% |\n"
     ]
    }
   ],
   "source": [
    "for name, accs in test_accuracies.items():\n",
    "    print(f\"{name} accs: | \" + \" | \".join([f\"{x:.0%}\" for layer, x in sorted(accs.items())]) + \" |\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numllama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
