{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "k0ncgC5Ow9Zw",
        "outputId": "bfb2921e-96b7-4d08-a2bc-6a5c62033b7c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting pyscnn\n",
            "  Downloading pyscnn-0.0.9b0-py3-none-any.whl (123 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/123.1 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m123.1/123.1 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.21.3 in /usr/local/lib/python3.10/dist-packages (from pyscnn) (1.25.2)\n",
            "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from pyscnn) (2.3.0+cu121)\n",
            "Requirement already satisfied: cvxpy>=1.2.1 in /usr/local/lib/python3.10/dist-packages (from pyscnn) (1.3.4)\n",
            "Requirement already satisfied: scikit-learn>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from pyscnn) (1.2.2)\n",
            "Requirement already satisfied: scipy>=1.7.2 in /usr/local/lib/python3.10/dist-packages (from pyscnn) (1.11.4)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pyscnn) (4.11.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from pyscnn) (4.66.4)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from pyscnn) (3.3.0)\n",
            "Collecting linalg-backends (from pyscnn)\n",
            "  Downloading linalg_backends-0.0.1b0-py3-none-any.whl (11 kB)\n",
            "Requirement already satisfied: osqp>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from cvxpy>=1.2.1->pyscnn) (0.6.2.post8)\n",
            "Requirement already satisfied: ecos>=2 in /usr/local/lib/python3.10/dist-packages (from cvxpy>=1.2.1->pyscnn) (2.0.13)\n",
            "Requirement already satisfied: scs>=1.1.6 in /usr/local/lib/python3.10/dist-packages (from cvxpy>=1.2.1->pyscnn) (3.2.4.post1)\n",
            "Requirement already satisfied: setuptools>65.5.1 in /usr/local/lib/python3.10/dist-packages (from cvxpy>=1.2.1->pyscnn) (67.7.2)\n",
            "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.0->pyscnn) (1.4.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.0->pyscnn) (3.5.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->pyscnn) (3.14.0)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->pyscnn) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->pyscnn) (3.3)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->pyscnn) (3.1.4)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->pyscnn) (2023.6.0)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
            "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
            "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
            "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
            "Collecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
            "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
            "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
            "Collecting nvidia-nccl-cu12==2.20.5 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
            "Collecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.10.0->pyscnn)\n",
            "  Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
            "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->pyscnn) (2.3.0)\n",
            "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->pyscnn)\n",
            "  Downloading nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl (21.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m74.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: qdldl in /usr/local/lib/python3.10/dist-packages (from osqp>=0.4.1->cvxpy>=1.2.1->pyscnn) (0.1.7.post2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->pyscnn) (2.1.5)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->pyscnn) (1.3.0)\n",
            "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, linalg-backends, pyscnn\n",
            "Successfully installed linalg-backends-0.0.1b0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.40 nvidia-nvtx-cu12-12.1.105 pyscnn-0.0.9b0\n"
          ]
        }
      ],
      "source": [
        "!pip install pyscnn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xDpQUUiJxCEO"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from scnn.private.utils.data import gen_classification_data, gen_regression_data\n",
        "\n",
        "\n",
        "from scnn.models import ConvexGatedReLU, ConvexReLU\n",
        "from scnn.solvers import RFISTA, AL, LeastSquaresSolver, CVXPYSolver, ApproximateConeDecomposition\n",
        "from scnn.regularizers import NeuronGL1, L2, L1, FeatureGL1\n",
        "from scnn.metrics import Metrics\n",
        "from scnn.activations import sample_gate_vectors\n",
        "from scnn.optimize import optimize_model, optimize"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zYCP_lhjw_H3"
      },
      "outputs": [],
      "source": [
        "from typing import Tuple, Optional, Union, Callable\n",
        "from typing_extensions import Literal\n",
        "import math\n",
        "\n",
        "from scipy.stats import ortho_group  # type: ignore\n",
        "\n",
        "Transform = Literal[\"cosine\", \"polynomial\"]\n",
        "\n",
        "Dataset = Tuple[np.ndarray, np.ndarray]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0gEjEU597gtx"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import torchvision\n",
        "from torch import nn, optim\n",
        "import torch.nn.functional as F\n",
        "from torchvision import datasets, transforms\n",
        "from torch.autograd import Variable\n",
        "from torch.utils.data.sampler import Sampler\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "\n",
        "import torch.utils.data as data\n",
        "import os\n",
        "import random"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NpdN8YFD7lRD",
        "outputId": "89e0c941-254c-4a2c-ab6f-530af47e1f6f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mounted at /content/gdrive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive', force_remount=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "q64jjQys7mmf",
        "outputId": "76df4092-df0a-448a-9656-f10cf0a452e1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "cuda\n"
          ]
        }
      ],
      "source": [
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8oh2H5-0wwuc",
        "outputId": "551859b3-d0f7-41fd-9a46-353d5adc752d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 170498071/170498071 [00:12<00:00, 13338764.58it/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Extracting ./data/cifar-10-python.tar.gz to ./data\n",
            "Files already downloaded and verified\n"
          ]
        }
      ],
      "source": [
        "torch.manual_seed(7)\n",
        "\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
        "\n",
        "transform_train = transforms.Compose([\n",
        "    transforms.RandomCrop(32, padding=4),\n",
        "    transforms.RandomHorizontalFlip(),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "])\n",
        "\n",
        "transform_test = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "])\n",
        "\n",
        "train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)\n",
        "test_dataset = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kaX9_ZDurwXr"
      },
      "outputs": [],
      "source": [
        "batch_size = 128"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4VNkusFdxWMt"
      },
      "outputs": [],
      "source": [
        "# train_loader = torch.utils.data.DataLoader(train_subset, batch_size = batch_size, num_workers = 2, pin_memory = True)\n",
        "validation_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, num_workers = 2, pin_memory = True, shuffle = False)\n",
        "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, num_workers = 2, pin_memory = True, shuffle = False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "__3GkWOXn8z3"
      },
      "outputs": [],
      "source": [
        "num_classes = 10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "N4UwG3fbIdmw",
        "outputId": "762bd3ba-b746-4355-bf5a-48a041c9564b"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n",
            "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
            "100%|██████████| 44.7M/44.7M [00:00<00:00, 180MB/s]\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "ResNet(\n",
              "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
              "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "  (relu): ReLU(inplace=True)\n",
              "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
              "  (layer1): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (layer2): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (downsample): Sequential(\n",
              "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (layer3): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (downsample): Sequential(\n",
              "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (layer4): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (downsample): Sequential(\n",
              "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
              "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
              ")"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import torchvision.models as models\n",
        "checkpoint = torch.load('/content/gdrive/My Drive/From Non-Convex to Convex/Fine-tune Pre-trained Resnet18/10 Class Experiments/TransformedImages/FFTSGDCosine/epoch189_testloss_0.0047605928719043735_testacc_87.61')\n",
        "model = torchvision.models.resnet18(pretrained=True).to(device)\n",
        "model.load_state_dict(checkpoint['model_state_dict'])\n",
        "model.eval()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xmtYWBFIIdm1",
        "outputId": "44977a61-5dd9-4187-8866-f0d23cfe8ac8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Input image shape torch.Size([128, 3, 32, 32])\n",
            "Layer1 Input torch.Size([128, 64, 8, 8])\n",
            "Layer2 Input torch.Size([128, 64, 8, 8])\n",
            "Layer3 Input torch.Size([128, 128, 4, 4])\n",
            "Layer3 Output torch.Size([128, 256, 2, 2])\n",
            "Layer4 Output torch.Size([128, 512, 1, 1])\n"
          ]
        }
      ],
      "source": [
        "def inspect_shapes(model, x):\n",
        "        print(\"Input image shape\", x.shape)\n",
        "        out = model.conv1(x)\n",
        "        out = model.bn1(out)\n",
        "        out = model.relu(out)\n",
        "        out = model.maxpool(out)\n",
        "        print(\"Layer1 Input\", out.shape)\n",
        "        out = model.layer1(out)\n",
        "        print(\"Layer2 Input\", out.shape)\n",
        "        out = model.layer2(out)\n",
        "        print(\"Layer3 Input\", out.shape)\n",
        "        out = model.layer3(out)\n",
        "        print(\"Layer3 Output\", out.shape)\n",
        "        out = model.layer4(out)\n",
        "        print(\"Layer4 Output\", out.shape)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    for data in validation_loader:\n",
        "\n",
        "        images, labels = data\n",
        "        images, labels = images.to(device), labels.to(device)\n",
        "\n",
        "        inspect_shapes(model, images)\n",
        "\n",
        "        break"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "922JhWo5IdnB",
        "outputId": "a4aa7ddd-d733-4597-e9cc-9d122dd64ba4"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<torch.utils.hooks.RemovableHandle at 0x7a17e7b84250>"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "activation = {}\n",
        "def get_activation(name):\n",
        "    def hook(model, input, output):\n",
        "        activation[name] = output.to(device)\n",
        "    return hook\n",
        "\n",
        "model.maxpool.register_forward_hook(get_activation('layer0'))\n",
        "model.layer1.register_forward_hook(get_activation('layer1'))\n",
        "model.layer2.register_forward_hook(get_activation('layer2'))\n",
        "model.layer3.register_forward_hook(get_activation('layer3'))\n",
        "model.layer4.register_forward_hook(get_activation('layer4'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZdOdL96aIdmx"
      },
      "outputs": [],
      "source": [
        "layer1_list = []\n",
        "\n",
        "for k,v in model.named_modules():\n",
        "    if 'layer1' in k:\n",
        "        layer1_list.append(k)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hYOeyaXvEtQv"
      },
      "outputs": [],
      "source": [
        "num_train_samples = 10\n",
        "num_train_gaussian_samples = 0\n",
        "\n",
        "# debugging_data = torch.load('/content/gdrive/My Drive/From Non-Convex to Convex/Fine-tune Pre-trained Resnet18/10 Class Experiments/debugging_scnn_numtrain5')\n",
        "\n",
        "debugging_data = torch.load('/content/gdrive/My Drive/From Non-Convex to Convex/Fine-tune Pre-trained Resnet18/10 Class Experiments/debugging_scnn_numtrain{}_gaussian{}'.format(num_train_samples, num_train_gaussian_samples))\n",
        "\n",
        "X_train_SCNN_new = debugging_data['X_train_SCNN_new']\n",
        "y_train_SCNN_new = debugging_data['y_train_SCNN_new']\n",
        "# X_test_SCNN = debugging_data['X_test_SCNN']\n",
        "# y_test_SCNN = debugging_data['y_test_SCNN']\n",
        "W1 = debugging_data['W1']\n",
        "W2 = debugging_data['W2']\n",
        "gates = debugging_data['gates']\n",
        "gated_activations = debugging_data['gated_activations']\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u0kghPOB_yHZ",
        "outputId": "54fa0f50-d490-4f3f-a96f-bf332c730b25"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "tensor(1.7676, device='cuda:0')\n",
            "tensor(3.5061, device='cuda:0')\n"
          ]
        }
      ],
      "source": [
        "print(torch.max(abs(X_train_SCNN_new)))\n",
        "print(torch.max(abs(y_train_SCNN_new)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zqC_ECSQ4WQe"
      },
      "outputs": [],
      "source": [
        "# model parameters\n",
        "lam = 4.50E-06\n",
        "\n",
        "# optimization parameters\n",
        "tol = 1e-7\n",
        "max_epochs = 1000\n",
        "\n",
        "# try playing with the step size...\n",
        "\n",
        "# lr = 0.01\n",
        "# lr = 0.001\n",
        "# lr = 0.0001\n",
        "lr = 5e-4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tRFJN9hF5V7Q"
      },
      "outputs": [],
      "source": [
        "class nonconvex_distil_model(nn.Module):\n",
        "    def __init__(self, hidden_units=100):\n",
        "        super().__init__()\n",
        "        self.fc1 = torch.nn.Linear(in_features=X_train_SCNN_new.shape[-1], out_features=hidden_units, bias=False)\n",
        "        self.fc2 = torch.nn.Linear(in_features=hidden_units, out_features=y_train_SCNN_new.shape[-1], bias=False)\n",
        "        self.relu = torch.nn.ReLU()\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.fc1(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.fc2(x)\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QUHX387RQTqV"
      },
      "outputs": [],
      "source": [
        "class gated_convex_distil_model(nn.Module):\n",
        "    def __init__(self, hidden_units=100):\n",
        "        super().__init__()\n",
        "        self.fc1 = torch.nn.Linear(in_features=X_train_SCNN_new.shape[-1], out_features=hidden_units, bias=False)\n",
        "        self.fc2 = torch.nn.Linear(in_features=hidden_units, out_features=y_train_SCNN_new.shape[-1], bias=False)\n",
        "\n",
        "        self.gates = torch.nn.Linear(in_features=X_train_SCNN_new.shape[-1], out_features=hidden_units, bias=False)\n",
        "        self.gates.weight.requires_grad = False\n",
        "\n",
        "    def forward(self, x):\n",
        "        g = self.gates(x)\n",
        "        x = self.fc1(x)\n",
        "\n",
        "        x = x * ( g > 0 )\n",
        "\n",
        "        x = self.fc2(x)\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n25_j8NY4itW",
        "outputId": "b09d08d4-b89a-4b4b-ba8d-49a43981572d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "153600\n"
          ]
        }
      ],
      "source": [
        "# create model\n",
        "hidden_units = 100\n",
        "\n",
        "# distil_model = gated_convex_distil_model(hidden_units).to(device)\n",
        "distil_model = nonconvex_distil_model(hidden_units).to(device)\n",
        "\n",
        "print(sum(p.numel() for p in distil_model.parameters() if p.requires_grad))\n",
        "\n",
        "optimizer = optim.Adam(distil_model.parameters(), lr=lr)\n",
        "loss_function = nn.MSELoss()\n",
        "\n",
        "# time_budget = np.sum(metrics['grelu_metrics'].time)\n",
        "time_budget = 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W7-GotvqrbbC",
        "outputId": "97684be3-82d6-48cd-8a17-a4e46b73ee4d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "1029046"
            ]
          },
          "execution_count": 27,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "1029046"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZQXBScKyWLwr"
      },
      "outputs": [],
      "source": [
        "def report_test_accuracy(distil_model):\n",
        "        test_class_correct = list(0. for i in range(num_classes))\n",
        "        test_class_total = list(0. for i in range(num_classes))\n",
        "        test_overall_correct = 0\n",
        "        test_overall_total = 0\n",
        "\n",
        "        model.eval()\n",
        "        distil_model.eval()\n",
        "\n",
        "        with torch.no_grad():\n",
        "\n",
        "            for data in test_loader:\n",
        "\n",
        "                images, labels = data\n",
        "                images, labels = images.to(device), labels.to(device)\n",
        "\n",
        "                _ = model(images)\n",
        "\n",
        "                input = activation[\"layer3\"]\n",
        "                target_output = activation[\"layer4\"]\n",
        "\n",
        "                distil_model_output = distil_model(input.reshape((images.shape[0], -1))).reshape(target_output.shape)\n",
        "\n",
        "                output = model.fc(model.avgpool(distil_model_output).reshape(input.shape[0], -1)).to(device)\n",
        "\n",
        "                predicted = output.argmax(dim=1, keepdim=True)\n",
        "                predicted = predicted.squeeze()\n",
        "\n",
        "                for i in range(len(predicted)):\n",
        "                    index = labels[i].item()\n",
        "                    test_class_correct[index] += int(labels[i] == predicted[i].item())\n",
        "                    test_class_total[index] += 1\n",
        "\n",
        "        overall_correct = 0\n",
        "        overall_total = 0\n",
        "\n",
        "        for i in range(num_classes):\n",
        "\n",
        "            test_overall_correct += test_class_correct[i]\n",
        "            test_overall_total += test_class_total[i]\n",
        "\n",
        "        return 100.0 * test_overall_correct / test_overall_total"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "L6UadgSOS7s5",
        "outputId": "4cb894d5-891f-4315-8f69-6a8b0c1ea135"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "epoch:  1 , time elapse:  0.002899646759033203 , mse train loss:  0.13541565835475922 , test acc:  8.8\n",
            "epoch:  1 , time elapse:  0.006285667419433594 , mse train loss:  0.1340165138244629 , test acc:  12.71\n",
            "epoch:  1 , time elapse:  0.011560678482055664 , mse train loss:  0.13260406255722046 , test acc:  15.69\n",
            "epoch:  1 , time elapse:  0.014116287231445312 , mse train loss:  0.1310487687587738 , test acc:  17.35\n",
            "epoch:  1 , time elapse:  0.018825054168701172 , mse train loss:  0.12926803529262543 , test acc:  18.12\n",
            "epoch:  1 , time elapse:  0.022591352462768555 , mse train loss:  0.12723322212696075 , test acc:  18.73\n",
            "epoch:  1 , time elapse:  0.026854991912841797 , mse train loss:  0.12495900690555573 , test acc:  19.21\n",
            "epoch:  1 , time elapse:  0.03184103965759277 , mse train loss:  0.12247583270072937 , test acc:  19.47\n",
            "epoch:  1 , time elapse:  0.03471517562866211 , mse train loss:  0.1198301687836647 , test acc:  19.87\n",
            "epoch:  1 , time elapse:  0.04107999801635742 , mse train loss:  0.1170908510684967 , test acc:  20.3\n",
            "epoch:  1 , time elapse:  0.04550433158874512 , mse train loss:  0.11433785408735275 , test acc:  20.71\n",
            "epoch:  1 , time elapse:  0.04802703857421875 , mse train loss:  0.11166281253099442 , test acc:  21.38\n",
            "epoch:  1 , time elapse:  0.050742149353027344 , mse train loss:  0.1091492623090744 , test acc:  21.85\n",
            "epoch:  1 , time elapse:  0.05767369270324707 , mse train loss:  0.10686378926038742 , test acc:  22.29\n",
            "epoch:  1 , time elapse:  0.0607607364654541 , mse train loss:  0.10483850538730621 , test acc:  22.52\n",
            "epoch:  1 , time elapse:  0.06578969955444336 , mse train loss:  0.10306553542613983 , test acc:  22.36\n",
            "epoch:  1 , time elapse:  0.0718083381652832 , mse train loss:  0.10150188952684402 , test acc:  21.83\n",
            "epoch:  1 , time elapse:  0.07539200782775879 , mse train loss:  0.10007987916469574 , test acc:  21.45\n",
            "epoch:  1 , time elapse:  0.08014440536499023 , mse train loss:  0.09873053431510925 , test acc:  20.72\n",
            "epoch:  1 , time elapse:  0.08520936965942383 , mse train loss:  0.09740204364061356 , test acc:  20.13\n",
            "epoch:  1 , time elapse:  0.09143924713134766 , mse train loss:  0.0960666611790657 , test acc:  19.54\n",
            "epoch:  1 , time elapse:  0.09806394577026367 , mse train loss:  0.09472224861383438 , test acc:  19.41\n",
            "epoch:  1 , time elapse:  0.10302138328552246 , mse train loss:  0.09338529407978058 , test acc:  19.61\n",
            "epoch:  1 , time elapse:  0.10804104804992676 , mse train loss:  0.0920867919921875 , test acc:  21.07\n",
            "epoch:  1 , time elapse:  0.11289215087890625 , mse train loss:  0.09085837006568909 , test acc:  23.54\n",
            "epoch:  1 , time elapse:  0.1185460090637207 , mse train loss:  0.0897224023938179 , test acc:  26.24\n",
            "epoch:  1 , time elapse:  0.1233515739440918 , mse train loss:  0.08868595957756042 , test acc:  29.68\n",
            "epoch:  1 , time elapse:  0.12825512886047363 , mse train loss:  0.08774510771036148 , test acc:  33.51\n",
            "epoch:  1 , time elapse:  0.1338489055633545 , mse train loss:  0.086888886988163 , test acc:  37.07\n",
            "epoch:  1 , time elapse:  0.13985133171081543 , mse train loss:  0.08609918504953384 , test acc:  40.39\n",
            "epoch:  1 , time elapse:  0.14525461196899414 , mse train loss:  0.08535254001617432 , test acc:  43.81\n",
            "epoch:  1 , time elapse:  0.15050697326660156 , mse train loss:  0.08462927490472794 , test acc:  46.84\n",
            "epoch:  1 , time elapse:  0.15381360054016113 , mse train loss:  0.08391502499580383 , test acc:  49.78\n",
            "epoch:  1 , time elapse:  0.16128897666931152 , mse train loss:  0.08319686353206635 , test acc:  52.72\n",
            "epoch:  1 , time elapse:  0.16569924354553223 , mse train loss:  0.08247456699609756 , test acc:  55.5\n",
            "epoch:  1 , time elapse:  0.1687636375427246 , mse train loss:  0.0817418247461319 , test acc:  58.23\n",
            "epoch:  1 , time elapse:  0.1729416847229004 , mse train loss:  0.08099941164255142 , test acc:  60.54\n",
            "epoch:  1 , time elapse:  0.17704296112060547 , mse train loss:  0.08024632185697556 , test acc:  63.01\n",
            "epoch:  1 , time elapse:  0.1813974380493164 , mse train loss:  0.07948355376720428 , test acc:  65.18\n",
            "epoch:  1 , time elapse:  0.18630123138427734 , mse train loss:  0.07871004194021225 , test acc:  66.96\n",
            "epoch:  1 , time elapse:  0.19101357460021973 , mse train loss:  0.07792172580957413 , test acc:  68.77\n",
            "epoch:  1 , time elapse:  0.1958613395690918 , mse train loss:  0.07711285352706909 , test acc:  70.22\n",
            "epoch:  1 , time elapse:  0.2018442153930664 , mse train loss:  0.07628165185451508 , test acc:  71.54\n",
            "epoch:  1 , time elapse:  0.2082967758178711 , mse train loss:  0.07542368769645691 , test acc:  72.67\n",
            "epoch:  1 , time elapse:  0.21318721771240234 , mse train loss:  0.07453738152980804 , test acc:  73.5\n",
            "epoch:  1 , time elapse:  0.21756601333618164 , mse train loss:  0.07362332195043564 , test acc:  74.3\n",
            "epoch:  1 , time elapse:  0.2246260643005371 , mse train loss:  0.07268504798412323 , test acc:  74.96\n",
            "epoch:  1 , time elapse:  0.229417085647583 , mse train loss:  0.07172516733407974 , test acc:  75.51\n",
            "epoch:  1 , time elapse:  0.23418712615966797 , mse train loss:  0.070747509598732 , test acc:  75.99\n",
            "epoch:  1 , time elapse:  0.23911786079406738 , mse train loss:  0.06975294649600983 , test acc:  76.19\n",
            "epoch:  1 , time elapse:  0.24517822265625 , mse train loss:  0.06874119490385056 , test acc:  76.45\n",
            "epoch:  1 , time elapse:  0.25231003761291504 , mse train loss:  0.06771159172058105 , test acc:  76.68\n",
            "epoch:  1 , time elapse:  0.2568395137786865 , mse train loss:  0.06666402518749237 , test acc:  76.91\n",
            "epoch:  1 , time elapse:  0.26001667976379395 , mse train loss:  0.06559862941503525 , test acc:  77.13\n",
            "epoch:  1 , time elapse:  0.2632405757904053 , mse train loss:  0.06451540440320969 , test acc:  77.29\n",
            "epoch:  1 , time elapse:  0.2693343162536621 , mse train loss:  0.0634155198931694 , test acc:  77.39\n",
            "epoch:  1 , time elapse:  0.27338290214538574 , mse train loss:  0.062300827354192734 , test acc:  77.43\n",
            "epoch:  1 , time elapse:  0.27832674980163574 , mse train loss:  0.061173275113105774 , test acc:  77.68\n",
            "epoch:  1 , time elapse:  0.28254055976867676 , mse train loss:  0.06003419682383537 , test acc:  77.84\n",
            "epoch:  1 , time elapse:  0.2886214256286621 , mse train loss:  0.05888575315475464 , test acc:  78.0\n",
            "epoch:  1 , time elapse:  0.2927117347717285 , mse train loss:  0.057728689163923264 , test acc:  78.29\n",
            "epoch:  1 , time elapse:  0.2966616153717041 , mse train loss:  0.05656462907791138 , test acc:  78.25\n",
            "epoch:  1 , time elapse:  0.30171847343444824 , mse train loss:  0.05539403483271599 , test acc:  78.34\n",
            "epoch:  1 , time elapse:  0.3079047203063965 , mse train loss:  0.054219238460063934 , test acc:  78.49\n",
            "epoch:  1 , time elapse:  0.31253957748413086 , mse train loss:  0.053042132407426834 , test acc:  78.65\n",
            "epoch:  1 , time elapse:  0.3161463737487793 , mse train loss:  0.05186500772833824 , test acc:  78.73\n",
            "epoch:  1 , time elapse:  0.32120609283447266 , mse train loss:  0.050689324736595154 , test acc:  78.79\n",
            "epoch:  1 , time elapse:  0.32643961906433105 , mse train loss:  0.049516137689352036 , test acc:  78.88\n",
            "epoch:  1 , time elapse:  0.33234429359436035 , mse train loss:  0.048346880823373795 , test acc:  78.94\n",
            "epoch:  1 , time elapse:  0.3373849391937256 , mse train loss:  0.04718324542045593 , test acc:  78.99\n",
            "epoch:  1 , time elapse:  0.34326815605163574 , mse train loss:  0.04602643847465515 , test acc:  79.06\n",
            "epoch:  1 , time elapse:  0.34892964363098145 , mse train loss:  0.04487685486674309 , test acc:  79.19\n",
            "epoch:  1 , time elapse:  0.35316038131713867 , mse train loss:  0.0437362976372242 , test acc:  79.25\n",
            "epoch:  1 , time elapse:  0.3581371307373047 , mse train loss:  0.04260595887899399 , test acc:  79.3\n",
            "epoch:  1 , time elapse:  0.36338329315185547 , mse train loss:  0.04148674011230469 , test acc:  79.39\n",
            "epoch:  1 , time elapse:  0.36584973335266113 , mse train loss:  0.04038063809275627 , test acc:  79.37\n",
            "epoch:  1 , time elapse:  0.3694319725036621 , mse train loss:  0.03928894177079201 , test acc:  79.48\n",
            "epoch:  1 , time elapse:  0.3738260269165039 , mse train loss:  0.038212068378925323 , test acc:  79.56\n",
            "epoch:  1 , time elapse:  0.37758779525756836 , mse train loss:  0.03715077415108681 , test acc:  79.61\n",
            "epoch:  1 , time elapse:  0.38344621658325195 , mse train loss:  0.03610595315694809 , test acc:  79.7\n",
            "epoch:  1 , time elapse:  0.38630080223083496 , mse train loss:  0.03507835790514946 , test acc:  79.76\n",
            "epoch:  1 , time elapse:  0.3915684223175049 , mse train loss:  0.034069325774908066 , test acc:  79.82\n",
            "epoch:  1 , time elapse:  0.3958396911621094 , mse train loss:  0.03307901695370674 , test acc:  79.97\n",
            "epoch:  1 , time elapse:  0.39963579177856445 , mse train loss:  0.032107919454574585 , test acc:  80.03\n",
            "epoch:  1 , time elapse:  0.4066429138183594 , mse train loss:  0.031156817451119423 , test acc:  80.09\n",
            "epoch:  1 , time elapse:  0.4115176200866699 , mse train loss:  0.030226141214370728 , test acc:  80.23\n",
            "epoch:  1 , time elapse:  0.4156150817871094 , mse train loss:  0.02931627258658409 , test acc:  80.21\n",
            "epoch:  1 , time elapse:  0.42000722885131836 , mse train loss:  0.028427813202142715 , test acc:  80.27\n",
            "epoch:  1 , time elapse:  0.42267870903015137 , mse train loss:  0.027561163529753685 , test acc:  80.34\n",
            "epoch:  1 , time elapse:  0.4278080463409424 , mse train loss:  0.026716727763414383 , test acc:  80.37\n",
            "epoch:  1 , time elapse:  0.43312573432922363 , mse train loss:  0.025894828140735626 , test acc:  80.49\n",
            "epoch:  1 , time elapse:  0.4383983612060547 , mse train loss:  0.02509533427655697 , test acc:  80.61\n",
            "epoch:  1 , time elapse:  0.4444873332977295 , mse train loss:  0.024317927658557892 , test acc:  80.68\n",
            "epoch:  1 , time elapse:  0.44948816299438477 , mse train loss:  0.023562731221318245 , test acc:  80.8\n",
            "epoch:  1 , time elapse:  0.4543170928955078 , mse train loss:  0.022829942405223846 , test acc:  80.88\n",
            "epoch:  1 , time elapse:  0.45914554595947266 , mse train loss:  0.022119585424661636 , test acc:  80.93\n",
            "epoch:  1 , time elapse:  0.4618206024169922 , mse train loss:  0.0214316938072443 , test acc:  80.92\n",
            "epoch:  1 , time elapse:  0.46758294105529785 , mse train loss:  0.020766353234648705 , test acc:  81.06\n",
            "epoch:  1 , time elapse:  0.4716670513153076 , mse train loss:  0.020122552290558815 , test acc:  81.12\n",
            "epoch:  1 , time elapse:  0.4766554832458496 , mse train loss:  0.01950008049607277 , test acc:  81.18\n",
            "epoch:  1 , time elapse:  0.4789307117462158 , mse train loss:  0.018898891285061836 , test acc:  81.3\n",
            "epoch:  1 , time elapse:  0.483705997467041 , mse train loss:  0.018318716436624527 , test acc:  81.34\n",
            "epoch:  1 , time elapse:  0.4884488582611084 , mse train loss:  0.01775924488902092 , test acc:  81.38\n",
            "epoch:  1 , time elapse:  0.4925541877746582 , mse train loss:  0.01722009852528572 , test acc:  81.37\n",
            "epoch:  1 , time elapse:  0.4975144863128662 , mse train loss:  0.016700908541679382 , test acc:  81.4\n",
            "epoch:  1 , time elapse:  0.5011365413665771 , mse train loss:  0.016201350837945938 , test acc:  81.45\n",
            "epoch:  1 , time elapse:  0.5060639381408691 , mse train loss:  0.015720903873443604 , test acc:  81.49\n",
            "epoch:  1 , time elapse:  0.5113303661346436 , mse train loss:  0.015259797684848309 , test acc:  81.53\n",
            "epoch:  1 , time elapse:  0.5164022445678711 , mse train loss:  0.014817052520811558 , test acc:  81.64\n",
            "epoch:  1 , time elapse:  0.5231077671051025 , mse train loss:  0.014392126351594925 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  0.5276286602020264 , mse train loss:  0.013984544202685356 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  0.5325496196746826 , mse train loss:  0.013593563809990883 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  0.5368883609771729 , mse train loss:  0.013218880631029606 , test acc:  81.67\n",
            "epoch:  1 , time elapse:  0.5437021255493164 , mse train loss:  0.012859945185482502 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  0.5487103462219238 , mse train loss:  0.012516308575868607 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  0.5536816120147705 , mse train loss:  0.01218740176409483 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  0.5578756332397461 , mse train loss:  0.011872715316712856 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  0.5636723041534424 , mse train loss:  0.011571688577532768 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.5686709880828857 , mse train loss:  0.011283919215202332 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  0.5727138519287109 , mse train loss:  0.011008848436176777 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.5771169662475586 , mse train loss:  0.010745810344815254 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.5823864936828613 , mse train loss:  0.010494446381926537 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.5883491039276123 , mse train loss:  0.010254357941448689 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.5933680534362793 , mse train loss:  0.01002503838390112 , test acc:  81.9\n",
            "epoch:  1 , time elapse:  0.5980827808380127 , mse train loss:  0.00980574730783701 , test acc:  81.92\n",
            "epoch:  1 , time elapse:  0.6007590293884277 , mse train loss:  0.009596080519258976 , test acc:  81.96\n",
            "epoch:  1 , time elapse:  0.6091206073760986 , mse train loss:  0.009395924396812916 , test acc:  81.94\n",
            "epoch:  1 , time elapse:  0.6155509948730469 , mse train loss:  0.009204695001244545 , test acc:  81.93\n",
            "epoch:  1 , time elapse:  0.6208488941192627 , mse train loss:  0.009021971374750137 , test acc:  81.93\n",
            "epoch:  1 , time elapse:  0.6268444061279297 , mse train loss:  0.008847313933074474 , test acc:  81.95\n",
            "epoch:  1 , time elapse:  0.6334187984466553 , mse train loss:  0.008680401369929314 , test acc:  81.93\n",
            "epoch:  1 , time elapse:  0.6401028633117676 , mse train loss:  0.00852101668715477 , test acc:  81.94\n",
            "epoch:  1 , time elapse:  0.6455123424530029 , mse train loss:  0.008368751965463161 , test acc:  81.95\n",
            "epoch:  1 , time elapse:  0.6504726409912109 , mse train loss:  0.008223243057727814 , test acc:  81.96\n",
            "epoch:  1 , time elapse:  0.6534340381622314 , mse train loss:  0.008084123022854328 , test acc:  81.94\n",
            "epoch:  1 , time elapse:  0.6582181453704834 , mse train loss:  0.007951382547616959 , test acc:  81.91\n",
            "epoch:  1 , time elapse:  0.6631166934967041 , mse train loss:  0.007824412547051907 , test acc:  81.9\n",
            "epoch:  1 , time elapse:  0.6671361923217773 , mse train loss:  0.0077029624953866005 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.6737496852874756 , mse train loss:  0.007586859166622162 , test acc:  81.92\n",
            "epoch:  1 , time elapse:  0.6771624088287354 , mse train loss:  0.007475902792066336 , test acc:  81.92\n",
            "epoch:  1 , time elapse:  0.681513786315918 , mse train loss:  0.007369794882833958 , test acc:  81.9\n",
            "epoch:  1 , time elapse:  0.6866703033447266 , mse train loss:  0.007268309593200684 , test acc:  81.93\n",
            "epoch:  1 , time elapse:  0.6894793510437012 , mse train loss:  0.007171167526394129 , test acc:  81.94\n",
            "epoch:  1 , time elapse:  0.694260835647583 , mse train loss:  0.007078260648995638 , test acc:  81.91\n",
            "epoch:  1 , time elapse:  0.6991703510284424 , mse train loss:  0.006989287678152323 , test acc:  81.91\n",
            "epoch:  1 , time elapse:  0.7028608322143555 , mse train loss:  0.006904133595526218 , test acc:  81.89\n",
            "epoch:  1 , time elapse:  0.708484411239624 , mse train loss:  0.006822624709457159 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.7162830829620361 , mse train loss:  0.006744569167494774 , test acc:  81.9\n",
            "epoch:  1 , time elapse:  0.721198558807373 , mse train loss:  0.006669722031801939 , test acc:  81.89\n",
            "epoch:  1 , time elapse:  0.7257328033447266 , mse train loss:  0.006598060950636864 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.7313168048858643 , mse train loss:  0.006529384292662144 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.7358360290527344 , mse train loss:  0.006463595665991306 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.7409067153930664 , mse train loss:  0.006400511134415865 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.7452454566955566 , mse train loss:  0.006340013816952705 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.7502107620239258 , mse train loss:  0.006282012443989515 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.755840539932251 , mse train loss:  0.006226368248462677 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.7623872756958008 , mse train loss:  0.006172908004373312 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.7674229145050049 , mse train loss:  0.00612157816067338 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.7702353000640869 , mse train loss:  0.006072317250072956 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.7753195762634277 , mse train loss:  0.006024980917572975 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.782263994216919 , mse train loss:  0.005979450885206461 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.787177324295044 , mse train loss:  0.00593568803742528 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.7922718524932861 , mse train loss:  0.0058936262503266335 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.7972145080566406 , mse train loss:  0.005853163078427315 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.8009853363037109 , mse train loss:  0.005814183969050646 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.8060624599456787 , mse train loss:  0.0057766935788095 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.8100752830505371 , mse train loss:  0.005740553606301546 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8150556087493896 , mse train loss:  0.005705740302801132 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8184788227081299 , mse train loss:  0.005672219209372997 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.8228394985198975 , mse train loss:  0.00563985388725996 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  0.8278713226318359 , mse train loss:  0.005608629900962114 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.8329677581787109 , mse train loss:  0.005578490439802408 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.8382699489593506 , mse train loss:  0.0055493880063295364 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  0.8433837890625 , mse train loss:  0.0055212401784956455 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.8483350276947021 , mse train loss:  0.0054940455593168736 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8534848690032959 , mse train loss:  0.0054677436128258705 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8589048385620117 , mse train loss:  0.00544227659702301 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8627116680145264 , mse train loss:  0.00541763985529542 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.8679139614105225 , mse train loss:  0.005393763072788715 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8730409145355225 , mse train loss:  0.00537063367664814 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8781676292419434 , mse train loss:  0.005348210223019123 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.885575532913208 , mse train loss:  0.005326446611434221 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.8924365043640137 , mse train loss:  0.005305327009409666 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.8976538181304932 , mse train loss:  0.0052848076447844505 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.9026596546173096 , mse train loss:  0.005264870822429657 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  0.908684492111206 , mse train loss:  0.005245490465313196 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.9150362014770508 , mse train loss:  0.005226644221693277 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  0.9198753833770752 , mse train loss:  0.005208300892263651 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  0.9263355731964111 , mse train loss:  0.0051904236897826195 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.9313070774078369 , mse train loss:  0.0051730177365243435 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.9351460933685303 , mse train loss:  0.005156038794666529 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.9400229454040527 , mse train loss:  0.005139497574418783 , test acc:  81.87\n",
            "epoch:  1 , time elapse:  0.944835901260376 , mse train loss:  0.005123347043991089 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.9499485492706299 , mse train loss:  0.005107592325657606 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.9556922912597656 , mse train loss:  0.005092186853289604 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.9609172344207764 , mse train loss:  0.00507711386308074 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.9652085304260254 , mse train loss:  0.005062384065240622 , test acc:  81.88\n",
            "epoch:  1 , time elapse:  0.9702198505401611 , mse train loss:  0.005047950427979231 , test acc:  81.86\n",
            "epoch:  1 , time elapse:  0.9767940044403076 , mse train loss:  0.005033804103732109 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  0.9830596446990967 , mse train loss:  0.0050199502147734165 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.9881370067596436 , mse train loss:  0.005006375256925821 , test acc:  81.84\n",
            "epoch:  1 , time elapse:  0.9945354461669922 , mse train loss:  0.004993060603737831 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  1.0003535747528076 , mse train loss:  0.004979982972145081 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  1.0061218738555908 , mse train loss:  0.004967140033841133 , test acc:  81.85\n",
            "epoch:  1 , time elapse:  1.011518955230713 , mse train loss:  0.004954525735229254 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  1.0164496898651123 , mse train loss:  0.004942137282341719 , test acc:  81.83\n",
            "epoch:  1 , time elapse:  1.0225856304168701 , mse train loss:  0.004929941147565842 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  1.0255587100982666 , mse train loss:  0.004917937330901623 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  1.031374216079712 , mse train loss:  0.004906127694994211 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  1.036564826965332 , mse train loss:  0.004894488025456667 , test acc:  81.82\n",
            "epoch:  1 , time elapse:  1.039318561553955 , mse train loss:  0.004883010871708393 , test acc:  81.81\n",
            "epoch:  1 , time elapse:  1.0420467853546143 , mse train loss:  0.004871703684329987 , test acc:  81.81\n",
            "epoch:  1 , time elapse:  1.0468065738677979 , mse train loss:  0.0048605469055473804 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.0519521236419678 , mse train loss:  0.004849541001021862 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.055185317993164 , mse train loss:  0.0048386696726083755 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.061082124710083 , mse train loss:  0.004827938508242369 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.0658490657806396 , mse train loss:  0.004817333538085222 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.0708155632019043 , mse train loss:  0.004806858953088522 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.0748519897460938 , mse train loss:  0.004796494729816914 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.0801188945770264 , mse train loss:  0.004786239936947823 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.0857844352722168 , mse train loss:  0.004776100628077984 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.093151330947876 , mse train loss:  0.004766048863530159 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.0980584621429443 , mse train loss:  0.004756108392030001 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.1006395816802979 , mse train loss:  0.004746251739561558 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.1041033267974854 , mse train loss:  0.004736481700092554 , test acc:  81.78\n",
            "epoch:  1 , time elapse:  1.1090407371520996 , mse train loss:  0.004726790823042393 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.1140763759613037 , mse train loss:  0.004717170260846615 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.1189303398132324 , mse train loss:  0.004707622341811657 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.1254303455352783 , mse train loss:  0.004698135890066624 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.1320946216583252 , mse train loss:  0.004688723478466272 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.1371769905090332 , mse train loss:  0.004679370205849409 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.141986608505249 , mse train loss:  0.004670075140893459 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.147745132446289 , mse train loss:  0.004660838283598423 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.1552026271820068 , mse train loss:  0.004651656839996576 , test acc:  81.8\n",
            "epoch:  1 , time elapse:  1.1607186794281006 , mse train loss:  0.00464252196252346 , test acc:  81.79\n",
            "epoch:  1 , time elapse:  1.1632490158081055 , mse train loss:  0.004633442498743534 , test acc:  81.77\n",
            "epoch:  1 , time elapse:  1.1684534549713135 , mse train loss:  0.004624409135431051 , test acc:  81.77\n",
            "epoch:  1 , time elapse:  1.1728487014770508 , mse train loss:  0.00461542047560215 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  1.1770217418670654 , mse train loss:  0.00460647651925683 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  1.181854486465454 , mse train loss:  0.004597574472427368 , test acc:  81.77\n",
            "epoch:  1 , time elapse:  1.1871402263641357 , mse train loss:  0.004588708281517029 , test acc:  81.77\n",
            "epoch:  1 , time elapse:  1.189723253250122 , mse train loss:  0.0045798770152032375 , test acc:  81.77\n",
            "epoch:  1 , time elapse:  1.1948025226593018 , mse train loss:  0.0045710778795182705 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  1.199936866760254 , mse train loss:  0.0045623076148331165 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  1.2040221691131592 , mse train loss:  0.004553578794002533 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  1.2088263034820557 , mse train loss:  0.004544876050204039 , test acc:  81.76\n",
            "epoch:  1 , time elapse:  1.2141082286834717 , mse train loss:  0.0045361886732280254 , test acc:  81.75\n",
            "epoch:  1 , time elapse:  1.2191994190216064 , mse train loss:  0.004527526441961527 , test acc:  81.75\n",
            "epoch:  1 , time elapse:  1.224320411682129 , mse train loss:  0.004518888890743256 , test acc:  81.75\n",
            "epoch:  1 , time elapse:  1.2269072532653809 , mse train loss:  0.004510275088250637 , test acc:  81.75\n",
            "epoch:  1 , time elapse:  1.234001874923706 , mse train loss:  0.004501677118241787 , test acc:  81.74\n",
            "epoch:  1 , time elapse:  1.2403247356414795 , mse train loss:  0.004493105225265026 , test acc:  81.73\n",
            "epoch:  1 , time elapse:  1.2455086708068848 , mse train loss:  0.004484544508159161 , test acc:  81.73\n",
            "epoch:  1 , time elapse:  1.2520599365234375 , mse train loss:  0.004475998226553202 , test acc:  81.74\n",
            "epoch:  1 , time elapse:  1.2557275295257568 , mse train loss:  0.004467471037060022 , test acc:  81.74\n",
            "epoch:  1 , time elapse:  1.2599265575408936 , mse train loss:  0.004458967130631208 , test acc:  81.75\n",
            "epoch:  1 , time elapse:  1.2650845050811768 , mse train loss:  0.004450464621186256 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.271383285522461 , mse train loss:  0.004441975150257349 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.277282953262329 , mse train loss:  0.004433503840118647 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.2835688591003418 , mse train loss:  0.004425043240189552 , test acc:  81.73\n",
            "epoch:  1 , time elapse:  1.2888567447662354 , mse train loss:  0.004416596610099077 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.2938611507415771 , mse train loss:  0.00440815556794405 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.2977888584136963 , mse train loss:  0.004399731755256653 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.3006353378295898 , mse train loss:  0.004391308408230543 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.3051707744598389 , mse train loss:  0.004382906015962362 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.309051752090454 , mse train loss:  0.004374500829726458 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.313920259475708 , mse train loss:  0.004366109147667885 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.32102632522583 , mse train loss:  0.004357724916189909 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3259141445159912 , mse train loss:  0.004349346738308668 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.330209493637085 , mse train loss:  0.004340980667620897 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3362007141113281 , mse train loss:  0.0043326131999492645 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3388950824737549 , mse train loss:  0.004324252717196941 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3455233573913574 , mse train loss:  0.004315895494073629 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3504900932312012 , mse train loss:  0.004307546652853489 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3553218841552734 , mse train loss:  0.004299185238778591 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3582892417907715 , mse train loss:  0.004290817771106958 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.362389087677002 , mse train loss:  0.004282444715499878 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3674421310424805 , mse train loss:  0.004274075385183096 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3723642826080322 , mse train loss:  0.004265695810317993 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3772857189178467 , mse train loss:  0.0042573255486786366 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.3806638717651367 , mse train loss:  0.004248952027410269 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.3860716819763184 , mse train loss:  0.00424057524651289 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.3908417224884033 , mse train loss:  0.004232205916196108 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.3956029415130615 , mse train loss:  0.004223840311169624 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.4013667106628418 , mse train loss:  0.004215466324239969 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.4063329696655273 , mse train loss:  0.004207102581858635 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.4109432697296143 , mse train loss:  0.004198750015348196 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.4162921905517578 , mse train loss:  0.004190403036773205 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.4219021797180176 , mse train loss:  0.004182048141956329 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.4281105995178223 , mse train loss:  0.004173702094703913 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.4326183795928955 , mse train loss:  0.004165357444435358 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.437654733657837 , mse train loss:  0.004157006274908781 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.443333387374878 , mse train loss:  0.004148662090301514 , test acc:  81.71\n",
            "epoch:  1 , time elapse:  1.4492628574371338 , mse train loss:  0.004140318371355534 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.4543533325195312 , mse train loss:  0.004131978377699852 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.4586892127990723 , mse train loss:  0.004123637452721596 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.464128017425537 , mse train loss:  0.004115303512662649 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.470947265625 , mse train loss:  0.0041069709695875645 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.4759933948516846 , mse train loss:  0.004098638892173767 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.4810597896575928 , mse train loss:  0.004090301692485809 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.4885318279266357 , mse train loss:  0.004081968683749437 , test acc:  81.72\n",
            "epoch:  1 , time elapse:  1.4942448139190674 , mse train loss:  0.00407363148406148 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.4991645812988281 , mse train loss:  0.004065294750034809 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.5034129619598389 , mse train loss:  0.004056956619024277 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.5083491802215576 , mse train loss:  0.004048620350658894 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.5113391876220703 , mse train loss:  0.004040297586470842 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.5153210163116455 , mse train loss:  0.0040319617837667465 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.5205490589141846 , mse train loss:  0.004023627378046513 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.5256123542785645 , mse train loss:  0.004015303682535887 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.5324556827545166 , mse train loss:  0.004006983712315559 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.5386388301849365 , mse train loss:  0.003998663276433945 , test acc:  81.7\n",
            "epoch:  1 , time elapse:  1.5438714027404785 , mse train loss:  0.00399034796282649 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.5507092475891113 , mse train loss:  0.003982027061283588 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.556168556213379 , mse train loss:  0.0039737108163535595 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.561344861984253 , mse train loss:  0.003965400159358978 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.5663485527038574 , mse train loss:  0.003957090899348259 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.5715341567993164 , mse train loss:  0.0039487904869019985 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.5776195526123047 , mse train loss:  0.003940490540117025 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.5831866264343262 , mse train loss:  0.003932190593332052 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.5881164073944092 , mse train loss:  0.003923894837498665 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.59309720993042 , mse train loss:  0.003915604203939438 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.598041296005249 , mse train loss:  0.0039073205552995205 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.604302167892456 , mse train loss:  0.003899033647030592 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.609501600265503 , mse train loss:  0.0038907555863261223 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6145358085632324 , mse train loss:  0.0038824849762022495 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6197624206542969 , mse train loss:  0.0038742104079574347 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6256263256072998 , mse train loss:  0.003865940496325493 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6307032108306885 , mse train loss:  0.0038576775696128607 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6356117725372314 , mse train loss:  0.0038494153413921595 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6406033039093018 , mse train loss:  0.0038411577697843313 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6433029174804688 , mse train loss:  0.0038328985683619976 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6500194072723389 , mse train loss:  0.003824650077149272 , test acc:  81.68\n",
            "epoch:  1 , time elapse:  1.6550920009613037 , mse train loss:  0.0038164006546139717 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6601786613464355 , mse train loss:  0.00380814541131258 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6628954410552979 , mse train loss:  0.0037998936604708433 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6677522659301758 , mse train loss:  0.0037916405126452446 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6726188659667969 , mse train loss:  0.003783399472013116 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.676725149154663 , mse train loss:  0.003775154473260045 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6792337894439697 , mse train loss:  0.003766910871490836 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6818664073944092 , mse train loss:  0.003758662845939398 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.686838150024414 , mse train loss:  0.0037504248321056366 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.691166877746582 , mse train loss:  0.003742196364328265 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.6962904930114746 , mse train loss:  0.003733959048986435 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.7026927471160889 , mse train loss:  0.003725733608007431 , test acc:  81.69\n",
            "epoch:  1 , time elapse:  1.705273151397705 , mse train loss:  0.0037175107281655073 , test acc:  81.69\n"
          ]
        }
      ],
      "source": [
        "import time\n",
        "import copy\n",
        "# Start the timer\n",
        "training_time = 0\n",
        "\n",
        "# Training loop\n",
        "# for epoch in range(num_epochs):\n",
        "epoch = 0\n",
        "best_test_accuracy = 0\n",
        "best_epoch = 0\n",
        "best_model = None\n",
        "while training_time < time_budget:\n",
        "\n",
        "    start_time = time.time()\n",
        "    distil_model.train()\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "    output = distil_model(X_train_SCNN_new)\n",
        "    loss = loss_function(output, y_train_SCNN_new)\n",
        "    temp_loss = loss.item()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    epoch_time = time.time() - start_time\n",
        "    training_time += epoch_time\n",
        "\n",
        "    train_mse = torch.mean((output-y_train_SCNN_new)**2).item()\n",
        "\n",
        "    # Check if the time budget has been exceeded\n",
        "    if epoch%100 == 0:\n",
        "        temp_acc = report_test_accuracy(distil_model)\n",
        "        if best_test_accuracy < temp_acc:\n",
        "            best_test_accuracy = temp_acc\n",
        "            best_epoch = epoch\n",
        "            best_model = copy.deepcopy(distil_model)\n",
        "        if epoch%1000 == 0:\n",
        "            print(\"epoch: \", epoch+1, \", time elapse: \", training_time, \", mse train loss: \", train_mse, \", test acc: \", temp_acc)\n",
        "# Continue with further code, such as evaluation or saving the model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "70J-_eHSygVy"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import copy\n",
        "# Start the timer\n",
        "training_time = 0\n",
        "\n",
        "# Training loop\n",
        "# for epoch in range(num_epochs):\n",
        "epoch = 0\n",
        "best_test_accuracy = 0\n",
        "best_epoch = 0\n",
        "best_model = None\n",
        "while training_time < time_budget:\n",
        "\n",
        "    start_time = time.time()\n",
        "    distil_model.train()\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "    output = distil_model(X_train_SCNN_new)\n",
        "    loss = loss_function(output, y_train_SCNN_new)\n",
        "    temp_loss = loss.item()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    epoch_time = time.time() - start_time\n",
        "    training_time += epoch_time\n",
        "\n",
        "    train_mse = torch.mean((output-y_train_SCNN_new)**2).item()\n",
        "\n",
        "    # Check if the time budget has been exceeded\n",
        "    if epoch%100 == 0:\n",
        "        temp_acc = report_test_accuracy(distil_model)\n",
        "        if best_test_accuracy < temp_acc:\n",
        "            best_test_accuracy = temp_acc\n",
        "            best_epoch = epoch\n",
        "            best_model = copy.deepcopy(distil_model)\n",
        "        if epoch%1000 == 0:\n",
        "            print(\"epoch: \", epoch+1, \", time elapse: \", training_time, \", mse train loss: \", train_mse, \", test acc: \", temp_acc)\n",
        "# Continue with further code, such as evaluation or saving the model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xZtUClDWWXE8"
      },
      "outputs": [],
      "source": [
        "best_test_accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "86aI6iaWayd3",
        "outputId": "626d39ea-bc99-4686-c48a-ba480ca3b86a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.14885425567626953"
            ]
          },
          "execution_count": 159,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "training_time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h_CEmGPLTdyG",
        "outputId": "c21fc7c2-c58d-41c3-dad7-3094b3eda371"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "73.61"
            ]
          },
          "execution_count": 52,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "report_test_accuracy(best_model)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
